mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Compare commits
13 Commits
next-test-
...
maryhipp/d
Author | SHA1 | Date | |
---|---|---|---|
f0bfa7f0e0 | |||
c46b2b6fa6 | |||
058cc715d4 | |||
f69e3ee01c | |||
6e0665e3d7 | |||
5a35550144 | |||
8926a1a424 | |||
8566c1c7ff | |||
6eb4c1ccb6 | |||
ef474a3196 | |||
16b3718d6a | |||
30228ce2a4 | |||
efb5f2d202 |
@ -5,6 +5,7 @@ import pathlib
|
||||
import shutil
|
||||
from hashlib import sha1
|
||||
from random import randbytes
|
||||
import traceback
|
||||
from typing import Any, Dict, List, Optional, Set
|
||||
|
||||
from fastapi import Body, Path, Query, Response
|
||||
@ -14,6 +15,7 @@ from starlette.exceptions import HTTPException
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from invokeai.app.services.model_install import ModelInstallJob
|
||||
from invokeai.app.services.model_metadata.metadata_store_base import ModelMetadataChanges
|
||||
from invokeai.app.services.model_records import (
|
||||
DuplicateModelException,
|
||||
InvalidModelException,
|
||||
@ -32,6 +34,7 @@ from invokeai.backend.model_manager.config import (
|
||||
)
|
||||
from invokeai.backend.model_manager.merge import MergeInterpolationMethod, ModelMerger
|
||||
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
|
||||
from invokeai.backend.model_manager.metadata.metadata_base import BaseMetadata
|
||||
from invokeai.backend.model_manager.search import ModelSearch
|
||||
|
||||
from ..dependencies import ApiDependencies
|
||||
@ -242,6 +245,47 @@ async def get_model_metadata(
|
||||
|
||||
return result
|
||||
|
||||
@model_manager_router.patch(
|
||||
"/i/{key}/metadata",
|
||||
operation_id="update_model_metadata",
|
||||
responses={
|
||||
201: {
|
||||
"description": "The model metadata was updated successfully",
|
||||
"content": {"application/json": {"example": example_model_metadata}},
|
||||
},
|
||||
400: {"description": "Bad request"},
|
||||
},
|
||||
)
|
||||
async def update_model_metadata(
|
||||
key: str = Path(description="Key of the model repo metadata to fetch."),
|
||||
changes: ModelMetadataChanges = Body(description="The changes")
|
||||
) -> Optional[AnyModelRepoMetadata]:
|
||||
"""Updates or creates a model metadata object."""
|
||||
record_store = ApiDependencies.invoker.services.model_manager.store
|
||||
metadata_store = ApiDependencies.invoker.services.model_manager.store.metadata_store
|
||||
|
||||
try:
|
||||
original_metadata = record_store.get_metadata(key)
|
||||
if original_metadata:
|
||||
if changes.trigger_phrases:
|
||||
original_metadata.trigger_phrases = changes.trigger_phrases
|
||||
|
||||
if changes.default_settings:
|
||||
original_metadata.default_settings = changes.default_settings
|
||||
|
||||
metadata_store.update_metadata(key, original_metadata)
|
||||
else:
|
||||
metadata_store.add_metadata(key, BaseMetadata(name="", author="",trigger_phrases=changes.trigger_phrases, default_settings=changes.default_settings))
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"An error occurred while updating the model metadata: {e}",
|
||||
)
|
||||
|
||||
result: Optional[AnyModelRepoMetadata] = record_store.get_metadata(key)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@model_manager_router.get(
|
||||
"/tags",
|
||||
|
@ -3,8 +3,9 @@ from typing import Iterator, List, Optional, Tuple, Union
|
||||
import torch
|
||||
from compel import Compel, ReturnedEmbeddingsType
|
||||
from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment
|
||||
from transformers import CLIPTextModel, CLIPTokenizer
|
||||
from transformers import CLIPTokenizer
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.app.invocations.fields import (
|
||||
FieldDescriptions,
|
||||
Input,
|
||||
@ -13,9 +14,11 @@ from invokeai.app.invocations.fields import (
|
||||
UIComponent,
|
||||
)
|
||||
from invokeai.app.invocations.primitives import ConditioningOutput
|
||||
from invokeai.app.services.model_records import UnknownModelException
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.util.ti_utils import generate_ti_list
|
||||
from invokeai.app.util.ti_utils import extract_ti_triggers_from_prompt
|
||||
from invokeai.backend.lora import LoRAModelRaw
|
||||
from invokeai.backend.model_manager import ModelType
|
||||
from invokeai.backend.model_patcher import ModelPatcher
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||
BasicConditioningInfo,
|
||||
@ -23,6 +26,7 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||
ExtraConditioningInfo,
|
||||
SDXLConditioningInfo,
|
||||
)
|
||||
from invokeai.backend.textual_inversion import TextualInversionModelRaw
|
||||
from invokeai.backend.util.devices import torch_dtype
|
||||
|
||||
from .baseinvocation import (
|
||||
@ -66,11 +70,7 @@ class CompelInvocation(BaseInvocation):
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
||||
tokenizer_info = context.models.load(**self.clip.tokenizer.model_dump())
|
||||
tokenizer_model = tokenizer_info.model
|
||||
assert isinstance(tokenizer_model, CLIPTokenizer)
|
||||
text_encoder_info = context.models.load(**self.clip.text_encoder.model_dump())
|
||||
text_encoder_model = text_encoder_info.model
|
||||
assert isinstance(text_encoder_model, CLIPTextModel)
|
||||
|
||||
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
|
||||
for lora in self.clip.loras:
|
||||
@ -82,10 +82,21 @@ class CompelInvocation(BaseInvocation):
|
||||
|
||||
# loras = [(context.models.get(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
|
||||
|
||||
ti_list = generate_ti_list(self.prompt, text_encoder_info.config.base, context)
|
||||
ti_list = []
|
||||
for trigger in extract_ti_triggers_from_prompt(self.prompt):
|
||||
name = trigger[1:-1]
|
||||
try:
|
||||
loaded_model = context.models.load(key=name).model
|
||||
assert isinstance(loaded_model, TextualInversionModelRaw)
|
||||
ti_list.append((name, loaded_model))
|
||||
except UnknownModelException:
|
||||
# print(e)
|
||||
# import traceback
|
||||
# print(traceback.format_exc())
|
||||
print(f'Warn: trigger: "{trigger}" not found')
|
||||
|
||||
with (
|
||||
ModelPatcher.apply_ti(tokenizer_model, text_encoder_model, ti_list) as (
|
||||
ModelPatcher.apply_ti(tokenizer_info.model, text_encoder_info.model, ti_list) as (
|
||||
tokenizer,
|
||||
ti_manager,
|
||||
),
|
||||
@ -93,9 +104,8 @@ class CompelInvocation(BaseInvocation):
|
||||
# Apply the LoRA after text_encoder has been moved to its target device for faster patching.
|
||||
ModelPatcher.apply_lora_text_encoder(text_encoder, _lora_loader()),
|
||||
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
|
||||
ModelPatcher.apply_clip_skip(text_encoder_model, self.clip.skipped_layers),
|
||||
ModelPatcher.apply_clip_skip(text_encoder_info.model, self.clip.skipped_layers),
|
||||
):
|
||||
assert isinstance(text_encoder, CLIPTextModel)
|
||||
compel = Compel(
|
||||
tokenizer=tokenizer,
|
||||
text_encoder=text_encoder,
|
||||
@ -145,11 +155,7 @@ class SDXLPromptInvocationBase:
|
||||
zero_on_empty: bool,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[ExtraConditioningInfo]]:
|
||||
tokenizer_info = context.models.load(**clip_field.tokenizer.model_dump())
|
||||
tokenizer_model = tokenizer_info.model
|
||||
assert isinstance(tokenizer_model, CLIPTokenizer)
|
||||
text_encoder_info = context.models.load(**clip_field.text_encoder.model_dump())
|
||||
text_encoder_model = text_encoder_info.model
|
||||
assert isinstance(text_encoder_model, CLIPTextModel)
|
||||
|
||||
# return zero on empty
|
||||
if prompt == "" and zero_on_empty:
|
||||
@ -183,10 +189,25 @@ class SDXLPromptInvocationBase:
|
||||
|
||||
# loras = [(context.models.get(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
|
||||
|
||||
ti_list = generate_ti_list(prompt, text_encoder_info.config.base, context)
|
||||
ti_list = []
|
||||
for trigger in extract_ti_triggers_from_prompt(prompt):
|
||||
name = trigger[1:-1]
|
||||
try:
|
||||
ti_model = context.models.load_by_attrs(
|
||||
model_name=name, base_model=text_encoder_info.config.base, model_type=ModelType.TextualInversion
|
||||
).model
|
||||
assert isinstance(ti_model, TextualInversionModelRaw)
|
||||
ti_list.append((name, ti_model))
|
||||
except UnknownModelException:
|
||||
# print(e)
|
||||
# import traceback
|
||||
# print(traceback.format_exc())
|
||||
logger.warning(f'trigger: "{trigger}" not found')
|
||||
except ValueError:
|
||||
logger.warning(f'trigger: "{trigger}" more than one similarly-named textual inversion models')
|
||||
|
||||
with (
|
||||
ModelPatcher.apply_ti(tokenizer_model, text_encoder_model, ti_list) as (
|
||||
ModelPatcher.apply_ti(tokenizer_info.model, text_encoder_info.model, ti_list) as (
|
||||
tokenizer,
|
||||
ti_manager,
|
||||
),
|
||||
@ -194,9 +215,8 @@ class SDXLPromptInvocationBase:
|
||||
# Apply the LoRA after text_encoder has been moved to its target device for faster patching.
|
||||
ModelPatcher.apply_lora(text_encoder, _lora_loader(), lora_prefix),
|
||||
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
|
||||
ModelPatcher.apply_clip_skip(text_encoder_model, clip_field.skipped_layers),
|
||||
ModelPatcher.apply_clip_skip(text_encoder_info.model, clip_field.skipped_layers),
|
||||
):
|
||||
assert isinstance(text_encoder, CLIPTextModel)
|
||||
compel = Compel(
|
||||
tokenizer=tokenizer,
|
||||
text_encoder=text_encoder,
|
||||
|
@ -228,16 +228,10 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
||||
except (OSError, HTTPError) as excp:
|
||||
job.error_type = excp.__class__.__name__ + f"({str(excp)})"
|
||||
job.error = traceback.format_exc()
|
||||
try:
|
||||
self._signal_job_error(job, excp)
|
||||
except:
|
||||
pass
|
||||
self._signal_job_error(job, excp)
|
||||
except DownloadJobCancelledException:
|
||||
try:
|
||||
self._signal_job_cancelled(job)
|
||||
self._cleanup_cancelled_job(job)
|
||||
except:
|
||||
pass
|
||||
self._signal_job_cancelled(job)
|
||||
self._cleanup_cancelled_job(job)
|
||||
|
||||
finally:
|
||||
job.job_ended = get_iso_timestamp()
|
||||
|
@ -4,9 +4,27 @@ Storage for Model Metadata
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Set, Tuple
|
||||
from typing import List, Optional, Set, Tuple
|
||||
|
||||
from pydantic import Field
|
||||
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
|
||||
|
||||
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
|
||||
from invokeai.backend.model_manager.metadata.metadata_base import ModelDefaultSettings
|
||||
|
||||
class ModelMetadataChanges(BaseModelExcludeNull, extra="allow"):
|
||||
"""A set of changes to apply to model metadata.
|
||||
|
||||
Only limited changes are valid:
|
||||
- `trigger_phrases`: the list of trigger phrases for this model
|
||||
- `default_settings`: the user-configured default settings for this model
|
||||
"""
|
||||
|
||||
trigger_phrases: Optional[List[str]] = Field(default=None, description="The model's list of trigger phrases")
|
||||
"""The model's list of trigger phrases"""
|
||||
|
||||
default_settings: Optional[ModelDefaultSettings] = Field(default=None, description="The user-configured default settings for this model")
|
||||
"""The user-configured default settings for this model"""
|
||||
|
||||
|
||||
class ModelMetadataStoreBase(ABC):
|
||||
|
@ -115,6 +115,8 @@ class ModelMetadataStoreSQL(ModelMetadataStoreBase):
|
||||
except sqlite3.Error as e:
|
||||
self._db.conn.rollback()
|
||||
raise e
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
return self.get_metadata(model_key)
|
||||
|
||||
@ -179,44 +181,45 @@ class ModelMetadataStoreSQL(ModelMetadataStoreBase):
|
||||
)
|
||||
return {x[0] for x in self._cursor.fetchall()}
|
||||
|
||||
def _update_tags(self, model_key: str, tags: Set[str]) -> None:
|
||||
def _update_tags(self, model_key: str, tags: Optional[Set[str]]) -> None:
|
||||
"""Update tags for the model referenced by model_key."""
|
||||
if tags:
|
||||
# remove previous tags from this model
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
DELETE FROM model_tags
|
||||
WHERE model_id=?;
|
||||
""",
|
||||
(model_key,),
|
||||
)
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
DELETE FROM model_tags
|
||||
WHERE model_id=?;
|
||||
""",
|
||||
(model_key,),
|
||||
)
|
||||
|
||||
for tag in tags:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
INSERT OR IGNORE INTO tags (
|
||||
tag_text
|
||||
)
|
||||
VALUES (?);
|
||||
""",
|
||||
(tag,),
|
||||
)
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT tag_id
|
||||
FROM tags
|
||||
WHERE tag_text = ?
|
||||
LIMIT 1;
|
||||
""",
|
||||
(tag,),
|
||||
)
|
||||
tag_id = self._cursor.fetchone()[0]
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
INSERT OR IGNORE INTO model_tags (
|
||||
model_id,
|
||||
tag_id
|
||||
)
|
||||
VALUES (?,?);
|
||||
""",
|
||||
(model_key, tag_id),
|
||||
)
|
||||
for tag in tags:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
INSERT OR IGNORE INTO tags (
|
||||
tag_text
|
||||
)
|
||||
VALUES (?);
|
||||
""",
|
||||
(tag,),
|
||||
)
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT tag_id
|
||||
FROM tags
|
||||
WHERE tag_text = ?
|
||||
LIMIT 1;
|
||||
""",
|
||||
(tag,),
|
||||
)
|
||||
tag_id = self._cursor.fetchone()[0]
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
INSERT OR IGNORE INTO model_tags (
|
||||
model_id,
|
||||
tag_id
|
||||
)
|
||||
VALUES (?,?);
|
||||
""",
|
||||
(model_key, tag_id),
|
||||
)
|
||||
|
@ -1,47 +1,8 @@
|
||||
import re
|
||||
from typing import List, Tuple
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.app.services.model_records import UnknownModelException
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.model_manager.config import BaseModelType, ModelType
|
||||
from invokeai.backend.textual_inversion import TextualInversionModelRaw
|
||||
|
||||
|
||||
def extract_ti_triggers_from_prompt(prompt: str) -> List[str]:
|
||||
ti_triggers: List[str] = []
|
||||
def extract_ti_triggers_from_prompt(prompt: str) -> list[str]:
|
||||
ti_triggers = []
|
||||
for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", prompt):
|
||||
ti_triggers.append(str(trigger))
|
||||
ti_triggers.append(trigger)
|
||||
return ti_triggers
|
||||
|
||||
|
||||
def generate_ti_list(
|
||||
prompt: str, base: BaseModelType, context: InvocationContext
|
||||
) -> List[Tuple[str, TextualInversionModelRaw]]:
|
||||
ti_list: List[Tuple[str, TextualInversionModelRaw]] = []
|
||||
for trigger in extract_ti_triggers_from_prompt(prompt):
|
||||
name_or_key = trigger[1:-1]
|
||||
try:
|
||||
loaded_model = context.models.load(key=name_or_key)
|
||||
model = loaded_model.model
|
||||
assert isinstance(model, TextualInversionModelRaw)
|
||||
assert loaded_model.config.base == base
|
||||
ti_list.append((name_or_key, model))
|
||||
except UnknownModelException:
|
||||
try:
|
||||
loaded_model = context.models.load_by_attrs(
|
||||
model_name=name_or_key, base_model=base, model_type=ModelType.TextualInversion
|
||||
)
|
||||
model = loaded_model.model
|
||||
assert isinstance(model, TextualInversionModelRaw)
|
||||
assert loaded_model.config.base == base
|
||||
ti_list.append((name_or_key, model))
|
||||
except UnknownModelException:
|
||||
pass
|
||||
except ValueError:
|
||||
logger.warning(f'trigger: "{trigger}" more than one similarly-named textual inversion models')
|
||||
except AssertionError:
|
||||
logger.warning(f'trigger: "{trigger}" not a valid textual inversion model for this graph')
|
||||
except Exception:
|
||||
logger.warning(f'Failed to load TI model for trigger: "{trigger}"')
|
||||
return ti_list
|
||||
|
@ -160,10 +160,11 @@ class CivitaiMetadataFetch(ModelMetadataFetchBase):
|
||||
nsfw=model_json["nsfw"],
|
||||
restrictions=LicenseRestrictions(
|
||||
AllowNoCredit=model_json["allowNoCredit"],
|
||||
AllowCommercialUse={CommercialUsage(x) for x in model_json["allowCommercialUse"]},
|
||||
AllowCommercialUse=CommercialUsage(model_json["allowCommercialUse"]),
|
||||
AllowDerivatives=model_json["allowDerivatives"],
|
||||
AllowDifferentLicense=model_json["allowDifferentLicense"],
|
||||
),
|
||||
trigger_phrases=version_json["trainedWords"],
|
||||
)
|
||||
|
||||
def from_civitai_versionid(self, version_id: int, model_id: Optional[int] = None) -> CivitaiMetadata:
|
||||
|
@ -24,6 +24,7 @@ from pydantic import BaseModel, Field, TypeAdapter
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
from requests.sessions import Session
|
||||
from typing_extensions import Annotated
|
||||
from invokeai.app.invocations.constants import SCHEDULER_NAME_VALUES
|
||||
|
||||
from invokeai.backend.model_manager import ModelRepoVariant
|
||||
|
||||
@ -54,8 +55,8 @@ class LicenseRestrictions(BaseModel):
|
||||
AllowDifferentLicense: bool = Field(
|
||||
description="if true, derivatives of this model be redistributed under a different license", default=False
|
||||
)
|
||||
AllowCommercialUse: Optional[Set[CommercialUsage] | CommercialUsage] = Field(
|
||||
description="Type of commercial use allowed if no commercial use is allowed.", default=None
|
||||
AllowCommercialUse: Optional[CommercialUsage] = Field(
|
||||
description="Type of commercial use allowed or 'No' if no commercial use is allowed.", default=None
|
||||
)
|
||||
|
||||
|
||||
@ -68,12 +69,22 @@ class RemoteModelFile(BaseModel):
|
||||
sha256: Optional[str] = Field(description="SHA256 hash of this model (not always available)", default=None)
|
||||
|
||||
|
||||
class ModelDefaultSettings(BaseModel):
|
||||
vae: str | None
|
||||
vae_precision: str | None
|
||||
scheduler: SCHEDULER_NAME_VALUES | None
|
||||
steps: int | None
|
||||
cfg_scale: float | None
|
||||
cfg_rescale_multiplier: float | None
|
||||
|
||||
class ModelMetadataBase(BaseModel):
|
||||
"""Base class for model metadata information."""
|
||||
|
||||
name: str = Field(description="model's name")
|
||||
author: str = Field(description="model's author")
|
||||
tags: Set[str] = Field(description="tags provided by model source")
|
||||
tags: Optional[Set[str]] = Field(description="tags provided by model source", default=None)
|
||||
trigger_phrases: Optional[List[str]] = Field(description="trigger phrases for this model", default=None)
|
||||
default_settings: Optional[ModelDefaultSettings] = Field(description="default settings for this model", default=None)
|
||||
|
||||
|
||||
class BaseMetadata(ModelMetadataBase):
|
||||
@ -142,10 +153,7 @@ class CivitaiMetadata(ModelMetadataWithFiles):
|
||||
if self.restrictions.AllowCommercialUse is None:
|
||||
return False
|
||||
else:
|
||||
# accommodate schema change
|
||||
acu = self.restrictions.AllowCommercialUse
|
||||
commercial_usage = acu if isinstance(acu, set) else {acu}
|
||||
return CommercialUsage.No not in commercial_usage
|
||||
return self.restrictions.AllowCommercialUse != CommercialUsage("None")
|
||||
|
||||
@property
|
||||
def allow_derivatives(self) -> bool:
|
||||
|
@ -78,6 +78,7 @@
|
||||
"aboutDesc": "Using Invoke for work? Check out:",
|
||||
"aboutHeading": "Own Your Creative Power",
|
||||
"accept": "Accept",
|
||||
"add": "Add",
|
||||
"advanced": "Advanced",
|
||||
"advancedOptions": "Advanced Options",
|
||||
"ai": "ai",
|
||||
@ -303,6 +304,12 @@
|
||||
"method": "High Resolution Fix Method"
|
||||
}
|
||||
},
|
||||
"prompt": {
|
||||
"addPromptTrigger": "Add Prompt Trigger",
|
||||
"compatibleEmbeddings": "Compatible Embeddings",
|
||||
"noPromptTriggers": "No triggers available",
|
||||
"noMatchingTriggers": "No matching triggers"
|
||||
},
|
||||
"embedding": {
|
||||
"addEmbedding": "Add Embedding",
|
||||
"incompatibleModel": "Incompatible base model:",
|
||||
@ -734,6 +741,8 @@
|
||||
"customConfig": "Custom Config",
|
||||
"customConfigFileLocation": "Custom Config File Location",
|
||||
"customSaveLocation": "Custom Save Location",
|
||||
"defaultSettings": "Default Settings",
|
||||
"defaultSettingsSaved": "Default Settings Saved",
|
||||
"delete": "Delete",
|
||||
"deleteConfig": "Delete Config",
|
||||
"deleteModel": "Delete Model",
|
||||
@ -768,6 +777,7 @@
|
||||
"mergedModelName": "Merged Model Name",
|
||||
"mergedModelSaveLocation": "Save Location",
|
||||
"mergeModels": "Merge Models",
|
||||
"metadata": "Metadata",
|
||||
"model": "Model",
|
||||
"modelAdded": "Model Added",
|
||||
"modelConversionFailed": "Model Conversion Failed",
|
||||
@ -839,9 +849,12 @@
|
||||
"statusConverting": "Converting",
|
||||
"syncModels": "Sync Models",
|
||||
"syncModelsDesc": "If your models are out of sync with the backend, you can refresh them up using this option. This is generally handy in cases where you add models to the InvokeAI root folder or autoimport directory after the application has booted.",
|
||||
"triggerPhrases": "Trigger Phrases",
|
||||
"typePhraseHere": "Type phrase here",
|
||||
"upcastAttention": "Upcast Attention",
|
||||
"updateModel": "Update Model",
|
||||
"useCustomConfig": "Use Custom Config",
|
||||
"useDefaultSettings": "Use Default Settings",
|
||||
"v1": "v1",
|
||||
"v2_768": "v2 (768px)",
|
||||
"v2_base": "v2 (512px)",
|
||||
|
@ -55,6 +55,8 @@ import { addUpscaleRequestedListener } from 'app/store/middleware/listenerMiddle
|
||||
import { addWorkflowLoadRequestedListener } from 'app/store/middleware/listenerMiddleware/listeners/workflowLoadRequested';
|
||||
import type { AppDispatch, RootState } from 'app/store/store';
|
||||
|
||||
import { addSetDefaultSettingsListener } from './listeners/setDefaultSettings';
|
||||
|
||||
export const listenerMiddleware = createListenerMiddleware();
|
||||
|
||||
export type AppStartListening = TypedStartListening<RootState, AppDispatch>;
|
||||
@ -153,3 +155,5 @@ addUpscaleRequestedListener(startAppListening);
|
||||
|
||||
// Dynamic prompts
|
||||
addDynamicPromptsListener(startAppListening);
|
||||
|
||||
addSetDefaultSettingsListener(startAppListening)
|
||||
|
@ -0,0 +1,88 @@
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { setDefaultSettings } from 'features/parameters/store/actions';
|
||||
import { setCfgRescaleMultiplier, setCfgScale, setScheduler, setSteps, vaePrecisionChanged, vaeSelected } from 'features/parameters/store/generationSlice';
|
||||
import { isParameterCFGRescaleMultiplier, isParameterCFGScale, isParameterPrecision, isParameterScheduler, isParameterSteps, zParameterVAEModel } from 'features/parameters/types/parameterSchemas';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { makeToast } from 'features/system/util/makeToast';
|
||||
import { t } from 'i18next';
|
||||
import { map } from 'lodash-es';
|
||||
import { modelsApi } from 'services/api/endpoints/models';
|
||||
|
||||
|
||||
export const addSetDefaultSettingsListener = (startAppListening: AppStartListening) => {
|
||||
startAppListening({
|
||||
actionCreator: setDefaultSettings,
|
||||
effect: async (action, { dispatch, getState }) => {
|
||||
const state = getState();
|
||||
|
||||
const currentModel = state.generation.model;
|
||||
|
||||
if (!currentModel) {
|
||||
return
|
||||
}
|
||||
|
||||
const metadata = await dispatch(
|
||||
modelsApi.endpoints.getModelMetadata.initiate(currentModel.key)
|
||||
).unwrap();
|
||||
|
||||
console.log({ metadata })
|
||||
|
||||
|
||||
if (!metadata || !metadata.default_settings) {
|
||||
return;
|
||||
}
|
||||
|
||||
const { vae, vae_precision, cfg_scale, cfg_rescale_multiplier, steps, scheduler } = metadata.default_settings
|
||||
|
||||
if (vae) {
|
||||
// we store this as "default" within default settings
|
||||
// to distinguish it from no default set
|
||||
if (vae === "default") {
|
||||
dispatch(vaeSelected(null))
|
||||
} else {
|
||||
const { data } = modelsApi.endpoints.getVaeModels.select()(state)
|
||||
const vaeArray = map(data?.entities)
|
||||
const validVae = vaeArray.find(model => model.key === vae)
|
||||
|
||||
const result = zParameterVAEModel.safeParse(validVae);
|
||||
if (!result.success) {
|
||||
return;
|
||||
}
|
||||
dispatch(vaeSelected(result.data));
|
||||
}
|
||||
}
|
||||
|
||||
if (vae_precision) {
|
||||
if (isParameterPrecision(vae_precision)) {
|
||||
dispatch(vaePrecisionChanged(vae_precision));
|
||||
}
|
||||
}
|
||||
|
||||
if (cfg_scale) {
|
||||
if (isParameterCFGScale(cfg_scale)) {
|
||||
dispatch(setCfgScale(cfg_scale));
|
||||
}
|
||||
}
|
||||
|
||||
if (cfg_rescale_multiplier) {
|
||||
if (isParameterCFGRescaleMultiplier(cfg_rescale_multiplier)) {
|
||||
dispatch(setCfgRescaleMultiplier(cfg_rescale_multiplier));
|
||||
}
|
||||
}
|
||||
|
||||
if (steps) {
|
||||
if (isParameterSteps(steps)) {
|
||||
dispatch(setSteps(steps));
|
||||
}
|
||||
}
|
||||
|
||||
if (scheduler) {
|
||||
if (isParameterScheduler(scheduler)) {
|
||||
dispatch(setScheduler(scheduler));
|
||||
}
|
||||
}
|
||||
|
||||
dispatch(addToast(makeToast({ title: t('toast.parameterSet', { parameter: "Default settings" }) })))
|
||||
},
|
||||
});
|
||||
};
|
@ -1,4 +1,5 @@
|
||||
import type { CONTROLNET_PROCESSORS } from 'features/controlAdapters/store/constants';
|
||||
import type { ParameterPrecision, ParameterScheduler } from 'features/parameters/types/parameterSchemas';
|
||||
import type { InvokeTabName } from 'features/ui/store/tabMap';
|
||||
import type { O } from 'ts-toolbelt';
|
||||
|
||||
@ -82,6 +83,8 @@ export type AppConfig = {
|
||||
guidance: NumericalParameterConfig;
|
||||
cfgRescaleMultiplier: NumericalParameterConfig;
|
||||
img2imgStrength: NumericalParameterConfig;
|
||||
scheduler?: ParameterScheduler,
|
||||
vaePrecision?: ParameterPrecision
|
||||
// Canvas
|
||||
boundingBoxHeight: NumericalParameterConfig; // initial value comes from model
|
||||
boundingBoxWidth: NumericalParameterConfig; // initial value comes from model
|
||||
|
@ -1,21 +0,0 @@
|
||||
import type { Meta, StoryObj } from '@storybook/react';
|
||||
|
||||
import { EmbeddingSelect } from './EmbeddingSelect';
|
||||
import type { EmbeddingSelectProps } from './types';
|
||||
|
||||
const meta: Meta<typeof EmbeddingSelect> = {
|
||||
title: 'Feature/Prompt/EmbeddingSelect',
|
||||
tags: ['autodocs'],
|
||||
component: EmbeddingSelect,
|
||||
};
|
||||
|
||||
export default meta;
|
||||
type Story = StoryObj<typeof EmbeddingSelect>;
|
||||
|
||||
const Component = (props: EmbeddingSelectProps) => {
|
||||
return <EmbeddingSelect {...props}>Invoke</EmbeddingSelect>;
|
||||
};
|
||||
|
||||
export const Default: Story = {
|
||||
render: Component,
|
||||
};
|
@ -1,67 +0,0 @@
|
||||
import type { ChakraProps } from '@invoke-ai/ui-library';
|
||||
import { Combobox, FormControl } from '@invoke-ai/ui-library';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
|
||||
import type { EmbeddingSelectProps } from 'features/embedding/types';
|
||||
import { t } from 'i18next';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useGetTextualInversionModelsQuery } from 'services/api/endpoints/models';
|
||||
import type { TextualInversionModelConfig } from 'services/api/types';
|
||||
|
||||
const noOptionsMessage = () => t('embedding.noMatchingEmbedding');
|
||||
|
||||
export const EmbeddingSelect = memo(({ onSelect, onClose }: EmbeddingSelectProps) => {
|
||||
const { t } = useTranslation();
|
||||
|
||||
const currentBaseModel = useAppSelector((s) => s.generation.model?.base);
|
||||
|
||||
const getIsDisabled = useCallback(
|
||||
(embedding: TextualInversionModelConfig): boolean => {
|
||||
const isCompatible = currentBaseModel === embedding.base;
|
||||
const hasMainModel = Boolean(currentBaseModel);
|
||||
return !hasMainModel || !isCompatible;
|
||||
},
|
||||
[currentBaseModel]
|
||||
);
|
||||
const { data, isLoading } = useGetTextualInversionModelsQuery();
|
||||
|
||||
const _onChange = useCallback(
|
||||
(embedding: TextualInversionModelConfig | null) => {
|
||||
if (!embedding) {
|
||||
return;
|
||||
}
|
||||
onSelect(embedding.name);
|
||||
},
|
||||
[onSelect]
|
||||
);
|
||||
|
||||
const { options, onChange } = useGroupedModelCombobox({
|
||||
modelEntities: data,
|
||||
getIsDisabled,
|
||||
onChange: _onChange,
|
||||
});
|
||||
|
||||
return (
|
||||
<FormControl>
|
||||
<Combobox
|
||||
placeholder={isLoading ? t('common.loading') : t('embedding.addEmbedding')}
|
||||
defaultMenuIsOpen
|
||||
autoFocus
|
||||
value={null}
|
||||
options={options}
|
||||
noOptionsMessage={noOptionsMessage}
|
||||
onChange={onChange}
|
||||
onMenuClose={onClose}
|
||||
data-testid="add-embedding"
|
||||
sx={selectStyles}
|
||||
/>
|
||||
</FormControl>
|
||||
);
|
||||
});
|
||||
|
||||
EmbeddingSelect.displayName = 'EmbeddingSelect';
|
||||
|
||||
const selectStyles: ChakraProps['sx'] = {
|
||||
w: 'full',
|
||||
};
|
@ -8,7 +8,7 @@ export const ModelPane = () => {
|
||||
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
|
||||
return (
|
||||
<Box layerStyle="first" p={2} borderRadius="base" w="50%" h="full">
|
||||
{selectedModelKey ? <Model /> : <ImportModels />}
|
||||
{selectedModelKey ? <Model key={selectedModelKey} /> : <ImportModels />}
|
||||
</Box>
|
||||
);
|
||||
};
|
||||
|
@ -0,0 +1,66 @@
|
||||
import { skipToken } from '@reduxjs/toolkit/query';
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import Loading from 'common/components/Loading/Loading';
|
||||
import { selectConfigSlice } from 'features/system/store/configSlice';
|
||||
import { isNil } from 'lodash-es';
|
||||
import { useMemo } from 'react';
|
||||
import { useGetModelMetadataQuery } from 'services/api/endpoints/models';
|
||||
|
||||
import { DefaultSettingsForm } from './DefaultSettings/DefaultSettingsForm';
|
||||
|
||||
const initialStatesSelector = createMemoizedSelector(selectConfigSlice, (config) => {
|
||||
const { steps, guidance, scheduler, cfgRescaleMultiplier, vaePrecision } = config.sd;
|
||||
|
||||
return {
|
||||
initialSteps: steps.initial,
|
||||
initialCfg: guidance.initial,
|
||||
initialScheduler: scheduler,
|
||||
initialCfgRescaleMultiplier: cfgRescaleMultiplier.initial,
|
||||
initialVaePrecision: vaePrecision,
|
||||
};
|
||||
});
|
||||
|
||||
export const DefaultSettings = () => {
|
||||
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
|
||||
|
||||
const { data, isLoading } = useGetModelMetadataQuery(selectedModelKey ?? skipToken);
|
||||
const { initialSteps, initialCfg, initialScheduler, initialCfgRescaleMultiplier, initialVaePrecision } =
|
||||
useAppSelector(initialStatesSelector);
|
||||
|
||||
const defaultSettingsDefaults = useMemo(() => {
|
||||
return {
|
||||
vae: { isEnabled: !isNil(data?.default_settings?.vae), value: data?.default_settings?.vae || 'default' },
|
||||
vaePrecision: {
|
||||
isEnabled: !isNil(data?.default_settings?.vae_precision),
|
||||
value: data?.default_settings?.vae_precision || initialVaePrecision || 'fp32',
|
||||
},
|
||||
scheduler: {
|
||||
isEnabled: !isNil(data?.default_settings?.scheduler),
|
||||
value: data?.default_settings?.scheduler || initialScheduler || 'euler',
|
||||
},
|
||||
steps: { isEnabled: !isNil(data?.default_settings?.steps), value: data?.default_settings?.steps || initialSteps },
|
||||
cfgScale: {
|
||||
isEnabled: !isNil(data?.default_settings?.cfg_scale),
|
||||
value: data?.default_settings?.cfg_scale || initialCfg,
|
||||
},
|
||||
cfgRescaleMultiplier: {
|
||||
isEnabled: !isNil(data?.default_settings?.cfg_rescale_multiplier),
|
||||
value: data?.default_settings?.cfg_rescale_multiplier || initialCfgRescaleMultiplier,
|
||||
},
|
||||
};
|
||||
}, [
|
||||
data?.default_settings,
|
||||
initialSteps,
|
||||
initialCfg,
|
||||
initialScheduler,
|
||||
initialCfgRescaleMultiplier,
|
||||
initialVaePrecision,
|
||||
]);
|
||||
|
||||
if (isLoading) {
|
||||
return <Loading />;
|
||||
}
|
||||
|
||||
return <DefaultSettingsForm defaultSettingsDefaults={defaultSettingsDefaults} />;
|
||||
};
|
@ -0,0 +1,72 @@
|
||||
import { CompositeNumberInput, CompositeSlider, Flex,FormControl, FormLabel } from '@invoke-ai/ui-library';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
|
||||
import { useCallback,useMemo } from 'react';
|
||||
import type {UseControllerProps } from 'react-hook-form';
|
||||
import { useController } from 'react-hook-form';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
import type { DefaultSettingsFormData } from './DefaultSettingsForm';
|
||||
|
||||
type DefaultCfgRescaleMultiplierType = DefaultSettingsFormData['cfgRescaleMultiplier'];
|
||||
|
||||
export function DefaultCfgRescaleMultiplier(props: UseControllerProps<DefaultSettingsFormData>) {
|
||||
const { field } = useController(props);
|
||||
|
||||
const sliderMin = useAppSelector((s) => s.config.sd.cfgRescaleMultiplier.sliderMin);
|
||||
const sliderMax = useAppSelector((s) => s.config.sd.cfgRescaleMultiplier.sliderMax);
|
||||
const numberInputMin = useAppSelector((s) => s.config.sd.cfgRescaleMultiplier.numberInputMin);
|
||||
const numberInputMax = useAppSelector((s) => s.config.sd.cfgRescaleMultiplier.numberInputMax);
|
||||
const coarseStep = useAppSelector((s) => s.config.sd.cfgRescaleMultiplier.coarseStep);
|
||||
const fineStep = useAppSelector((s) => s.config.sd.cfgRescaleMultiplier.fineStep);
|
||||
const { t } = useTranslation();
|
||||
const marks = useMemo(() => [sliderMin, Math.floor(sliderMax / 2), sliderMax], [sliderMax, sliderMin]);
|
||||
|
||||
const onChange = useCallback(
|
||||
(v: number) => {
|
||||
const updatedValue = {
|
||||
...(field.value as DefaultCfgRescaleMultiplierType),
|
||||
value: v,
|
||||
};
|
||||
field.onChange(updatedValue);
|
||||
},
|
||||
[field]
|
||||
);
|
||||
|
||||
const value = useMemo(() => {
|
||||
return (field.value as DefaultCfgRescaleMultiplierType).value;
|
||||
}, [field.value]);
|
||||
|
||||
const isDisabled = useMemo(() => {
|
||||
return !(field.value as DefaultCfgRescaleMultiplierType).isEnabled;
|
||||
}, [field.value]);
|
||||
|
||||
return (
|
||||
<FormControl flexDir="column" gap={1} alignItems="flex-start">
|
||||
<InformationalPopover feature="paramCFGRescaleMultiplier">
|
||||
<FormLabel>{t('parameters.cfgRescaleMultiplier')}</FormLabel>
|
||||
</InformationalPopover>
|
||||
<Flex w="full" gap={1}>
|
||||
<CompositeSlider
|
||||
value={value}
|
||||
min={sliderMin}
|
||||
max={sliderMax}
|
||||
step={coarseStep}
|
||||
fineStep={fineStep}
|
||||
onChange={onChange}
|
||||
marks={marks}
|
||||
isDisabled={isDisabled}
|
||||
/>
|
||||
<CompositeNumberInput
|
||||
value={value}
|
||||
min={numberInputMin}
|
||||
max={numberInputMax}
|
||||
step={coarseStep}
|
||||
fineStep={fineStep}
|
||||
onChange={onChange}
|
||||
isDisabled={isDisabled}
|
||||
/>
|
||||
</Flex>
|
||||
</FormControl>
|
||||
);
|
||||
}
|
@ -0,0 +1,72 @@
|
||||
import { CompositeNumberInput, CompositeSlider, Flex,FormControl, FormLabel } from '@invoke-ai/ui-library';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
|
||||
import { useCallback,useMemo } from 'react';
|
||||
import type {UseControllerProps } from 'react-hook-form';
|
||||
import { useController } from 'react-hook-form';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
import type { DefaultSettingsFormData } from './DefaultSettingsForm';
|
||||
|
||||
type DefaultCfgType = DefaultSettingsFormData['cfgScale'];
|
||||
|
||||
export function DefaultCfgScale(props: UseControllerProps<DefaultSettingsFormData>) {
|
||||
const { field } = useController(props);
|
||||
|
||||
const sliderMin = useAppSelector((s) => s.config.sd.guidance.sliderMin);
|
||||
const sliderMax = useAppSelector((s) => s.config.sd.guidance.sliderMax);
|
||||
const numberInputMin = useAppSelector((s) => s.config.sd.guidance.numberInputMin);
|
||||
const numberInputMax = useAppSelector((s) => s.config.sd.guidance.numberInputMax);
|
||||
const coarseStep = useAppSelector((s) => s.config.sd.guidance.coarseStep);
|
||||
const fineStep = useAppSelector((s) => s.config.sd.guidance.fineStep);
|
||||
const { t } = useTranslation();
|
||||
const marks = useMemo(() => [sliderMin, Math.floor(sliderMax / 2), sliderMax], [sliderMax, sliderMin]);
|
||||
|
||||
const onChange = useCallback(
|
||||
(v: number) => {
|
||||
const updatedValue = {
|
||||
...(field.value as DefaultCfgType),
|
||||
value: v,
|
||||
};
|
||||
field.onChange(updatedValue);
|
||||
},
|
||||
[field]
|
||||
);
|
||||
|
||||
const value = useMemo(() => {
|
||||
return (field.value as DefaultCfgType).value;
|
||||
}, [field.value]);
|
||||
|
||||
const isDisabled = useMemo(() => {
|
||||
return !(field.value as DefaultCfgType).isEnabled;
|
||||
}, [field.value]);
|
||||
|
||||
return (
|
||||
<FormControl flexDir="column" gap={1} alignItems="flex-start">
|
||||
<InformationalPopover feature="paramCFGScale">
|
||||
<FormLabel>{t('parameters.cfgScale')}</FormLabel>
|
||||
</InformationalPopover>
|
||||
<Flex w="full" gap={1}>
|
||||
<CompositeSlider
|
||||
value={value}
|
||||
min={sliderMin}
|
||||
max={sliderMax}
|
||||
step={coarseStep}
|
||||
fineStep={fineStep}
|
||||
onChange={onChange}
|
||||
marks={marks}
|
||||
isDisabled={isDisabled}
|
||||
/>
|
||||
<CompositeNumberInput
|
||||
value={value}
|
||||
min={numberInputMin}
|
||||
max={numberInputMax}
|
||||
step={coarseStep}
|
||||
fineStep={fineStep}
|
||||
onChange={onChange}
|
||||
isDisabled={isDisabled}
|
||||
/>
|
||||
</Flex>
|
||||
</FormControl>
|
||||
);
|
||||
}
|
@ -0,0 +1,50 @@
|
||||
import type { ComboboxOnChange } from '@invoke-ai/ui-library';
|
||||
import { Combobox, FormControl, FormLabel } from '@invoke-ai/ui-library';
|
||||
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
|
||||
import { SCHEDULER_OPTIONS } from 'features/parameters/types/constants';
|
||||
import { isParameterScheduler } from 'features/parameters/types/parameterSchemas';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import type {UseControllerProps } from 'react-hook-form';
|
||||
import { useController } from 'react-hook-form';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
import type { DefaultSettingsFormData } from './DefaultSettingsForm';
|
||||
|
||||
type DefaultSchedulerType = DefaultSettingsFormData['scheduler'];
|
||||
|
||||
export function DefaultScheduler(props: UseControllerProps<DefaultSettingsFormData>) {
|
||||
const { t } = useTranslation();
|
||||
const { field } = useController(props);
|
||||
|
||||
const onChange = useCallback<ComboboxOnChange>(
|
||||
(v) => {
|
||||
if (!isParameterScheduler(v?.value)) {
|
||||
return;
|
||||
}
|
||||
const updatedValue = {
|
||||
...(field.value as DefaultSchedulerType),
|
||||
value: v.value,
|
||||
};
|
||||
field.onChange(updatedValue);
|
||||
},
|
||||
[field]
|
||||
);
|
||||
|
||||
const value = useMemo(
|
||||
() => SCHEDULER_OPTIONS.find((o) => o.value === (field.value as DefaultSchedulerType).value),
|
||||
[field]
|
||||
);
|
||||
|
||||
const isDisabled = useMemo(() => {
|
||||
return !(field.value as DefaultSchedulerType).isEnabled;
|
||||
}, [field.value]);
|
||||
|
||||
return (
|
||||
<FormControl flexDir="column" gap={1} alignItems="flex-start">
|
||||
<InformationalPopover feature="paramScheduler">
|
||||
<FormLabel>{t('parameters.scheduler')}</FormLabel>
|
||||
</InformationalPopover>
|
||||
<Combobox isDisabled={isDisabled} value={value} options={SCHEDULER_OPTIONS} onChange={onChange} />
|
||||
</FormControl>
|
||||
);
|
||||
}
|
@ -0,0 +1,147 @@
|
||||
import { Button, Flex, Heading } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import type { ParameterScheduler } from 'features/parameters/types/parameterSchemas';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { makeToast } from 'features/system/util/makeToast';
|
||||
import { useCallback } from 'react';
|
||||
import type { SubmitHandler } from 'react-hook-form';
|
||||
import { useForm } from 'react-hook-form';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { IoPencil } from 'react-icons/io5';
|
||||
import { useUpdateModelMetadataMutation } from 'services/api/endpoints/models';
|
||||
|
||||
import { DefaultCfgRescaleMultiplier } from './DefaultCfgRescaleMultiplier';
|
||||
import { DefaultCfgScale } from './DefaultCfgScale';
|
||||
import { DefaultScheduler } from './DefaultScheduler';
|
||||
import { DefaultSteps } from './DefaultSteps';
|
||||
import { DefaultVae } from './DefaultVae';
|
||||
import { DefaultVaePrecision } from './DefaultVaePrecision';
|
||||
import { SettingToggle } from './SettingToggle';
|
||||
|
||||
export interface FormField<T> {
|
||||
value: T;
|
||||
isEnabled: boolean;
|
||||
}
|
||||
|
||||
export type DefaultSettingsFormData = {
|
||||
vae: FormField<string>;
|
||||
vaePrecision: FormField<string>;
|
||||
scheduler: FormField<ParameterScheduler>;
|
||||
steps: FormField<number>;
|
||||
cfgScale: FormField<number>;
|
||||
cfgRescaleMultiplier: FormField<number>;
|
||||
};
|
||||
|
||||
export const DefaultSettingsForm = ({
|
||||
defaultSettingsDefaults,
|
||||
}: {
|
||||
defaultSettingsDefaults: DefaultSettingsFormData;
|
||||
}) => {
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
|
||||
|
||||
const [editModelMetadata, { isLoading }] = useUpdateModelMetadataMutation();
|
||||
|
||||
const { handleSubmit, control, formState } = useForm<DefaultSettingsFormData>({
|
||||
defaultValues: defaultSettingsDefaults,
|
||||
});
|
||||
|
||||
const onSubmit = useCallback<SubmitHandler<DefaultSettingsFormData>>(
|
||||
(data) => {
|
||||
if (!selectedModelKey) {
|
||||
return;
|
||||
}
|
||||
|
||||
const body = {
|
||||
vae: data.vae.isEnabled ? data.vae.value : null,
|
||||
vae_precision: data.vaePrecision.isEnabled ? data.vaePrecision.value : null,
|
||||
cfg_scale: data.cfgScale.isEnabled ? data.cfgScale.value : null,
|
||||
cfg_rescale_multiplier: data.cfgRescaleMultiplier.isEnabled ? data.cfgRescaleMultiplier.value : null,
|
||||
steps: data.steps.isEnabled ? data.steps.value : null,
|
||||
scheduler: data.scheduler.isEnabled ? data.scheduler.value : null,
|
||||
};
|
||||
|
||||
editModelMetadata({
|
||||
key: selectedModelKey,
|
||||
body: { default_settings: body },
|
||||
})
|
||||
.unwrap()
|
||||
.then((_) => {
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: t('modelManager.defaultSettingsSaved'),
|
||||
status: 'success',
|
||||
})
|
||||
)
|
||||
);
|
||||
})
|
||||
.catch((error) => {
|
||||
if (error) {
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: `${error.data.detail} `,
|
||||
status: 'error',
|
||||
})
|
||||
)
|
||||
);
|
||||
}
|
||||
});
|
||||
},
|
||||
[selectedModelKey, dispatch, editModelMetadata, t]
|
||||
);
|
||||
|
||||
return (
|
||||
<>
|
||||
<Flex gap="2" justifyContent="space-between" w="full" mb={5}>
|
||||
<Heading fontSize="md">{t('modelManager.defaultSettings')}</Heading>
|
||||
<Button
|
||||
size="sm"
|
||||
leftIcon={<IoPencil />}
|
||||
colorScheme="invokeYellow"
|
||||
isDisabled={!formState.isDirty}
|
||||
onClick={handleSubmit(onSubmit)}
|
||||
type="submit"
|
||||
isLoading={isLoading}
|
||||
>
|
||||
{t('common.save')}
|
||||
</Button>
|
||||
</Flex>
|
||||
|
||||
<Flex flexDir="column" gap={8}>
|
||||
<Flex gap={8}>
|
||||
<Flex gap={4} w="full">
|
||||
<SettingToggle control={control} name="vae" />
|
||||
<DefaultVae control={control} name="vae" />
|
||||
</Flex>
|
||||
<Flex gap={4} w="full">
|
||||
<SettingToggle control={control} name="vaePrecision" />
|
||||
<DefaultVaePrecision control={control} name="vaePrecision" />
|
||||
</Flex>
|
||||
</Flex>
|
||||
<Flex gap={8}>
|
||||
<Flex gap={4} w="full">
|
||||
<SettingToggle control={control} name="scheduler" />
|
||||
<DefaultScheduler control={control} name="scheduler" />
|
||||
</Flex>
|
||||
<Flex gap={4} w="full">
|
||||
<SettingToggle control={control} name="steps" />
|
||||
<DefaultSteps control={control} name="steps" />
|
||||
</Flex>
|
||||
</Flex>
|
||||
<Flex gap={8}>
|
||||
<Flex gap={4} w="full">
|
||||
<SettingToggle control={control} name="cfgScale" />
|
||||
<DefaultCfgScale control={control} name="cfgScale" />
|
||||
</Flex>
|
||||
<Flex gap={4} w="full">
|
||||
<SettingToggle control={control} name="cfgRescaleMultiplier" />
|
||||
<DefaultCfgRescaleMultiplier control={control} name="cfgRescaleMultiplier" />
|
||||
</Flex>
|
||||
</Flex>
|
||||
</Flex>
|
||||
</>
|
||||
);
|
||||
};
|
@ -0,0 +1,72 @@
|
||||
import { CompositeNumberInput, CompositeSlider, Flex,FormControl, FormLabel } from '@invoke-ai/ui-library';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
|
||||
import { useCallback,useMemo } from 'react';
|
||||
import type {UseControllerProps } from 'react-hook-form';
|
||||
import { useController } from 'react-hook-form';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
import type { DefaultSettingsFormData } from './DefaultSettingsForm';
|
||||
|
||||
type DefaultSteps = DefaultSettingsFormData['steps'];
|
||||
|
||||
export function DefaultSteps(props: UseControllerProps<DefaultSettingsFormData>) {
|
||||
const { field } = useController(props);
|
||||
|
||||
const sliderMin = useAppSelector((s) => s.config.sd.steps.sliderMin);
|
||||
const sliderMax = useAppSelector((s) => s.config.sd.steps.sliderMax);
|
||||
const numberInputMin = useAppSelector((s) => s.config.sd.steps.numberInputMin);
|
||||
const numberInputMax = useAppSelector((s) => s.config.sd.steps.numberInputMax);
|
||||
const coarseStep = useAppSelector((s) => s.config.sd.steps.coarseStep);
|
||||
const fineStep = useAppSelector((s) => s.config.sd.steps.fineStep);
|
||||
const { t } = useTranslation();
|
||||
const marks = useMemo(() => [sliderMin, Math.floor(sliderMax / 2), sliderMax], [sliderMax, sliderMin]);
|
||||
|
||||
const onChange = useCallback(
|
||||
(v: number) => {
|
||||
const updatedValue = {
|
||||
...(field.value as DefaultSteps),
|
||||
value: v,
|
||||
};
|
||||
field.onChange(updatedValue);
|
||||
},
|
||||
[field]
|
||||
);
|
||||
|
||||
const value = useMemo(() => {
|
||||
return (field.value as DefaultSteps).value;
|
||||
}, [field.value]);
|
||||
|
||||
const isDisabled = useMemo(() => {
|
||||
return !(field.value as DefaultSteps).isEnabled;
|
||||
}, [field.value]);
|
||||
|
||||
return (
|
||||
<FormControl flexDir="column" gap={1} alignItems="flex-start">
|
||||
<InformationalPopover feature="paramSteps">
|
||||
<FormLabel>{t('parameters.steps')}</FormLabel>
|
||||
</InformationalPopover>
|
||||
<Flex w="full" gap={1}>
|
||||
<CompositeSlider
|
||||
value={value}
|
||||
min={sliderMin}
|
||||
max={sliderMax}
|
||||
step={coarseStep}
|
||||
fineStep={fineStep}
|
||||
onChange={onChange}
|
||||
marks={marks}
|
||||
isDisabled={isDisabled}
|
||||
/>
|
||||
<CompositeNumberInput
|
||||
value={value}
|
||||
min={numberInputMin}
|
||||
max={numberInputMax}
|
||||
step={coarseStep}
|
||||
fineStep={fineStep}
|
||||
onChange={onChange}
|
||||
isDisabled={isDisabled}
|
||||
/>
|
||||
</Flex>
|
||||
</FormControl>
|
||||
);
|
||||
}
|
@ -0,0 +1,65 @@
|
||||
import type { ComboboxOnChange } from '@invoke-ai/ui-library';
|
||||
import { Combobox, FormControl, FormLabel } from '@invoke-ai/ui-library';
|
||||
import { skipToken } from '@reduxjs/toolkit/query';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
|
||||
import { map } from 'lodash-es';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import type {UseControllerProps } from 'react-hook-form';
|
||||
import { useController } from 'react-hook-form';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useGetModelConfigQuery, useGetVaeModelsQuery } from 'services/api/endpoints/models';
|
||||
|
||||
import type { DefaultSettingsFormData } from './DefaultSettingsForm';
|
||||
|
||||
type DefaultVaeType = DefaultSettingsFormData['vae'];
|
||||
|
||||
export function DefaultVae(props: UseControllerProps<DefaultSettingsFormData>) {
|
||||
const { t } = useTranslation();
|
||||
const { field } = useController(props);
|
||||
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
|
||||
const { data: modelData } = useGetModelConfigQuery(selectedModelKey ?? skipToken);
|
||||
|
||||
const { compatibleOptions } = useGetVaeModelsQuery(undefined, {
|
||||
selectFromResult: ({ data }) => {
|
||||
const modelArray = map(data?.entities);
|
||||
const compatibleOptions = modelArray
|
||||
.filter((vae) => vae.base === modelData?.base)
|
||||
.map((vae) => ({ label: vae.name, value: vae.key }));
|
||||
|
||||
const defaultOption = { label: 'Default VAE', value: 'default' };
|
||||
|
||||
return { compatibleOptions: [defaultOption, ...compatibleOptions] };
|
||||
},
|
||||
});
|
||||
|
||||
const onChange = useCallback<ComboboxOnChange>(
|
||||
(v) => {
|
||||
const newValue = !v?.value ? 'default' : v.value;
|
||||
|
||||
const updatedValue = {
|
||||
...(field.value as DefaultVaeType),
|
||||
value: newValue,
|
||||
};
|
||||
field.onChange(updatedValue);
|
||||
},
|
||||
[field]
|
||||
);
|
||||
|
||||
const value = useMemo(() => {
|
||||
return compatibleOptions.find((vae) => vae.value === (field.value as DefaultVaeType).value);
|
||||
}, [compatibleOptions, field.value]);
|
||||
|
||||
const isDisabled = useMemo(() => {
|
||||
return !(field.value as DefaultVaeType).isEnabled;
|
||||
}, [field.value]);
|
||||
|
||||
return (
|
||||
<FormControl flexDir="column" gap={1} alignItems="flex-start">
|
||||
<InformationalPopover feature="paramVAE">
|
||||
<FormLabel>{t('modelManager.vae')}</FormLabel>
|
||||
</InformationalPopover>
|
||||
<Combobox isDisabled={isDisabled} value={value} options={compatibleOptions} onChange={onChange} />
|
||||
</FormControl>
|
||||
);
|
||||
}
|
@ -0,0 +1,51 @@
|
||||
import type { ComboboxOnChange } from '@invoke-ai/ui-library';
|
||||
import { Combobox, FormControl, FormLabel } from '@invoke-ai/ui-library';
|
||||
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
|
||||
import { isParameterPrecision } from 'features/parameters/types/parameterSchemas';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import type {UseControllerProps } from 'react-hook-form';
|
||||
import { useController } from 'react-hook-form';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
import type { DefaultSettingsFormData } from './DefaultSettingsForm';
|
||||
|
||||
const options = [
|
||||
{ label: 'FP16', value: 'fp16' },
|
||||
{ label: 'FP32', value: 'fp32' },
|
||||
];
|
||||
|
||||
type DefaultVaePrecisionType = DefaultSettingsFormData['vaePrecision'];
|
||||
|
||||
export function DefaultVaePrecision(props: UseControllerProps<DefaultSettingsFormData>) {
|
||||
const { t } = useTranslation();
|
||||
const { field } = useController(props);
|
||||
|
||||
const onChange = useCallback<ComboboxOnChange>(
|
||||
(v) => {
|
||||
if (!isParameterPrecision(v?.value)) {
|
||||
return;
|
||||
}
|
||||
const updatedValue = {
|
||||
...(field.value as DefaultVaePrecisionType),
|
||||
value: v.value,
|
||||
};
|
||||
field.onChange(updatedValue);
|
||||
},
|
||||
[field]
|
||||
);
|
||||
|
||||
const value = useMemo(() => options.find((o) => o.value === (field.value as DefaultVaePrecisionType).value), [field]);
|
||||
|
||||
const isDisabled = useMemo(() => {
|
||||
return !(field.value as DefaultVaePrecisionType).isEnabled;
|
||||
}, [field.value]);
|
||||
|
||||
return (
|
||||
<FormControl flexDir="column" gap={1} alignItems="flex-start">
|
||||
<InformationalPopover feature="paramVAEPrecision">
|
||||
<FormLabel>{t('modelManager.vaePrecision')}</FormLabel>
|
||||
</InformationalPopover>
|
||||
<Combobox isDisabled={isDisabled} value={value} options={options} onChange={onChange} />
|
||||
</FormControl>
|
||||
);
|
||||
}
|
@ -0,0 +1,32 @@
|
||||
import { Switch } from '@invoke-ai/ui-library';
|
||||
import type { ChangeEvent } from 'react';
|
||||
import { useCallback , useMemo } from 'react';
|
||||
import type { UseControllerProps} from 'react-hook-form';
|
||||
import { useController } from 'react-hook-form';
|
||||
|
||||
import type { DefaultSettingsFormData, FormField } from './DefaultSettingsForm';
|
||||
|
||||
interface Props<T> extends UseControllerProps<DefaultSettingsFormData> {
|
||||
name: keyof DefaultSettingsFormData;
|
||||
}
|
||||
|
||||
export function SettingToggle<T>(props: Props<T>) {
|
||||
const { field } = useController(props);
|
||||
|
||||
const value = useMemo(() => {
|
||||
return !!(field.value as FormField<T>).isEnabled;
|
||||
}, [field.value]);
|
||||
|
||||
const onChange = useCallback(
|
||||
(e: ChangeEvent<HTMLInputElement>) => {
|
||||
const updatedValue: FormField<T> = {
|
||||
...(field.value as FormField<T>),
|
||||
isEnabled: e.target.checked,
|
||||
};
|
||||
field.onChange(updatedValue);
|
||||
},
|
||||
[field]
|
||||
);
|
||||
|
||||
return <Switch isChecked={value} onChange={onChange} />;
|
||||
}
|
@ -0,0 +1,21 @@
|
||||
import { Box, Flex } from '@invoke-ai/ui-library';
|
||||
import { skipToken } from '@reduxjs/toolkit/query';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import DataViewer from 'features/gallery/components/ImageMetadataViewer/DataViewer';
|
||||
import { useGetModelMetadataQuery } from 'services/api/endpoints/models';
|
||||
|
||||
import { TriggerPhrases } from './TriggerPhrases';
|
||||
|
||||
export const ModelMetadata = () => {
|
||||
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
|
||||
const { data: metadata } = useGetModelMetadataQuery(selectedModelKey ?? skipToken);
|
||||
|
||||
return (
|
||||
<Flex flexDir="column" height="full" gap="3">
|
||||
<Box layerStyle="second" borderRadius="base" p={3}>
|
||||
<TriggerPhrases />
|
||||
</Box>
|
||||
<DataViewer label="metadata" data={metadata || {}} />
|
||||
</Flex>
|
||||
);
|
||||
};
|
@ -0,0 +1,106 @@
|
||||
import {
|
||||
Button,
|
||||
Flex,
|
||||
FormControl,
|
||||
FormErrorMessage,
|
||||
Input,
|
||||
Tag,
|
||||
TagCloseButton,
|
||||
TagLabel,
|
||||
} from '@invoke-ai/ui-library';
|
||||
import { skipToken } from '@reduxjs/toolkit/query';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { ModelListHeader } from 'features/modelManagerV2/subpanels/ModelManagerPanel/ModelListHeader';
|
||||
import type { ChangeEvent } from 'react';
|
||||
import { useCallback, useMemo, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useGetModelMetadataQuery, useUpdateModelMetadataMutation } from 'services/api/endpoints/models';
|
||||
|
||||
export const TriggerPhrases = () => {
|
||||
const { t } = useTranslation();
|
||||
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
|
||||
const { data: metadata } = useGetModelMetadataQuery(selectedModelKey ?? skipToken);
|
||||
const [phrase, setPhrase] = useState('');
|
||||
|
||||
const [editModelMetadata, { isLoading }] = useUpdateModelMetadataMutation();
|
||||
|
||||
const handlePhraseChange = useCallback((e: ChangeEvent<HTMLInputElement>) => {
|
||||
setPhrase(e.target.value);
|
||||
}, []);
|
||||
|
||||
const triggerPhrases = useMemo(() => {
|
||||
return metadata?.trigger_phrases || [];
|
||||
}, [metadata?.trigger_phrases]);
|
||||
|
||||
const errors = useMemo(() => {
|
||||
const errors = [];
|
||||
|
||||
if (phrase.length && triggerPhrases.includes(phrase)) {
|
||||
errors.push('Phrase is already in list');
|
||||
}
|
||||
|
||||
return errors;
|
||||
}, [phrase, triggerPhrases]);
|
||||
|
||||
const addTriggerPhrase = useCallback(async () => {
|
||||
if (!selectedModelKey) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (!phrase.length || triggerPhrases.includes(phrase)) {
|
||||
return;
|
||||
}
|
||||
|
||||
await editModelMetadata({
|
||||
key: selectedModelKey,
|
||||
body: { trigger_phrases: [...triggerPhrases, phrase] },
|
||||
}).unwrap();
|
||||
setPhrase('');
|
||||
}, [editModelMetadata, selectedModelKey, phrase, triggerPhrases]);
|
||||
|
||||
const removeTriggerPhrase = useCallback(
|
||||
async (phraseToRemove: string) => {
|
||||
if (!selectedModelKey) {
|
||||
return;
|
||||
}
|
||||
|
||||
const filteredPhrases = triggerPhrases.filter((p) => p !== phraseToRemove);
|
||||
|
||||
await editModelMetadata({ key: selectedModelKey, body: { trigger_phrases: filteredPhrases } }).unwrap();
|
||||
},
|
||||
[editModelMetadata, selectedModelKey, triggerPhrases]
|
||||
);
|
||||
|
||||
return (
|
||||
<Flex flexDir="column" w="full" gap="5">
|
||||
<ModelListHeader title={t('modelManager.triggerPhrases')} />
|
||||
<form>
|
||||
<FormControl w="full" isInvalid={Boolean(errors.length)}>
|
||||
<Flex flexDir="column" w="full">
|
||||
<Flex gap="3" alignItems="center" w="full">
|
||||
<Input value={phrase} onChange={handlePhraseChange} placeholder={t('modelManager.typePhraseHere')} />
|
||||
<Button
|
||||
type="submit"
|
||||
onClick={addTriggerPhrase}
|
||||
isDisabled={Boolean(errors.length)}
|
||||
isLoading={isLoading}
|
||||
>
|
||||
{t('common.add')}
|
||||
</Button>
|
||||
</Flex>
|
||||
{!!errors.length && errors.map((error) => <FormErrorMessage key={error}>{error}</FormErrorMessage>)}
|
||||
</Flex>
|
||||
</FormControl>
|
||||
</form>
|
||||
|
||||
<Flex gap="4" flexWrap="wrap" mt="3" mb="3">
|
||||
{triggerPhrases.map((phrase, index) => (
|
||||
<Tag size="md" key={index}>
|
||||
<TagLabel>{phrase}</TagLabel>
|
||||
<TagCloseButton onClick={removeTriggerPhrase.bind(null, phrase)} isDisabled={isLoading} />
|
||||
</Tag>
|
||||
))}
|
||||
</Flex>
|
||||
</Flex>
|
||||
);
|
||||
};
|
@ -1,9 +1,58 @@
|
||||
import { Box, Flex, Heading, Tab, TabList, TabPanel, TabPanels, Tabs, Text } from '@invoke-ai/ui-library';
|
||||
import { skipToken } from '@reduxjs/toolkit/query';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useGetModelConfigQuery } from 'services/api/endpoints/models';
|
||||
|
||||
import { ModelMetadata } from './Metadata/ModelMetadata';
|
||||
import { ModelAttrView } from './ModelAttrView';
|
||||
import { ModelEdit } from './ModelEdit';
|
||||
import { ModelView } from './ModelView';
|
||||
|
||||
export const Model = () => {
|
||||
const { t } = useTranslation();
|
||||
const selectedModelMode = useAppSelector((s) => s.modelmanagerV2.selectedModelMode);
|
||||
return selectedModelMode === 'view' ? <ModelView /> : <ModelEdit />;
|
||||
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
|
||||
const { data, isLoading } = useGetModelConfigQuery(selectedModelKey ?? skipToken);
|
||||
|
||||
if (isLoading) {
|
||||
return <Text>{t('common.loading')}</Text>;
|
||||
}
|
||||
|
||||
if (!data) {
|
||||
return <Text>{t('common.somethingWentWrong')}</Text>;
|
||||
}
|
||||
|
||||
return (
|
||||
<>
|
||||
<Flex flexDir="column" gap={1} p={2}>
|
||||
<Heading as="h2" fontSize="lg">
|
||||
{data.name}
|
||||
</Heading>
|
||||
|
||||
{data.source && (
|
||||
<Text variant="subtext">
|
||||
{t('modelManager.source')}: {data?.source}
|
||||
</Text>
|
||||
)}
|
||||
<Box mt="4">
|
||||
<ModelAttrView label="Description" value={data.description} />
|
||||
</Box>
|
||||
</Flex>
|
||||
|
||||
<Tabs mt="4" h="100%">
|
||||
<TabList>
|
||||
<Tab>{t('modelManager.settings')}</Tab>
|
||||
<Tab>{t('modelManager.metadata')}</Tab>
|
||||
</TabList>
|
||||
|
||||
<TabPanels h="100%">
|
||||
<TabPanel>{selectedModelMode === 'view' ? <ModelView /> : <ModelEdit />}</TabPanel>
|
||||
<TabPanel h="full">
|
||||
<ModelMetadata />
|
||||
</TabPanel>
|
||||
</TabPanels>
|
||||
</Tabs>
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
@ -1,12 +1,11 @@
|
||||
import { Box, Button, Flex, Heading, Text } from '@invoke-ai/ui-library';
|
||||
import { Box, Button, Flex, Text } from '@invoke-ai/ui-library';
|
||||
import { skipToken } from '@reduxjs/toolkit/query';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import DataViewer from 'features/gallery/components/ImageMetadataViewer/DataViewer';
|
||||
import { setSelectedModelMode } from 'features/modelManagerV2/store/modelManagerV2Slice';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { IoPencil } from 'react-icons/io5';
|
||||
import { useGetModelConfigQuery, useGetModelMetadataQuery } from 'services/api/endpoints/models';
|
||||
import { useGetModelConfigQuery } from 'services/api/endpoints/models';
|
||||
import type {
|
||||
CheckpointModelConfig,
|
||||
ControlNetModelConfig,
|
||||
@ -18,6 +17,7 @@ import type {
|
||||
VAEModelConfig,
|
||||
} from 'services/api/types';
|
||||
|
||||
import { DefaultSettings } from './DefaultSettings';
|
||||
import { ModelAttrView } from './ModelAttrView';
|
||||
import { ModelConvert } from './ModelConvert';
|
||||
|
||||
@ -26,7 +26,6 @@ export const ModelView = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
|
||||
const { data, isLoading } = useGetModelConfigQuery(selectedModelKey ?? skipToken);
|
||||
const { data: metadata } = useGetModelMetadataQuery(selectedModelKey ?? skipToken);
|
||||
|
||||
const modelData = useMemo(() => {
|
||||
if (!data) {
|
||||
@ -73,85 +72,56 @@ export const ModelView = () => {
|
||||
return <Text>{t('common.somethingWentWrong')}</Text>;
|
||||
}
|
||||
return (
|
||||
<Flex flexDir="column" h="full">
|
||||
<Flex w="full" justifyContent="space-between">
|
||||
<Flex flexDir="column" gap={1} p={2}>
|
||||
<Heading as="h2" fontSize="lg">
|
||||
{modelData.name}
|
||||
</Heading>
|
||||
|
||||
{modelData.source && (
|
||||
<Text variant="subtext">
|
||||
{t('modelManager.source')}: {modelData.source}
|
||||
</Text>
|
||||
)}
|
||||
</Flex>
|
||||
<Flex gap={2}>
|
||||
<Flex flexDir="column" h="full" gap="2">
|
||||
<Box layerStyle="second" borderRadius="base" p={3}>
|
||||
<Flex gap="2" justifyContent="flex-end" w="full">
|
||||
<Button size="sm" leftIcon={<IoPencil />} colorScheme="invokeYellow" onClick={handleEditModel}>
|
||||
{t('modelManager.edit')}
|
||||
</Button>
|
||||
|
||||
{modelData.type === 'main' && modelData.format === 'checkpoint' && <ModelConvert model={modelData} />}
|
||||
</Flex>
|
||||
</Flex>
|
||||
|
||||
<Flex flexDir="column" p={2} gap={3}>
|
||||
<Flex>
|
||||
<ModelAttrView label="Description" value={modelData.description} />
|
||||
</Flex>
|
||||
<Heading as="h3" fontSize="md" mt="4">
|
||||
{t('modelManager.modelSettings')}
|
||||
</Heading>
|
||||
<Box layerStyle="second" borderRadius="base" p={3}>
|
||||
<Flex flexDir="column" gap={3}>
|
||||
<Flex gap={2}>
|
||||
<ModelAttrView label={t('modelManager.baseModel')} value={modelData.base} />
|
||||
<ModelAttrView label={t('modelManager.modelType')} value={modelData.type} />
|
||||
</Flex>
|
||||
<Flex gap={2}>
|
||||
<ModelAttrView label={t('common.format')} value={modelData.format} />
|
||||
<ModelAttrView label={t('modelManager.path')} value={modelData.path} />
|
||||
</Flex>
|
||||
{modelData.type === 'main' && (
|
||||
<>
|
||||
<Flex gap={2}>
|
||||
{modelData.format === 'diffusers' && (
|
||||
<ModelAttrView label={t('modelManager.repoVariant')} value={modelData.repo_variant} />
|
||||
)}
|
||||
{modelData.format === 'checkpoint' && (
|
||||
<ModelAttrView label={t('modelManager.pathToConfig')} value={modelData.config} />
|
||||
)}
|
||||
|
||||
<ModelAttrView label={t('modelManager.variant')} value={modelData.variant} />
|
||||
</Flex>
|
||||
<Flex gap={2}>
|
||||
<ModelAttrView label={t('modelManager.predictionType')} value={modelData.prediction_type} />
|
||||
<ModelAttrView label={t('modelManager.upcastAttention')} value={`${modelData.upcast_attention}`} />
|
||||
</Flex>
|
||||
<Flex gap={2}>
|
||||
<ModelAttrView label={t('modelManager.ztsnrTraining')} value={`${modelData.ztsnr_training}`} />
|
||||
<ModelAttrView label={t('modelManager.vae')} value={modelData.vae} />
|
||||
</Flex>
|
||||
</>
|
||||
)}
|
||||
{modelData.type === 'ip_adapter' && (
|
||||
<Flex flexDir="column" gap={3}>
|
||||
<Flex gap={2}>
|
||||
<ModelAttrView label={t('modelManager.baseModel')} value={modelData.base} />
|
||||
<ModelAttrView label={t('modelManager.modelType')} value={modelData.type} />
|
||||
</Flex>
|
||||
<Flex gap={2}>
|
||||
<ModelAttrView label={t('common.format')} value={modelData.format} />
|
||||
<ModelAttrView label={t('modelManager.path')} value={modelData.path} />
|
||||
</Flex>
|
||||
{modelData.type === 'main' && (
|
||||
<>
|
||||
<Flex gap={2}>
|
||||
<ModelAttrView label={t('modelManager.imageEncoderModelId')} value={modelData.image_encoder_model_id} />
|
||||
</Flex>
|
||||
)}
|
||||
</Flex>
|
||||
</Box>
|
||||
</Flex>
|
||||
{modelData.format === 'diffusers' && (
|
||||
<ModelAttrView label={t('modelManager.repoVariant')} value={modelData.repo_variant} />
|
||||
)}
|
||||
{modelData.format === 'checkpoint' && (
|
||||
<ModelAttrView label={t('modelManager.pathToConfig')} value={modelData.config} />
|
||||
)}
|
||||
|
||||
{metadata && (
|
||||
<>
|
||||
<Heading as="h3" fontSize="md" mt="4">
|
||||
{t('modelManager.modelMetadata')}
|
||||
</Heading>
|
||||
<Flex h="full" w="full" p={2}>
|
||||
<DataViewer label="metadata" data={metadata} />
|
||||
</Flex>
|
||||
</>
|
||||
)}
|
||||
<ModelAttrView label={t('modelManager.variant')} value={modelData.variant} />
|
||||
</Flex>
|
||||
<Flex gap={2}>
|
||||
<ModelAttrView label={t('modelManager.predictionType')} value={modelData.prediction_type} />
|
||||
<ModelAttrView label={t('modelManager.upcastAttention')} value={`${modelData.upcast_attention}`} />
|
||||
</Flex>
|
||||
<Flex gap={2}>
|
||||
<ModelAttrView label={t('modelManager.ztsnrTraining')} value={`${modelData.ztsnr_training}`} />
|
||||
<ModelAttrView label={t('modelManager.vae')} value={modelData.vae} />
|
||||
</Flex>
|
||||
</>
|
||||
)}
|
||||
{modelData.type === 'ip_adapter' && (
|
||||
<Flex gap={2}>
|
||||
<ModelAttrView label={t('modelManager.imageEncoderModelId')} value={modelData.image_encoder_model_id} />
|
||||
</Flex>
|
||||
)}
|
||||
</Flex>
|
||||
</Box>
|
||||
<Box layerStyle="second" borderRadius="base" p={3}>
|
||||
<DefaultSettings />
|
||||
</Box>
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
@ -1,10 +1,10 @@
|
||||
import { Box, Textarea } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { AddEmbeddingButton } from 'features/embedding/AddEmbeddingButton';
|
||||
import { EmbeddingPopover } from 'features/embedding/EmbeddingPopover';
|
||||
import { usePrompt } from 'features/embedding/usePrompt';
|
||||
import { PromptOverlayButtonWrapper } from 'features/parameters/components/Prompts/PromptOverlayButtonWrapper';
|
||||
import { setNegativePrompt } from 'features/parameters/store/generationSlice';
|
||||
import { AddPromptTriggerButton } from 'features/prompt/AddPromptTriggerButton';
|
||||
import { PromptPopover } from 'features/prompt/PromptPopover';
|
||||
import { usePrompt } from 'features/prompt/usePrompt';
|
||||
import { memo, useCallback, useRef } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
@ -19,19 +19,14 @@ export const ParamNegativePrompt = memo(() => {
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
const { onChange, isOpen, onClose, onOpen, onSelectEmbedding, onKeyDown } = usePrompt({
|
||||
const { onChange, isOpen, onClose, onOpen, onSelect, onKeyDown } = usePrompt({
|
||||
prompt,
|
||||
textareaRef,
|
||||
onChange: _onChange,
|
||||
});
|
||||
|
||||
return (
|
||||
<EmbeddingPopover
|
||||
isOpen={isOpen}
|
||||
onClose={onClose}
|
||||
onSelect={onSelectEmbedding}
|
||||
width={textareaRef.current?.clientWidth}
|
||||
>
|
||||
<PromptPopover isOpen={isOpen} onClose={onClose} onSelect={onSelect} width={textareaRef.current?.clientWidth}>
|
||||
<Box pos="relative">
|
||||
<Textarea
|
||||
id="negativePrompt"
|
||||
@ -45,10 +40,10 @@ export const ParamNegativePrompt = memo(() => {
|
||||
variant="darkFilled"
|
||||
/>
|
||||
<PromptOverlayButtonWrapper>
|
||||
<AddEmbeddingButton isOpen={isOpen} onOpen={onOpen} />
|
||||
<AddPromptTriggerButton isOpen={isOpen} onOpen={onOpen} />
|
||||
</PromptOverlayButtonWrapper>
|
||||
</Box>
|
||||
</EmbeddingPopover>
|
||||
</PromptPopover>
|
||||
);
|
||||
});
|
||||
|
||||
|
@ -1,11 +1,11 @@
|
||||
import { Box, Textarea } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { ShowDynamicPromptsPreviewButton } from 'features/dynamicPrompts/components/ShowDynamicPromptsPreviewButton';
|
||||
import { AddEmbeddingButton } from 'features/embedding/AddEmbeddingButton';
|
||||
import { EmbeddingPopover } from 'features/embedding/EmbeddingPopover';
|
||||
import { usePrompt } from 'features/embedding/usePrompt';
|
||||
import { PromptOverlayButtonWrapper } from 'features/parameters/components/Prompts/PromptOverlayButtonWrapper';
|
||||
import { setPositivePrompt } from 'features/parameters/store/generationSlice';
|
||||
import { AddPromptTriggerButton } from 'features/prompt/AddPromptTriggerButton';
|
||||
import { PromptPopover } from 'features/prompt/PromptPopover';
|
||||
import { usePrompt } from 'features/prompt/usePrompt';
|
||||
import { SDXLConcatButton } from 'features/sdxl/components/SDXLPrompts/SDXLConcatButton';
|
||||
import { memo, useCallback, useRef } from 'react';
|
||||
import type { HotkeyCallback } from 'react-hotkeys-hook';
|
||||
@ -25,7 +25,7 @@ export const ParamPositivePrompt = memo(() => {
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
const { onChange, isOpen, onClose, onOpen, onSelectEmbedding, onKeyDown, onFocus } = usePrompt({
|
||||
const { onChange, isOpen, onClose, onOpen, onSelect, onKeyDown, onFocus } = usePrompt({
|
||||
prompt,
|
||||
textareaRef: textareaRef,
|
||||
onChange: handleChange,
|
||||
@ -42,12 +42,7 @@ export const ParamPositivePrompt = memo(() => {
|
||||
useHotkeys('alt+a', focus, []);
|
||||
|
||||
return (
|
||||
<EmbeddingPopover
|
||||
isOpen={isOpen}
|
||||
onClose={onClose}
|
||||
onSelect={onSelectEmbedding}
|
||||
width={textareaRef.current?.clientWidth}
|
||||
>
|
||||
<PromptPopover isOpen={isOpen} onClose={onClose} onSelect={onSelect} width={textareaRef.current?.clientWidth}>
|
||||
<Box pos="relative">
|
||||
<Textarea
|
||||
id="prompt"
|
||||
@ -61,12 +56,12 @@ export const ParamPositivePrompt = memo(() => {
|
||||
variant="darkFilled"
|
||||
/>
|
||||
<PromptOverlayButtonWrapper>
|
||||
<AddEmbeddingButton isOpen={isOpen} onOpen={onOpen} />
|
||||
<AddPromptTriggerButton isOpen={isOpen} onOpen={onOpen} />
|
||||
{baseModel === 'sdxl' && <SDXLConcatButton />}
|
||||
<ShowDynamicPromptsPreviewButton />
|
||||
</PromptOverlayButtonWrapper>
|
||||
</Box>
|
||||
</EmbeddingPopover>
|
||||
</PromptPopover>
|
||||
);
|
||||
});
|
||||
|
||||
|
@ -0,0 +1,28 @@
|
||||
import { IconButton } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { setDefaultSettings } from 'features/parameters/store/actions';
|
||||
import { useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { RiSparklingFill } from 'react-icons/ri';
|
||||
|
||||
export const UseDefaultSettingsButton = () => {
|
||||
const model = useAppSelector((s) => s.generation.model);
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const handleClickDefaultSettings = useCallback(() => {
|
||||
dispatch(setDefaultSettings());
|
||||
}, [dispatch]);
|
||||
|
||||
return (
|
||||
<IconButton
|
||||
icon={<RiSparklingFill />}
|
||||
tooltip={t('modelManager.useDefaultSettings')}
|
||||
aria-label={t('modelManager.useDefaultSettings')}
|
||||
isDisabled={!model}
|
||||
onClick={handleClickDefaultSettings}
|
||||
size="sm"
|
||||
variant="ghost"
|
||||
/>
|
||||
);
|
||||
};
|
@ -5,3 +5,5 @@ import type { ImageDTO } from 'services/api/types';
|
||||
export const initialImageSelected = createAction<ImageDTO | undefined>('generation/initialImageSelected');
|
||||
|
||||
export const modelSelected = createAction<ParameterModel>('generation/modelSelected');
|
||||
|
||||
export const setDefaultSettings = createAction('generation/setDefaultSettings');
|
@ -230,6 +230,12 @@ export const generationSlice = createSlice({
|
||||
state.height = optimalDimension;
|
||||
}
|
||||
}
|
||||
if (action.payload.sd?.scheduler) {
|
||||
state.scheduler = action.payload.sd.scheduler;
|
||||
}
|
||||
if (action.payload.sd?.vaePrecision) {
|
||||
state.vaePrecision = action.payload.sd.vaePrecision;
|
||||
}
|
||||
});
|
||||
|
||||
// TODO: This is a temp fix to reduce issues with T2I adapter having a different downscaling
|
||||
|
@ -8,15 +8,15 @@ type Props = {
|
||||
onOpen: () => void;
|
||||
};
|
||||
|
||||
export const AddEmbeddingButton = memo((props: Props) => {
|
||||
export const AddPromptTriggerButton = memo((props: Props) => {
|
||||
const { onOpen, isOpen } = props;
|
||||
const { t } = useTranslation();
|
||||
return (
|
||||
<Tooltip label={t('embedding.addEmbedding')}>
|
||||
<Tooltip label={t('prompt.addPromptTrigger')}>
|
||||
<IconButton
|
||||
variant="promptOverlay"
|
||||
isDisabled={isOpen}
|
||||
aria-label={t('embedding.addEmbedding')}
|
||||
aria-label={t('prompt.addPromptTrigger')}
|
||||
icon={<PiCodeBold />}
|
||||
onClick={onOpen}
|
||||
/>
|
||||
@ -24,4 +24,4 @@ export const AddEmbeddingButton = memo((props: Props) => {
|
||||
);
|
||||
});
|
||||
|
||||
AddEmbeddingButton.displayName = 'AddEmbeddingButton';
|
||||
AddPromptTriggerButton.displayName = 'AddPromptTriggerButton';
|
@ -1,9 +1,9 @@
|
||||
import { Popover, PopoverAnchor, PopoverBody, PopoverContent } from '@invoke-ai/ui-library';
|
||||
import { EmbeddingSelect } from 'features/embedding/EmbeddingSelect';
|
||||
import type { EmbeddingPopoverProps } from 'features/embedding/types';
|
||||
import { PromptTriggerSelect } from 'features/prompt/PromptTriggerSelect';
|
||||
import type { PromptPopoverProps } from 'features/prompt/types';
|
||||
import { memo } from 'react';
|
||||
|
||||
export const EmbeddingPopover = memo((props: EmbeddingPopoverProps) => {
|
||||
export const PromptPopover = memo((props: PromptPopoverProps) => {
|
||||
const { onSelect, isOpen, onClose, width, children } = props;
|
||||
|
||||
return (
|
||||
@ -14,7 +14,7 @@ export const EmbeddingPopover = memo((props: EmbeddingPopoverProps) => {
|
||||
openDelay={0}
|
||||
closeDelay={0}
|
||||
closeOnBlur={true}
|
||||
returnFocusOnClose={true}
|
||||
returnFocusOnClose={false}
|
||||
isLazy
|
||||
>
|
||||
<PopoverAnchor>{children}</PopoverAnchor>
|
||||
@ -27,11 +27,11 @@ export const EmbeddingPopover = memo((props: EmbeddingPopoverProps) => {
|
||||
borderStyle="solid"
|
||||
>
|
||||
<PopoverBody p={0} width={`calc(${width}px - 0.25rem)`}>
|
||||
<EmbeddingSelect onClose={onClose} onSelect={onSelect} />
|
||||
<PromptTriggerSelect onClose={onClose} onSelect={onSelect} />
|
||||
</PopoverBody>
|
||||
</PopoverContent>
|
||||
</Popover>
|
||||
);
|
||||
});
|
||||
|
||||
EmbeddingPopover.displayName = 'EmbeddingPopover';
|
||||
PromptPopover.displayName = 'PromptPopover';
|
@ -0,0 +1,21 @@
|
||||
import type { Meta, StoryObj } from '@storybook/react';
|
||||
|
||||
import { PromptTriggerSelect } from './PromptTriggerSelect';
|
||||
import type { PromptTriggerSelectProps } from './types';
|
||||
|
||||
const meta: Meta<typeof PromptTriggerSelect> = {
|
||||
title: 'Feature/Prompt/PromptTriggerSelect',
|
||||
tags: ['autodocs'],
|
||||
component: PromptTriggerSelect,
|
||||
};
|
||||
|
||||
export default meta;
|
||||
type Story = StoryObj<typeof PromptTriggerSelect>;
|
||||
|
||||
const Component = (props: PromptTriggerSelectProps) => {
|
||||
return <PromptTriggerSelect {...props}>Invoke</PromptTriggerSelect>;
|
||||
};
|
||||
|
||||
export const Default: Story = {
|
||||
render: Component,
|
||||
};
|
@ -0,0 +1,86 @@
|
||||
import type { ChakraProps, ComboboxOnChange } from '@invoke-ai/ui-library';
|
||||
import { Combobox, FormControl } from '@invoke-ai/ui-library';
|
||||
import { skipToken } from '@reduxjs/toolkit/query';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import type { PromptTriggerSelectProps } from 'features/prompt/types';
|
||||
import { t } from 'i18next';
|
||||
import { map } from 'lodash-es';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useGetModelMetadataQuery, useGetTextualInversionModelsQuery } from 'services/api/endpoints/models';
|
||||
|
||||
const noOptionsMessage = () => t('prompt.noMatchingTriggers');
|
||||
|
||||
export const PromptTriggerSelect = memo(({ onSelect, onClose }: PromptTriggerSelectProps) => {
|
||||
const { t } = useTranslation();
|
||||
|
||||
const currentBaseModel = useAppSelector((s) => s.generation.model?.base);
|
||||
const currentModelKey = useAppSelector((s) => s.generation.model?.key);
|
||||
|
||||
const { data, isLoading } = useGetTextualInversionModelsQuery();
|
||||
const { data: metadata } = useGetModelMetadataQuery(currentModelKey ?? skipToken);
|
||||
|
||||
const _onChange = useCallback<ComboboxOnChange>(
|
||||
(v) => {
|
||||
if (!v) {
|
||||
onSelect('');
|
||||
return;
|
||||
}
|
||||
|
||||
onSelect(v.value);
|
||||
},
|
||||
[onSelect]
|
||||
);
|
||||
|
||||
const embeddingOptions = useMemo(() => {
|
||||
if (!data) {
|
||||
return [];
|
||||
}
|
||||
|
||||
const compatibleEmbeddingsArray = map(data.entities).filter((model) => model.base === currentBaseModel);
|
||||
|
||||
return [
|
||||
{
|
||||
label: t('prompt.compatibleEmbeddings'),
|
||||
options: compatibleEmbeddingsArray.map((model) => ({ label: model.name, value: `<${model.name}>` })),
|
||||
},
|
||||
];
|
||||
}, [data, currentBaseModel, t]);
|
||||
|
||||
const options = useMemo(() => {
|
||||
if (!metadata || !metadata.trigger_phrases) {
|
||||
return [...embeddingOptions];
|
||||
}
|
||||
|
||||
const metadataOptions = [
|
||||
{
|
||||
label: t('modelManager.triggerPhrases'),
|
||||
options: metadata.trigger_phrases.map((phrase) => ({ label: phrase, value: phrase })),
|
||||
},
|
||||
];
|
||||
return [...metadataOptions, ...embeddingOptions];
|
||||
}, [embeddingOptions, metadata, t]);
|
||||
|
||||
return (
|
||||
<FormControl>
|
||||
<Combobox
|
||||
placeholder={isLoading ? t('common.loading') : t('prompt.addPromptTrigger')}
|
||||
defaultMenuIsOpen
|
||||
autoFocus
|
||||
value={null}
|
||||
options={options}
|
||||
noOptionsMessage={noOptionsMessage}
|
||||
onChange={_onChange}
|
||||
onMenuClose={onClose}
|
||||
data-testid="add-prompt-trigger"
|
||||
sx={selectStyles}
|
||||
/>
|
||||
</FormControl>
|
||||
);
|
||||
});
|
||||
|
||||
PromptTriggerSelect.displayName = 'PromptTriggerSelect';
|
||||
|
||||
const selectStyles: ChakraProps['sx'] = {
|
||||
w: 'full',
|
||||
};
|
@ -1,12 +1,12 @@
|
||||
import type { PropsWithChildren } from 'react';
|
||||
|
||||
export type EmbeddingSelectProps = {
|
||||
export type PromptTriggerSelectProps = {
|
||||
onSelect: (v: string) => void;
|
||||
onClose: () => void;
|
||||
};
|
||||
|
||||
export type EmbeddingPopoverProps = PropsWithChildren &
|
||||
EmbeddingSelectProps & {
|
||||
export type PromptPopoverProps = PropsWithChildren &
|
||||
PromptTriggerSelectProps & {
|
||||
isOpen: boolean;
|
||||
width?: number | string;
|
||||
};
|
@ -4,13 +4,13 @@ import type { ChangeEventHandler, KeyboardEventHandler, RefObject } from 'react'
|
||||
import { useCallback } from 'react';
|
||||
import { flushSync } from 'react-dom';
|
||||
|
||||
type UseInsertEmbeddingArg = {
|
||||
type UseInsertTriggerArg = {
|
||||
prompt: string;
|
||||
textareaRef: RefObject<HTMLTextAreaElement>;
|
||||
onChange: (v: string) => void;
|
||||
};
|
||||
|
||||
export const usePrompt = ({ prompt, textareaRef, onChange: _onChange }: UseInsertEmbeddingArg) => {
|
||||
export const usePrompt = ({ prompt, textareaRef, onChange: _onChange }: UseInsertTriggerArg) => {
|
||||
const { isOpen, onClose, onOpen } = useDisclosure();
|
||||
|
||||
const onChange: ChangeEventHandler<HTMLTextAreaElement> = useCallback(
|
||||
@ -20,13 +20,13 @@ export const usePrompt = ({ prompt, textareaRef, onChange: _onChange }: UseInser
|
||||
[_onChange]
|
||||
);
|
||||
|
||||
const insertEmbedding = useCallback(
|
||||
const insertTrigger = useCallback(
|
||||
(v: string) => {
|
||||
if (!textareaRef.current) {
|
||||
return;
|
||||
}
|
||||
|
||||
// this is where we insert the TI trigger
|
||||
// this is where we insert the trigger
|
||||
const caret = textareaRef.current.selectionStart;
|
||||
|
||||
if (isNil(caret)) {
|
||||
@ -35,13 +35,9 @@ export const usePrompt = ({ prompt, textareaRef, onChange: _onChange }: UseInser
|
||||
|
||||
let newPrompt = prompt.slice(0, caret);
|
||||
|
||||
if (newPrompt[newPrompt.length - 1] !== '<') {
|
||||
newPrompt += '<';
|
||||
}
|
||||
newPrompt += `${v}`;
|
||||
|
||||
newPrompt += `${v}>`;
|
||||
|
||||
// we insert the cursor after the `>`
|
||||
// we insert the cursor after the end of trigger
|
||||
const finalCaretPos = newPrompt.length;
|
||||
|
||||
newPrompt += prompt.slice(caret);
|
||||
@ -51,7 +47,7 @@ export const usePrompt = ({ prompt, textareaRef, onChange: _onChange }: UseInser
|
||||
_onChange(newPrompt);
|
||||
});
|
||||
|
||||
// set the caret position to just after the TI trigger
|
||||
// set the cursor position to just after the trigger
|
||||
textareaRef.current.selectionStart = finalCaretPos;
|
||||
textareaRef.current.selectionEnd = finalCaretPos;
|
||||
},
|
||||
@ -62,17 +58,17 @@ export const usePrompt = ({ prompt, textareaRef, onChange: _onChange }: UseInser
|
||||
textareaRef.current?.focus();
|
||||
}, [textareaRef]);
|
||||
|
||||
const handleClose = useCallback(() => {
|
||||
const handleClosePopover = useCallback(() => {
|
||||
onClose();
|
||||
onFocus();
|
||||
}, [onFocus, onClose]);
|
||||
|
||||
const onSelectEmbedding = useCallback(
|
||||
const onSelect = useCallback(
|
||||
(v: string) => {
|
||||
insertEmbedding(v);
|
||||
handleClose();
|
||||
insertTrigger(v);
|
||||
handleClosePopover();
|
||||
},
|
||||
[handleClose, insertEmbedding]
|
||||
[handleClosePopover, insertTrigger]
|
||||
);
|
||||
|
||||
const onKeyDown: KeyboardEventHandler<HTMLTextAreaElement> = useCallback(
|
||||
@ -90,7 +86,7 @@ export const usePrompt = ({ prompt, textareaRef, onChange: _onChange }: UseInser
|
||||
isOpen,
|
||||
onClose,
|
||||
onOpen,
|
||||
onSelectEmbedding,
|
||||
onSelect,
|
||||
onKeyDown,
|
||||
onFocus,
|
||||
};
|
@ -1,9 +1,9 @@
|
||||
import { Box, Textarea } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { AddEmbeddingButton } from 'features/embedding/AddEmbeddingButton';
|
||||
import { EmbeddingPopover } from 'features/embedding/EmbeddingPopover';
|
||||
import { usePrompt } from 'features/embedding/usePrompt';
|
||||
import { PromptOverlayButtonWrapper } from 'features/parameters/components/Prompts/PromptOverlayButtonWrapper';
|
||||
import { AddPromptTriggerButton } from 'features/prompt/AddPromptTriggerButton';
|
||||
import { PromptPopover } from 'features/prompt/PromptPopover';
|
||||
import { usePrompt } from 'features/prompt/usePrompt';
|
||||
import { setNegativeStylePromptSDXL } from 'features/sdxl/store/sdxlSlice';
|
||||
import { memo, useCallback, useRef } from 'react';
|
||||
import { useHotkeys } from 'react-hotkeys-hook';
|
||||
@ -20,7 +20,7 @@ export const ParamSDXLNegativeStylePrompt = memo(() => {
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
const { onChange, isOpen, onClose, onOpen, onSelectEmbedding, onKeyDown, onFocus } = usePrompt({
|
||||
const { onChange, isOpen, onClose, onOpen, onSelect, onKeyDown, onFocus } = usePrompt({
|
||||
prompt,
|
||||
textareaRef: textareaRef,
|
||||
onChange: handleChange,
|
||||
@ -29,12 +29,7 @@ export const ParamSDXLNegativeStylePrompt = memo(() => {
|
||||
useHotkeys('alt+a', onFocus, []);
|
||||
|
||||
return (
|
||||
<EmbeddingPopover
|
||||
isOpen={isOpen}
|
||||
onClose={onClose}
|
||||
onSelect={onSelectEmbedding}
|
||||
width={textareaRef.current?.clientWidth}
|
||||
>
|
||||
<PromptPopover isOpen={isOpen} onClose={onClose} onSelect={onSelect} width={textareaRef.current?.clientWidth}>
|
||||
<Box pos="relative">
|
||||
<Textarea
|
||||
id="prompt"
|
||||
@ -48,10 +43,10 @@ export const ParamSDXLNegativeStylePrompt = memo(() => {
|
||||
variant="darkFilled"
|
||||
/>
|
||||
<PromptOverlayButtonWrapper>
|
||||
<AddEmbeddingButton isOpen={isOpen} onOpen={onOpen} />
|
||||
<AddPromptTriggerButton isOpen={isOpen} onOpen={onOpen} />
|
||||
</PromptOverlayButtonWrapper>
|
||||
</Box>
|
||||
</EmbeddingPopover>
|
||||
</PromptPopover>
|
||||
);
|
||||
});
|
||||
|
||||
|
@ -1,9 +1,9 @@
|
||||
import { Box, Textarea } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { AddEmbeddingButton } from 'features/embedding/AddEmbeddingButton';
|
||||
import { EmbeddingPopover } from 'features/embedding/EmbeddingPopover';
|
||||
import { usePrompt } from 'features/embedding/usePrompt';
|
||||
import { PromptOverlayButtonWrapper } from 'features/parameters/components/Prompts/PromptOverlayButtonWrapper';
|
||||
import { AddPromptTriggerButton } from 'features/prompt/AddPromptTriggerButton';
|
||||
import { PromptPopover } from 'features/prompt/PromptPopover';
|
||||
import { usePrompt } from 'features/prompt/usePrompt';
|
||||
import { setPositiveStylePromptSDXL } from 'features/sdxl/store/sdxlSlice';
|
||||
import { memo, useCallback, useRef } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
@ -19,19 +19,14 @@ export const ParamSDXLPositiveStylePrompt = memo(() => {
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
const { onChange, isOpen, onClose, onOpen, onSelectEmbedding, onKeyDown } = usePrompt({
|
||||
const { onChange, isOpen, onClose, onOpen, onSelect, onKeyDown } = usePrompt({
|
||||
prompt,
|
||||
textareaRef: textareaRef,
|
||||
onChange: handleChange,
|
||||
});
|
||||
|
||||
return (
|
||||
<EmbeddingPopover
|
||||
isOpen={isOpen}
|
||||
onClose={onClose}
|
||||
onSelect={onSelectEmbedding}
|
||||
width={textareaRef.current?.clientWidth}
|
||||
>
|
||||
<PromptPopover isOpen={isOpen} onClose={onClose} onSelect={onSelect} width={textareaRef.current?.clientWidth}>
|
||||
<Box pos="relative">
|
||||
<Textarea
|
||||
id="prompt"
|
||||
@ -45,10 +40,10 @@ export const ParamSDXLPositiveStylePrompt = memo(() => {
|
||||
variant="darkFilled"
|
||||
/>
|
||||
<PromptOverlayButtonWrapper>
|
||||
<AddEmbeddingButton isOpen={isOpen} onOpen={onOpen} />
|
||||
<AddPromptTriggerButton isOpen={isOpen} onOpen={onOpen} />
|
||||
</PromptOverlayButtonWrapper>
|
||||
</Box>
|
||||
</EmbeddingPopover>
|
||||
</PromptPopover>
|
||||
);
|
||||
});
|
||||
|
||||
|
@ -21,6 +21,7 @@ import ParamCFGScale from 'features/parameters/components/Core/ParamCFGScale';
|
||||
import ParamScheduler from 'features/parameters/components/Core/ParamScheduler';
|
||||
import ParamSteps from 'features/parameters/components/Core/ParamSteps';
|
||||
import ParamMainModelSelect from 'features/parameters/components/MainModel/ParamMainModelSelect';
|
||||
import { UseDefaultSettingsButton } from 'features/parameters/components/MainModel/UseDefaultSettingsButton';
|
||||
import { useExpanderToggle } from 'features/settingsAccordions/hooks/useExpanderToggle';
|
||||
import { useStandaloneAccordionToggle } from 'features/settingsAccordions/hooks/useStandaloneAccordionToggle';
|
||||
import { filter } from 'lodash-es';
|
||||
@ -71,7 +72,10 @@ export const GenerationSettingsAccordion = memo(() => {
|
||||
<TabPanel overflow="visible" px={4} pt={4}>
|
||||
<Flex gap={4} alignItems="center">
|
||||
<ParamMainModelSelect />
|
||||
<SyncModelsIconButton />
|
||||
<Flex>
|
||||
<UseDefaultSettingsButton />
|
||||
<SyncModelsIconButton />
|
||||
</Flex>
|
||||
</Flex>
|
||||
<Expander isOpen={isOpenExpander} onToggle={onToggleExpander}>
|
||||
<Flex gap={4} flexDir="column" pb={4}>
|
||||
|
@ -41,6 +41,8 @@ const initialConfigState: AppConfig = {
|
||||
boundingBoxHeight: { ...baseDimensionConfig },
|
||||
scaledBoundingBoxWidth: { ...baseDimensionConfig },
|
||||
scaledBoundingBoxHeight: { ...baseDimensionConfig },
|
||||
scheduler: "euler",
|
||||
vaePrecision: "fp32",
|
||||
steps: {
|
||||
initial: 30,
|
||||
sliderMin: 1,
|
||||
|
@ -24,12 +24,21 @@ export type UpdateModelArg = {
|
||||
body: paths['/api/v2/models/i/{key}']['patch']['requestBody']['content']['application/json'];
|
||||
};
|
||||
|
||||
type UpdateModelMetadataArg = {
|
||||
key: paths['/api/v2/models/i/{key}/metadata']['patch']['parameters']['path']['key'];
|
||||
body: paths['/api/v2/models/i/{key}/metadata']['patch']['requestBody']['content']['application/json'];
|
||||
};
|
||||
|
||||
type UpdateModelResponse = paths['/api/v2/models/i/{key}']['patch']['responses']['200']['content']['application/json'];
|
||||
type UpdateModelMetadataResponse =
|
||||
paths['/api/v2/models/i/{key}/metadata']['patch']['responses']['200']['content']['application/json'];
|
||||
|
||||
type GetModelConfigResponse = paths['/api/v2/models/i/{key}']['get']['responses']['200']['content']['application/json'];
|
||||
|
||||
type GetModelMetadataResponse =
|
||||
paths['/api/v2/models/i/{key}/metadata']['get']['responses']['200']['content']['application/json'];
|
||||
|
||||
|
||||
type ListModelsArg = NonNullable<paths['/api/v2/models/']['get']['parameters']['query']>;
|
||||
|
||||
type DeleteMainModelArg = {
|
||||
@ -108,25 +117,25 @@ const anyModelConfigAdapterSelectors = anyModelConfigAdapter.getSelectors(undefi
|
||||
|
||||
const buildProvidesTags =
|
||||
<TEntity extends AnyModelConfig>(tagType: (typeof tagTypes)[number]) =>
|
||||
(result: EntityState<TEntity, string> | undefined) => {
|
||||
const tags: ApiTagDescription[] = [{ type: tagType, id: LIST_TAG }, 'Model'];
|
||||
if (result) {
|
||||
tags.push(
|
||||
...result.ids.map((id) => ({
|
||||
type: tagType,
|
||||
id,
|
||||
}))
|
||||
);
|
||||
}
|
||||
(result: EntityState<TEntity, string> | undefined) => {
|
||||
const tags: ApiTagDescription[] = [{ type: tagType, id: LIST_TAG }, 'Model'];
|
||||
if (result) {
|
||||
tags.push(
|
||||
...result.ids.map((id) => ({
|
||||
type: tagType,
|
||||
id,
|
||||
}))
|
||||
);
|
||||
}
|
||||
|
||||
return tags;
|
||||
};
|
||||
return tags;
|
||||
};
|
||||
|
||||
const buildTransformResponse =
|
||||
<T extends AnyModelConfig>(adapter: EntityAdapter<T, string>) =>
|
||||
(response: { models: T[] }) => {
|
||||
return adapter.setAll(adapter.getInitialState(), response.models);
|
||||
};
|
||||
(response: { models: T[] }) => {
|
||||
return adapter.setAll(adapter.getInitialState(), response.models);
|
||||
};
|
||||
|
||||
/**
|
||||
* Builds an endpoint URL for the models router
|
||||
@ -172,6 +181,16 @@ export const modelsApi = api.injectEndpoints({
|
||||
},
|
||||
invalidatesTags: ['Model'],
|
||||
}),
|
||||
updateModelMetadata: build.mutation<UpdateModelMetadataResponse, UpdateModelMetadataArg>({
|
||||
query: ({ key, body }) => {
|
||||
return {
|
||||
url: buildModelsUrl(`i/${key}/metadata`),
|
||||
method: 'PATCH',
|
||||
body: body,
|
||||
};
|
||||
},
|
||||
invalidatesTags: ['Model'],
|
||||
}),
|
||||
installModel: build.mutation<InstallModelResponse, InstallModelArg>({
|
||||
query: ({ source, config, access_token }) => {
|
||||
return {
|
||||
@ -351,6 +370,7 @@ export const {
|
||||
useGetModelMetadataQuery,
|
||||
useDeleteModelImportMutation,
|
||||
usePruneModelImportsMutation,
|
||||
useUpdateModelMetadataMutation,
|
||||
} = modelsApi;
|
||||
|
||||
const upsertModelConfigs = (
|
||||
|
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@ -133,7 +133,7 @@ def test_metadata_civitai_fetch(mm2_session: Session) -> None:
|
||||
assert metadata.id == 215485
|
||||
assert metadata.author == "test_author" # note that this is not the same as the original from Civitai
|
||||
assert metadata.allow_commercial_use # changed to make sure we are reading locally not remotely
|
||||
assert CommercialUsage("RentCivit") in metadata.restrictions.AllowCommercialUse
|
||||
assert metadata.restrictions.AllowCommercialUse == CommercialUsage("RentCivit")
|
||||
assert metadata.version_id == 242807
|
||||
assert metadata.tags == {"tool", "turbo", "sdxl turbo"}
|
||||
|
||||
|
Reference in New Issue
Block a user