mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'main' into lstein/remove-hardcoded-cuda-device
This commit is contained in:
commit
9f9ce08e44
2
.gitignore
vendored
2
.gitignore
vendored
@ -201,8 +201,6 @@ checkpoints
|
|||||||
# If it's a Mac
|
# If it's a Mac
|
||||||
.DS_Store
|
.DS_Store
|
||||||
|
|
||||||
invokeai/frontend/web/dist/*
|
|
||||||
|
|
||||||
# Let the frontend manage its own gitignore
|
# Let the frontend manage its own gitignore
|
||||||
!invokeai/frontend/web/*
|
!invokeai/frontend/web/*
|
||||||
|
|
||||||
|
@ -2,17 +2,17 @@
|
|||||||
|
|
||||||
from typing import Literal, Optional, Union
|
from typing import Literal, Optional, Union
|
||||||
|
|
||||||
from fastapi import Query
|
from fastapi import Query, Body
|
||||||
from fastapi.routing import APIRouter, HTTPException
|
from fastapi.routing import APIRouter, HTTPException
|
||||||
from pydantic import BaseModel, Field, parse_obj_as
|
from pydantic import BaseModel, Field, parse_obj_as
|
||||||
from ..dependencies import ApiDependencies
|
from ..dependencies import ApiDependencies
|
||||||
from invokeai.backend import BaseModelType, ModelType
|
from invokeai.backend import BaseModelType, ModelType
|
||||||
|
from invokeai.backend.model_management import AddModelResult
|
||||||
from invokeai.backend.model_management.models import OPENAPI_MODEL_CONFIGS, SchedulerPredictionType
|
from invokeai.backend.model_management.models import OPENAPI_MODEL_CONFIGS, SchedulerPredictionType
|
||||||
MODEL_CONFIGS = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
MODEL_CONFIGS = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||||
|
|
||||||
models_router = APIRouter(prefix="/v1/models", tags=["models"])
|
models_router = APIRouter(prefix="/v1/models", tags=["models"])
|
||||||
|
|
||||||
|
|
||||||
class VaeRepo(BaseModel):
|
class VaeRepo(BaseModel):
|
||||||
repo_id: str = Field(description="The repo ID to use for this VAE")
|
repo_id: str = Field(description="The repo ID to use for this VAE")
|
||||||
path: Optional[str] = Field(description="The path to the VAE")
|
path: Optional[str] = Field(description="The path to the VAE")
|
||||||
@ -51,9 +51,12 @@ class CreateModelResponse(BaseModel):
|
|||||||
info: Union[CkptModelInfo, DiffusersModelInfo] = Field(discriminator="format", description="The model info")
|
info: Union[CkptModelInfo, DiffusersModelInfo] = Field(discriminator="format", description="The model info")
|
||||||
status: str = Field(description="The status of the API response")
|
status: str = Field(description="The status of the API response")
|
||||||
|
|
||||||
class ImportModelRequest(BaseModel):
|
class ImportModelResponse(BaseModel):
|
||||||
name: str = Field(description="A model path, repo_id or URL to import")
|
name: str = Field(description="The name of the imported model")
|
||||||
prediction_type: Optional[Literal['epsilon','v_prediction','sample']] = Field(description='Prediction type for SDv2 checkpoint files')
|
# base_model: str = Field(description="The base model")
|
||||||
|
# model_type: str = Field(description="The model type")
|
||||||
|
info: AddModelResult = Field(description="The model info")
|
||||||
|
status: str = Field(description="The status of the API response")
|
||||||
|
|
||||||
class ConversionRequest(BaseModel):
|
class ConversionRequest(BaseModel):
|
||||||
name: str = Field(description="The name of the new model")
|
name: str = Field(description="The name of the new model")
|
||||||
@ -86,7 +89,6 @@ async def list_models(
|
|||||||
models = parse_obj_as(ModelsList, { "models": models_raw })
|
models = parse_obj_as(ModelsList, { "models": models_raw })
|
||||||
return models
|
return models
|
||||||
|
|
||||||
|
|
||||||
@models_router.post(
|
@models_router.post(
|
||||||
"/",
|
"/",
|
||||||
operation_id="update_model",
|
operation_id="update_model",
|
||||||
@ -109,27 +111,38 @@ async def update_model(
|
|||||||
return model_response
|
return model_response
|
||||||
|
|
||||||
@models_router.post(
|
@models_router.post(
|
||||||
"/",
|
"/import",
|
||||||
operation_id="import_model",
|
operation_id="import_model",
|
||||||
responses={200: {"status": "success"}},
|
responses= {
|
||||||
|
201: {"description" : "The model imported successfully"},
|
||||||
|
404: {"description" : "The model could not be found"},
|
||||||
|
},
|
||||||
|
status_code=201,
|
||||||
|
response_model=ImportModelResponse
|
||||||
)
|
)
|
||||||
async def import_model(
|
async def import_model(
|
||||||
model_request: ImportModelRequest
|
name: str = Query(description="A model path, repo_id or URL to import"),
|
||||||
) -> None:
|
prediction_type: Optional[Literal['v_prediction','epsilon','sample']] = Query(description='Prediction type for SDv2 checkpoint files', default="v_prediction"),
|
||||||
""" Add Model """
|
) -> ImportModelResponse:
|
||||||
items_to_import = set([model_request.name])
|
""" Add a model using its local path, repo_id, or remote URL """
|
||||||
|
items_to_import = {name}
|
||||||
prediction_types = { x.value: x for x in SchedulerPredictionType }
|
prediction_types = { x.value: x for x in SchedulerPredictionType }
|
||||||
logger = ApiDependencies.invoker.services.logger
|
logger = ApiDependencies.invoker.services.logger
|
||||||
|
|
||||||
installed_models = ApiDependencies.invoker.services.model_manager.heuristic_import(
|
installed_models = ApiDependencies.invoker.services.model_manager.heuristic_import(
|
||||||
items_to_import = items_to_import,
|
items_to_import = items_to_import,
|
||||||
prediction_type_helper = lambda x: prediction_types.get(model_request.prediction_type)
|
prediction_type_helper = lambda x: prediction_types.get(prediction_type)
|
||||||
)
|
)
|
||||||
if len(installed_models) > 0:
|
if info := installed_models.get(name):
|
||||||
logger.info(f'Successfully imported {model_request.name}')
|
logger.info(f'Successfully imported {name}, got {info}')
|
||||||
|
return ImportModelResponse(
|
||||||
|
name = name,
|
||||||
|
info = info,
|
||||||
|
status = "success",
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
logger.error(f'Model {model_request.name} not imported')
|
logger.error(f'Model {name} not imported')
|
||||||
raise HTTPException(status_code=500, detail=f'Model {model_request.name} not imported')
|
raise HTTPException(status_code=404, detail=f'Model {name} not found')
|
||||||
|
|
||||||
@models_router.delete(
|
@models_router.delete(
|
||||||
"/{model_name}",
|
"/{model_name}",
|
||||||
|
@ -4,9 +4,10 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from inspect import signature
|
from inspect import signature
|
||||||
from typing import get_args, get_type_hints, Dict, List, Literal, TypedDict, TYPE_CHECKING
|
from typing import (TYPE_CHECKING, Dict, List, Literal, TypedDict, get_args,
|
||||||
|
get_type_hints)
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseConfig, BaseModel, Field
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ..services.invocation_services import InvocationServices
|
from ..services.invocation_services import InvocationServices
|
||||||
@ -65,8 +66,13 @@ class BaseInvocation(ABC, BaseModel):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def get_invocations_map(cls):
|
def get_invocations_map(cls):
|
||||||
# Get the type strings out of the literals and into a dictionary
|
# Get the type strings out of the literals and into a dictionary
|
||||||
return dict(map(lambda t: (get_args(get_type_hints(t)['type'])[0], t),BaseInvocation.get_all_subclasses()))
|
return dict(
|
||||||
|
map(
|
||||||
|
lambda t: (get_args(get_type_hints(t)["type"])[0], t),
|
||||||
|
BaseInvocation.get_all_subclasses(),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_output_type(cls):
|
def get_output_type(cls):
|
||||||
return signature(cls.invoke).return_annotation
|
return signature(cls.invoke).return_annotation
|
||||||
@ -75,11 +81,11 @@ class BaseInvocation(ABC, BaseModel):
|
|||||||
def invoke(self, context: InvocationContext) -> BaseInvocationOutput:
|
def invoke(self, context: InvocationContext) -> BaseInvocationOutput:
|
||||||
"""Invoke with provided context and return outputs."""
|
"""Invoke with provided context and return outputs."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
#fmt: off
|
# fmt: off
|
||||||
id: str = Field(description="The id of this node. Must be unique among all nodes.")
|
id: str = Field(description="The id of this node. Must be unique among all nodes.")
|
||||||
is_intermediate: bool = Field(default=False, description="Whether or not this node is an intermediate node.")
|
is_intermediate: bool = Field(default=False, description="Whether or not this node is an intermediate node.")
|
||||||
#fmt: on
|
# fmt: on
|
||||||
|
|
||||||
|
|
||||||
# TODO: figure out a better way to provide these hints
|
# TODO: figure out a better way to provide these hints
|
||||||
@ -98,16 +104,19 @@ class UIConfig(TypedDict, total=False):
|
|||||||
"model",
|
"model",
|
||||||
"control",
|
"control",
|
||||||
"image_collection",
|
"image_collection",
|
||||||
|
"vae_model",
|
||||||
|
"lora_model",
|
||||||
],
|
],
|
||||||
]
|
]
|
||||||
tags: List[str]
|
tags: List[str]
|
||||||
title: str
|
title: str
|
||||||
|
|
||||||
|
|
||||||
class CustomisedSchemaExtra(TypedDict):
|
class CustomisedSchemaExtra(TypedDict):
|
||||||
ui: UIConfig
|
ui: UIConfig
|
||||||
|
|
||||||
|
|
||||||
class InvocationConfig(BaseModel.Config):
|
class InvocationConfig(BaseConfig):
|
||||||
"""Customizes pydantic's BaseModel.Config class for use by Invocations.
|
"""Customizes pydantic's BaseModel.Config class for use by Invocations.
|
||||||
|
|
||||||
Provide `schema_extra` a `ui` dict to add hints for generated UIs.
|
Provide `schema_extra` a `ui` dict to add hints for generated UIs.
|
||||||
|
@ -1,27 +1,28 @@
|
|||||||
from typing import Literal, Optional, Union
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
from contextlib import ExitStack
|
|
||||||
import re
|
import re
|
||||||
|
from contextlib import ExitStack
|
||||||
|
from typing import List, Literal, Optional, Union
|
||||||
|
|
||||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
|
import torch
|
||||||
from .model import ClipField
|
from compel import Compel
|
||||||
|
from compel.prompt_parser import (Blend, Conjunction,
|
||||||
|
CrossAttentionControlSubstitute,
|
||||||
|
FlattenedPrompt, Fragment)
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from ...backend.util.devices import torch_dtype
|
from ...backend.model_management.models import ModelNotFoundException
|
||||||
from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent
|
|
||||||
from ...backend.model_management import BaseModelType, ModelType, SubModelType
|
from ...backend.model_management import BaseModelType, ModelType, SubModelType
|
||||||
from ...backend.model_management.lora import ModelPatcher
|
from ...backend.model_management.lora import ModelPatcher
|
||||||
|
from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent
|
||||||
from compel import Compel
|
from ...backend.util.devices import torch_dtype
|
||||||
from compel.prompt_parser import (
|
from .baseinvocation import (BaseInvocation, BaseInvocationOutput,
|
||||||
Blend,
|
InvocationConfig, InvocationContext)
|
||||||
CrossAttentionControlSubstitute,
|
from .model import ClipField
|
||||||
FlattenedPrompt,
|
|
||||||
Fragment, Conjunction,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ConditioningField(BaseModel):
|
class ConditioningField(BaseModel):
|
||||||
conditioning_name: Optional[str] = Field(default=None, description="The name of conditioning data")
|
conditioning_name: Optional[str] = Field(
|
||||||
|
default=None, description="The name of conditioning data")
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
schema_extra = {"required": ["conditioning_name"]}
|
schema_extra = {"required": ["conditioning_name"]}
|
||||||
|
|
||||||
@ -51,83 +52,92 @@ class CompelInvocation(BaseInvocation):
|
|||||||
"title": "Prompt (Compel)",
|
"title": "Prompt (Compel)",
|
||||||
"tags": ["prompt", "compel"],
|
"tags": ["prompt", "compel"],
|
||||||
"type_hints": {
|
"type_hints": {
|
||||||
"model": "model"
|
"model": "model"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
def invoke(self, context: InvocationContext) -> CompelOutput:
|
def invoke(self, context: InvocationContext) -> CompelOutput:
|
||||||
|
|
||||||
tokenizer_info = context.services.model_manager.get_model(
|
tokenizer_info = context.services.model_manager.get_model(
|
||||||
**self.clip.tokenizer.dict(),
|
**self.clip.tokenizer.dict(),
|
||||||
)
|
)
|
||||||
text_encoder_info = context.services.model_manager.get_model(
|
text_encoder_info = context.services.model_manager.get_model(
|
||||||
**self.clip.text_encoder.dict(),
|
**self.clip.text_encoder.dict(),
|
||||||
)
|
)
|
||||||
with tokenizer_info as orig_tokenizer,\
|
|
||||||
text_encoder_info as text_encoder:
|
|
||||||
|
|
||||||
loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
|
def _lora_loader():
|
||||||
|
for lora in self.clip.loras:
|
||||||
|
lora_info = context.services.model_manager.get_model(
|
||||||
|
**lora.dict(exclude={"weight"}))
|
||||||
|
yield (lora_info.context.model, lora.weight)
|
||||||
|
del lora_info
|
||||||
|
return
|
||||||
|
|
||||||
ti_list = []
|
#loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
|
||||||
for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", self.prompt):
|
|
||||||
name = trigger[1:-1]
|
|
||||||
try:
|
|
||||||
ti_list.append(
|
|
||||||
context.services.model_manager.get_model(
|
|
||||||
model_name=name,
|
|
||||||
base_model=self.clip.text_encoder.base_model,
|
|
||||||
model_type=ModelType.TextualInversion,
|
|
||||||
).context.model
|
|
||||||
)
|
|
||||||
except Exception:
|
|
||||||
#print(e)
|
|
||||||
#import traceback
|
|
||||||
#print(traceback.format_exc())
|
|
||||||
print(f"Warn: trigger: \"{trigger}\" not found")
|
|
||||||
|
|
||||||
with ModelPatcher.apply_lora_text_encoder(text_encoder, loras),\
|
ti_list = []
|
||||||
ModelPatcher.apply_ti(orig_tokenizer, text_encoder, ti_list) as (tokenizer, ti_manager):
|
for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", self.prompt):
|
||||||
|
name = trigger[1:-1]
|
||||||
compel = Compel(
|
try:
|
||||||
tokenizer=tokenizer,
|
ti_list.append(
|
||||||
text_encoder=text_encoder,
|
context.services.model_manager.get_model(
|
||||||
textual_inversion_manager=ti_manager,
|
model_name=name,
|
||||||
dtype_for_device_getter=torch_dtype,
|
base_model=self.clip.text_encoder.base_model,
|
||||||
truncate_long_prompts=True, # TODO:
|
model_type=ModelType.TextualInversion,
|
||||||
|
).context.model
|
||||||
)
|
)
|
||||||
|
except ModelNotFoundException:
|
||||||
conjunction = Compel.parse_prompt_string(self.prompt)
|
# print(e)
|
||||||
prompt: Union[FlattenedPrompt, Blend] = conjunction.prompts[0]
|
#import traceback
|
||||||
|
#print(traceback.format_exc())
|
||||||
|
print(f"Warn: trigger: \"{trigger}\" not found")
|
||||||
|
|
||||||
if context.services.configuration.log_tokenization:
|
with ModelPatcher.apply_lora_text_encoder(text_encoder_info.context.model, _lora_loader()),\
|
||||||
log_tokenization_for_prompt_object(prompt, tokenizer)
|
ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (tokenizer, ti_manager),\
|
||||||
|
text_encoder_info as text_encoder:
|
||||||
|
|
||||||
c, options = compel.build_conditioning_tensor_for_prompt_object(prompt)
|
compel = Compel(
|
||||||
|
tokenizer=tokenizer,
|
||||||
# TODO: long prompt support
|
text_encoder=text_encoder,
|
||||||
#if not self.truncate_long_prompts:
|
textual_inversion_manager=ti_manager,
|
||||||
# [c, uc] = compel.pad_conditioning_tensors_to_same_length([c, uc])
|
dtype_for_device_getter=torch_dtype,
|
||||||
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(
|
truncate_long_prompts=True, # TODO:
|
||||||
tokens_count_including_eos_bos=get_max_token_count(tokenizer, conjunction),
|
|
||||||
cross_attention_control_args=options.get("cross_attention_control", None),
|
|
||||||
)
|
|
||||||
|
|
||||||
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
|
|
||||||
|
|
||||||
# TODO: hacky but works ;D maybe rename latents somehow?
|
|
||||||
context.services.latents.save(conditioning_name, (c, ec))
|
|
||||||
|
|
||||||
return CompelOutput(
|
|
||||||
conditioning=ConditioningField(
|
|
||||||
conditioning_name=conditioning_name,
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
conjunction = Compel.parse_prompt_string(self.prompt)
|
||||||
|
prompt: Union[FlattenedPrompt, Blend] = conjunction.prompts[0]
|
||||||
|
|
||||||
|
if context.services.configuration.log_tokenization:
|
||||||
|
log_tokenization_for_prompt_object(prompt, tokenizer)
|
||||||
|
|
||||||
|
c, options = compel.build_conditioning_tensor_for_prompt_object(
|
||||||
|
prompt)
|
||||||
|
|
||||||
|
# TODO: long prompt support
|
||||||
|
# if not self.truncate_long_prompts:
|
||||||
|
# [c, uc] = compel.pad_conditioning_tensors_to_same_length([c, uc])
|
||||||
|
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(
|
||||||
|
tokens_count_including_eos_bos=get_max_token_count(
|
||||||
|
tokenizer, conjunction),
|
||||||
|
cross_attention_control_args=options.get(
|
||||||
|
"cross_attention_control", None),)
|
||||||
|
|
||||||
|
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
|
||||||
|
|
||||||
|
# TODO: hacky but works ;D maybe rename latents somehow?
|
||||||
|
context.services.latents.save(conditioning_name, (c, ec))
|
||||||
|
|
||||||
|
return CompelOutput(
|
||||||
|
conditioning=ConditioningField(
|
||||||
|
conditioning_name=conditioning_name,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_max_token_count(
|
def get_max_token_count(
|
||||||
tokenizer, prompt: Union[FlattenedPrompt, Blend, Conjunction], truncate_if_too_long=False
|
tokenizer, prompt: Union[FlattenedPrompt, Blend, Conjunction],
|
||||||
) -> int:
|
truncate_if_too_long=False) -> int:
|
||||||
if type(prompt) is Blend:
|
if type(prompt) is Blend:
|
||||||
blend: Blend = prompt
|
blend: Blend = prompt
|
||||||
return max(
|
return max(
|
||||||
@ -146,13 +156,13 @@ def get_max_token_count(
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return len(
|
return len(
|
||||||
get_tokens_for_prompt_object(tokenizer, prompt, truncate_if_too_long)
|
get_tokens_for_prompt_object(
|
||||||
)
|
tokenizer, prompt, truncate_if_too_long))
|
||||||
|
|
||||||
|
|
||||||
def get_tokens_for_prompt_object(
|
def get_tokens_for_prompt_object(
|
||||||
tokenizer, parsed_prompt: FlattenedPrompt, truncate_if_too_long=True
|
tokenizer, parsed_prompt: FlattenedPrompt, truncate_if_too_long=True
|
||||||
) -> [str]:
|
) -> List[str]:
|
||||||
if type(parsed_prompt) is Blend:
|
if type(parsed_prompt) is Blend:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Blend is not supported here - you need to get tokens for each of its .children"
|
"Blend is not supported here - you need to get tokens for each of its .children"
|
||||||
@ -181,7 +191,7 @@ def log_tokenization_for_conjunction(
|
|||||||
):
|
):
|
||||||
display_label_prefix = display_label_prefix or ""
|
display_label_prefix = display_label_prefix or ""
|
||||||
for i, p in enumerate(c.prompts):
|
for i, p in enumerate(c.prompts):
|
||||||
if len(c.prompts)>1:
|
if len(c.prompts) > 1:
|
||||||
this_display_label_prefix = f"{display_label_prefix}(conjunction part {i + 1}, weight={c.weights[i]})"
|
this_display_label_prefix = f"{display_label_prefix}(conjunction part {i + 1}, weight={c.weights[i]})"
|
||||||
else:
|
else:
|
||||||
this_display_label_prefix = display_label_prefix
|
this_display_label_prefix = display_label_prefix
|
||||||
@ -236,7 +246,8 @@ def log_tokenization_for_prompt_object(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def log_tokenization_for_text(text, tokenizer, display_label=None, truncate_if_too_long=False):
|
def log_tokenization_for_text(
|
||||||
|
text, tokenizer, display_label=None, truncate_if_too_long=False):
|
||||||
"""shows how the prompt is tokenized
|
"""shows how the prompt is tokenized
|
||||||
# usually tokens have '</w>' to indicate end-of-word,
|
# usually tokens have '</w>' to indicate end-of-word,
|
||||||
# but for readability it has been replaced with ' '
|
# but for readability it has been replaced with ' '
|
||||||
|
@ -4,18 +4,17 @@ from contextlib import ExitStack
|
|||||||
from typing import List, Literal, Optional, Union
|
from typing import List, Literal, Optional, Union
|
||||||
|
|
||||||
import einops
|
import einops
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, validator
|
|
||||||
import torch
|
import torch
|
||||||
from diffusers import ControlNetModel, DPMSolverMultistepScheduler
|
from diffusers import ControlNetModel, DPMSolverMultistepScheduler
|
||||||
from diffusers.image_processor import VaeImageProcessor
|
from diffusers.image_processor import VaeImageProcessor
|
||||||
from diffusers.schedulers import SchedulerMixin as Scheduler
|
from diffusers.schedulers import SchedulerMixin as Scheduler
|
||||||
|
from pydantic import BaseModel, Field, validator
|
||||||
|
|
||||||
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
||||||
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
||||||
|
|
||||||
from ..models.image import ImageCategory, ImageField, ResourceOrigin
|
|
||||||
from ...backend.image_util.seamless import configure_model_padding
|
from ...backend.image_util.seamless import configure_model_padding
|
||||||
|
from ...backend.model_management.lora import ModelPatcher
|
||||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||||
from ...backend.stable_diffusion.diffusers_pipeline import (
|
from ...backend.stable_diffusion.diffusers_pipeline import (
|
||||||
ConditioningData, ControlNetData, StableDiffusionGeneratorPipeline,
|
ConditioningData, ControlNetData, StableDiffusionGeneratorPipeline,
|
||||||
@ -24,7 +23,7 @@ from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import \
|
|||||||
PostprocessingSettings
|
PostprocessingSettings
|
||||||
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
||||||
from ...backend.util.devices import torch_dtype
|
from ...backend.util.devices import torch_dtype
|
||||||
from ...backend.model_management.lora import ModelPatcher
|
from ..models.image import ImageCategory, ImageField, ResourceOrigin
|
||||||
from .baseinvocation import (BaseInvocation, BaseInvocationOutput,
|
from .baseinvocation import (BaseInvocation, BaseInvocationOutput,
|
||||||
InvocationConfig, InvocationContext)
|
InvocationConfig, InvocationContext)
|
||||||
from .compel import ConditioningField
|
from .compel import ConditioningField
|
||||||
@ -32,14 +31,17 @@ from .controlnet_image_processors import ControlField
|
|||||||
from .image import ImageOutput
|
from .image import ImageOutput
|
||||||
from .model import ModelInfo, UNetField, VaeField
|
from .model import ModelInfo, UNetField, VaeField
|
||||||
|
|
||||||
|
|
||||||
class LatentsField(BaseModel):
|
class LatentsField(BaseModel):
|
||||||
"""A latents field used for passing latents between invocations"""
|
"""A latents field used for passing latents between invocations"""
|
||||||
|
|
||||||
latents_name: Optional[str] = Field(default=None, description="The name of the latents")
|
latents_name: Optional[str] = Field(
|
||||||
|
default=None, description="The name of the latents")
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
schema_extra = {"required": ["latents_name"]}
|
schema_extra = {"required": ["latents_name"]}
|
||||||
|
|
||||||
|
|
||||||
class LatentsOutput(BaseInvocationOutput):
|
class LatentsOutput(BaseInvocationOutput):
|
||||||
"""Base class for invocations that output latents"""
|
"""Base class for invocations that output latents"""
|
||||||
#fmt: off
|
#fmt: off
|
||||||
@ -53,11 +55,11 @@ class LatentsOutput(BaseInvocationOutput):
|
|||||||
|
|
||||||
|
|
||||||
def build_latents_output(latents_name: str, latents: torch.Tensor):
|
def build_latents_output(latents_name: str, latents: torch.Tensor):
|
||||||
return LatentsOutput(
|
return LatentsOutput(
|
||||||
latents=LatentsField(latents_name=latents_name),
|
latents=LatentsField(latents_name=latents_name),
|
||||||
width=latents.size()[3] * 8,
|
width=latents.size()[3] * 8,
|
||||||
height=latents.size()[2] * 8,
|
height=latents.size()[2] * 8,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
SAMPLER_NAME_VALUES = Literal[
|
SAMPLER_NAME_VALUES = Literal[
|
||||||
@ -70,16 +72,19 @@ def get_scheduler(
|
|||||||
scheduler_info: ModelInfo,
|
scheduler_info: ModelInfo,
|
||||||
scheduler_name: str,
|
scheduler_name: str,
|
||||||
) -> Scheduler:
|
) -> Scheduler:
|
||||||
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP['ddim'])
|
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(
|
||||||
orig_scheduler_info = context.services.model_manager.get_model(**scheduler_info.dict())
|
scheduler_name, SCHEDULER_MAP['ddim'])
|
||||||
|
orig_scheduler_info = context.services.model_manager.get_model(
|
||||||
|
**scheduler_info.dict())
|
||||||
with orig_scheduler_info as orig_scheduler:
|
with orig_scheduler_info as orig_scheduler:
|
||||||
scheduler_config = orig_scheduler.config
|
scheduler_config = orig_scheduler.config
|
||||||
|
|
||||||
if "_backup" in scheduler_config:
|
if "_backup" in scheduler_config:
|
||||||
scheduler_config = scheduler_config["_backup"]
|
scheduler_config = scheduler_config["_backup"]
|
||||||
scheduler_config = {**scheduler_config, **scheduler_extra_config, "_backup": scheduler_config}
|
scheduler_config = {**scheduler_config, **
|
||||||
|
scheduler_extra_config, "_backup": scheduler_config}
|
||||||
scheduler = scheduler_class.from_config(scheduler_config)
|
scheduler = scheduler_class.from_config(scheduler_config)
|
||||||
|
|
||||||
# hack copied over from generate.py
|
# hack copied over from generate.py
|
||||||
if not hasattr(scheduler, 'uses_inpainting_model'):
|
if not hasattr(scheduler, 'uses_inpainting_model'):
|
||||||
scheduler.uses_inpainting_model = lambda: False
|
scheduler.uses_inpainting_model = lambda: False
|
||||||
@ -124,18 +129,18 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
"ui": {
|
"ui": {
|
||||||
"tags": ["latents"],
|
"tags": ["latents"],
|
||||||
"type_hints": {
|
"type_hints": {
|
||||||
"model": "model",
|
"model": "model",
|
||||||
"control": "control",
|
"control": "control",
|
||||||
# "cfg_scale": "float",
|
# "cfg_scale": "float",
|
||||||
"cfg_scale": "number"
|
"cfg_scale": "number"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
# TODO: pass this an emitter method or something? or a session for dispatching?
|
# TODO: pass this an emitter method or something? or a session for dispatching?
|
||||||
def dispatch_progress(
|
def dispatch_progress(
|
||||||
self, context: InvocationContext, source_node_id: str, intermediate_state: PipelineIntermediateState
|
self, context: InvocationContext, source_node_id: str,
|
||||||
) -> None:
|
intermediate_state: PipelineIntermediateState) -> None:
|
||||||
stable_diffusion_step_callback(
|
stable_diffusion_step_callback(
|
||||||
context=context,
|
context=context,
|
||||||
intermediate_state=intermediate_state,
|
intermediate_state=intermediate_state,
|
||||||
@ -143,9 +148,12 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
source_node_id=source_node_id,
|
source_node_id=source_node_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_conditioning_data(self, context: InvocationContext, scheduler) -> ConditioningData:
|
def get_conditioning_data(
|
||||||
c, extra_conditioning_info = context.services.latents.get(self.positive_conditioning.conditioning_name)
|
self, context: InvocationContext, scheduler) -> ConditioningData:
|
||||||
uc, _ = context.services.latents.get(self.negative_conditioning.conditioning_name)
|
c, extra_conditioning_info = context.services.latents.get(
|
||||||
|
self.positive_conditioning.conditioning_name)
|
||||||
|
uc, _ = context.services.latents.get(
|
||||||
|
self.negative_conditioning.conditioning_name)
|
||||||
|
|
||||||
conditioning_data = ConditioningData(
|
conditioning_data = ConditioningData(
|
||||||
unconditioned_embeddings=uc,
|
unconditioned_embeddings=uc,
|
||||||
@ -153,10 +161,10 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
guidance_scale=self.cfg_scale,
|
guidance_scale=self.cfg_scale,
|
||||||
extra=extra_conditioning_info,
|
extra=extra_conditioning_info,
|
||||||
postprocessing_settings=PostprocessingSettings(
|
postprocessing_settings=PostprocessingSettings(
|
||||||
threshold=0.0,#threshold,
|
threshold=0.0, # threshold,
|
||||||
warmup=0.2,#warmup,
|
warmup=0.2, # warmup,
|
||||||
h_symmetry_time_pct=None,#h_symmetry_time_pct,
|
h_symmetry_time_pct=None, # h_symmetry_time_pct,
|
||||||
v_symmetry_time_pct=None#v_symmetry_time_pct,
|
v_symmetry_time_pct=None # v_symmetry_time_pct,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -164,31 +172,32 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
scheduler,
|
scheduler,
|
||||||
|
|
||||||
# for ddim scheduler
|
# for ddim scheduler
|
||||||
eta=0.0, #ddim_eta
|
eta=0.0, # ddim_eta
|
||||||
|
|
||||||
# for ancestral and sde schedulers
|
# for ancestral and sde schedulers
|
||||||
generator=torch.Generator(device=uc.device).manual_seed(0),
|
generator=torch.Generator(device=uc.device).manual_seed(0),
|
||||||
)
|
)
|
||||||
return conditioning_data
|
return conditioning_data
|
||||||
|
|
||||||
def create_pipeline(self, unet, scheduler) -> StableDiffusionGeneratorPipeline:
|
def create_pipeline(
|
||||||
|
self, unet, scheduler) -> StableDiffusionGeneratorPipeline:
|
||||||
# TODO:
|
# TODO:
|
||||||
#configure_model_padding(
|
# configure_model_padding(
|
||||||
# unet,
|
# unet,
|
||||||
# self.seamless,
|
# self.seamless,
|
||||||
# self.seamless_axes,
|
# self.seamless_axes,
|
||||||
#)
|
# )
|
||||||
|
|
||||||
class FakeVae:
|
class FakeVae:
|
||||||
class FakeVaeConfig:
|
class FakeVaeConfig:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.block_out_channels = [0]
|
self.block_out_channels = [0]
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.config = FakeVae.FakeVaeConfig()
|
self.config = FakeVae.FakeVaeConfig()
|
||||||
|
|
||||||
return StableDiffusionGeneratorPipeline(
|
return StableDiffusionGeneratorPipeline(
|
||||||
vae=FakeVae(), # TODO: oh...
|
vae=FakeVae(), # TODO: oh...
|
||||||
text_encoder=None,
|
text_encoder=None,
|
||||||
tokenizer=None,
|
tokenizer=None,
|
||||||
unet=unet,
|
unet=unet,
|
||||||
@ -198,11 +207,12 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
requires_safety_checker=False,
|
requires_safety_checker=False,
|
||||||
precision="float16" if unet.dtype == torch.float16 else "float32",
|
precision="float16" if unet.dtype == torch.float16 else "float32",
|
||||||
)
|
)
|
||||||
|
|
||||||
def prep_control_data(
|
def prep_control_data(
|
||||||
self,
|
self,
|
||||||
context: InvocationContext,
|
context: InvocationContext,
|
||||||
model: StableDiffusionGeneratorPipeline, # really only need model for dtype and device
|
# really only need model for dtype and device
|
||||||
|
model: StableDiffusionGeneratorPipeline,
|
||||||
control_input: List[ControlField],
|
control_input: List[ControlField],
|
||||||
latents_shape: List[int],
|
latents_shape: List[int],
|
||||||
do_classifier_free_guidance: bool = True,
|
do_classifier_free_guidance: bool = True,
|
||||||
@ -238,15 +248,17 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
print("Using HF model subfolders")
|
print("Using HF model subfolders")
|
||||||
print(" control_name: ", control_name)
|
print(" control_name: ", control_name)
|
||||||
print(" control_subfolder: ", control_subfolder)
|
print(" control_subfolder: ", control_subfolder)
|
||||||
control_model = ControlNetModel.from_pretrained(control_name,
|
control_model = ControlNetModel.from_pretrained(
|
||||||
subfolder=control_subfolder,
|
control_name, subfolder=control_subfolder,
|
||||||
torch_dtype=model.unet.dtype).to(model.device)
|
torch_dtype=model.unet.dtype).to(
|
||||||
|
model.device)
|
||||||
else:
|
else:
|
||||||
control_model = ControlNetModel.from_pretrained(control_info.control_model,
|
control_model = ControlNetModel.from_pretrained(
|
||||||
torch_dtype=model.unet.dtype).to(model.device)
|
control_info.control_model, torch_dtype=model.unet.dtype).to(model.device)
|
||||||
control_models.append(control_model)
|
control_models.append(control_model)
|
||||||
control_image_field = control_info.image
|
control_image_field = control_info.image
|
||||||
input_image = context.services.images.get_pil_image(control_image_field.image_name)
|
input_image = context.services.images.get_pil_image(
|
||||||
|
control_image_field.image_name)
|
||||||
# self.image.image_type, self.image.image_name
|
# self.image.image_type, self.image.image_name
|
||||||
# FIXME: still need to test with different widths, heights, devices, dtypes
|
# FIXME: still need to test with different widths, heights, devices, dtypes
|
||||||
# and add in batch_size, num_images_per_prompt?
|
# and add in batch_size, num_images_per_prompt?
|
||||||
@ -263,41 +275,50 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
dtype=control_model.dtype,
|
dtype=control_model.dtype,
|
||||||
control_mode=control_info.control_mode,
|
control_mode=control_info.control_mode,
|
||||||
)
|
)
|
||||||
control_item = ControlNetData(model=control_model,
|
control_item = ControlNetData(
|
||||||
image_tensor=control_image,
|
model=control_model, image_tensor=control_image,
|
||||||
weight=control_info.control_weight,
|
weight=control_info.control_weight,
|
||||||
begin_step_percent=control_info.begin_step_percent,
|
begin_step_percent=control_info.begin_step_percent,
|
||||||
end_step_percent=control_info.end_step_percent,
|
end_step_percent=control_info.end_step_percent,
|
||||||
control_mode=control_info.control_mode,
|
control_mode=control_info.control_mode,)
|
||||||
)
|
|
||||||
control_data.append(control_item)
|
control_data.append(control_item)
|
||||||
# MultiControlNetModel has been refactored out, just need list[ControlNetData]
|
# MultiControlNetModel has been refactored out, just need list[ControlNetData]
|
||||||
return control_data
|
return control_data
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||||
noise = context.services.latents.get(self.noise.latents_name)
|
noise = context.services.latents.get(self.noise.latents_name)
|
||||||
|
|
||||||
# Get the source node id (we are invoking the prepared node)
|
# Get the source node id (we are invoking the prepared node)
|
||||||
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
|
graph_execution_state = context.services.graph_execution_manager.get(
|
||||||
|
context.graph_execution_state_id)
|
||||||
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
||||||
|
|
||||||
def step_callback(state: PipelineIntermediateState):
|
def step_callback(state: PipelineIntermediateState):
|
||||||
self.dispatch_progress(context, source_node_id, state)
|
self.dispatch_progress(context, source_node_id, state)
|
||||||
|
|
||||||
unet_info = context.services.model_manager.get_model(**self.unet.unet.dict())
|
def _lora_loader():
|
||||||
with unet_info as unet:
|
for lora in self.unet.loras:
|
||||||
|
lora_info = context.services.model_manager.get_model(
|
||||||
|
**lora.dict(exclude={"weight"}))
|
||||||
|
yield (lora_info.context.model, lora.weight)
|
||||||
|
del lora_info
|
||||||
|
return
|
||||||
|
|
||||||
|
unet_info = context.services.model_manager.get_model(
|
||||||
|
**self.unet.unet.dict())
|
||||||
|
with ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\
|
||||||
|
unet_info as unet:
|
||||||
|
|
||||||
scheduler = get_scheduler(
|
scheduler = get_scheduler(
|
||||||
context=context,
|
context=context,
|
||||||
scheduler_info=self.unet.scheduler,
|
scheduler_info=self.unet.scheduler,
|
||||||
scheduler_name=self.scheduler,
|
scheduler_name=self.scheduler,
|
||||||
)
|
)
|
||||||
|
|
||||||
pipeline = self.create_pipeline(unet, scheduler)
|
pipeline = self.create_pipeline(unet, scheduler)
|
||||||
conditioning_data = self.get_conditioning_data(context, scheduler)
|
conditioning_data = self.get_conditioning_data(context, scheduler)
|
||||||
|
|
||||||
loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.unet.loras]
|
|
||||||
|
|
||||||
control_data = self.prep_control_data(
|
control_data = self.prep_control_data(
|
||||||
model=pipeline, context=context, control_input=self.control,
|
model=pipeline, context=context, control_input=self.control,
|
||||||
latents_shape=noise.shape,
|
latents_shape=noise.shape,
|
||||||
@ -305,16 +326,15 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
do_classifier_free_guidance=True,
|
do_classifier_free_guidance=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
with ModelPatcher.apply_lora_unet(pipeline.unet, loras):
|
# TODO: Verify the noise is the right size
|
||||||
# TODO: Verify the noise is the right size
|
result_latents, result_attention_map_saver = pipeline.latents_from_embeddings(
|
||||||
result_latents, result_attention_map_saver = pipeline.latents_from_embeddings(
|
latents=torch.zeros_like(noise, dtype=torch_dtype(unet.device)),
|
||||||
latents=torch.zeros_like(noise, dtype=torch_dtype(unet.device)),
|
noise=noise,
|
||||||
noise=noise,
|
num_inference_steps=self.steps,
|
||||||
num_inference_steps=self.steps,
|
conditioning_data=conditioning_data,
|
||||||
conditioning_data=conditioning_data,
|
control_data=control_data, # list[ControlNetData]
|
||||||
control_data=control_data, # list[ControlNetData]
|
callback=step_callback,
|
||||||
callback=step_callback,
|
)
|
||||||
)
|
|
||||||
|
|
||||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
@ -323,14 +343,18 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
context.services.latents.save(name, result_latents)
|
context.services.latents.save(name, result_latents)
|
||||||
return build_latents_output(latents_name=name, latents=result_latents)
|
return build_latents_output(latents_name=name, latents=result_latents)
|
||||||
|
|
||||||
|
|
||||||
class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
||||||
"""Generates latents using latents as base image."""
|
"""Generates latents using latents as base image."""
|
||||||
|
|
||||||
type: Literal["l2l"] = "l2l"
|
type: Literal["l2l"] = "l2l"
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
latents: Optional[LatentsField] = Field(description="The latents to use as a base image")
|
latents: Optional[LatentsField] = Field(
|
||||||
strength: float = Field(default=0.7, ge=0, le=1, description="The strength of the latents to use")
|
description="The latents to use as a base image")
|
||||||
|
strength: float = Field(
|
||||||
|
default=0.7, ge=0, le=1,
|
||||||
|
description="The strength of the latents to use")
|
||||||
|
|
||||||
# Schema customisation
|
# Schema customisation
|
||||||
class Config(InvocationConfig):
|
class Config(InvocationConfig):
|
||||||
@ -345,22 +369,31 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||||
noise = context.services.latents.get(self.noise.latents_name)
|
noise = context.services.latents.get(self.noise.latents_name)
|
||||||
latent = context.services.latents.get(self.latents.latents_name)
|
latent = context.services.latents.get(self.latents.latents_name)
|
||||||
|
|
||||||
# Get the source node id (we are invoking the prepared node)
|
# Get the source node id (we are invoking the prepared node)
|
||||||
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
|
graph_execution_state = context.services.graph_execution_manager.get(
|
||||||
|
context.graph_execution_state_id)
|
||||||
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
||||||
|
|
||||||
def step_callback(state: PipelineIntermediateState):
|
def step_callback(state: PipelineIntermediateState):
|
||||||
self.dispatch_progress(context, source_node_id, state)
|
self.dispatch_progress(context, source_node_id, state)
|
||||||
|
|
||||||
unet_info = context.services.model_manager.get_model(
|
def _lora_loader():
|
||||||
**self.unet.unet.dict(),
|
for lora in self.unet.loras:
|
||||||
)
|
lora_info = context.services.model_manager.get_model(
|
||||||
|
**lora.dict(exclude={"weight"}))
|
||||||
|
yield (lora_info.context.model, lora.weight)
|
||||||
|
del lora_info
|
||||||
|
return
|
||||||
|
|
||||||
with unet_info as unet:
|
unet_info = context.services.model_manager.get_model(
|
||||||
|
**self.unet.unet.dict())
|
||||||
|
with ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\
|
||||||
|
unet_info as unet:
|
||||||
|
|
||||||
scheduler = get_scheduler(
|
scheduler = get_scheduler(
|
||||||
context=context,
|
context=context,
|
||||||
@ -370,7 +403,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
|||||||
|
|
||||||
pipeline = self.create_pipeline(unet, scheduler)
|
pipeline = self.create_pipeline(unet, scheduler)
|
||||||
conditioning_data = self.get_conditioning_data(context, scheduler)
|
conditioning_data = self.get_conditioning_data(context, scheduler)
|
||||||
|
|
||||||
control_data = self.prep_control_data(
|
control_data = self.prep_control_data(
|
||||||
model=pipeline, context=context, control_input=self.control,
|
model=pipeline, context=context, control_input=self.control,
|
||||||
latents_shape=noise.shape,
|
latents_shape=noise.shape,
|
||||||
@ -380,8 +413,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
|||||||
|
|
||||||
# TODO: Verify the noise is the right size
|
# TODO: Verify the noise is the right size
|
||||||
initial_latents = latent if self.strength < 1.0 else torch.zeros_like(
|
initial_latents = latent if self.strength < 1.0 else torch.zeros_like(
|
||||||
latent, device=unet.device, dtype=latent.dtype
|
latent, device=unet.device, dtype=latent.dtype)
|
||||||
)
|
|
||||||
|
|
||||||
timesteps, _ = pipeline.get_img2img_timesteps(
|
timesteps, _ = pipeline.get_img2img_timesteps(
|
||||||
self.steps,
|
self.steps,
|
||||||
@ -389,18 +421,15 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
|||||||
device=unet.device,
|
device=unet.device,
|
||||||
)
|
)
|
||||||
|
|
||||||
loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.unet.loras]
|
result_latents, result_attention_map_saver = pipeline.latents_from_embeddings(
|
||||||
|
latents=initial_latents,
|
||||||
with ModelPatcher.apply_lora_unet(pipeline.unet, loras):
|
timesteps=timesteps,
|
||||||
result_latents, result_attention_map_saver = pipeline.latents_from_embeddings(
|
noise=noise,
|
||||||
latents=initial_latents,
|
num_inference_steps=self.steps,
|
||||||
timesteps=timesteps,
|
conditioning_data=conditioning_data,
|
||||||
noise=noise,
|
control_data=control_data, # list[ControlNetData]
|
||||||
num_inference_steps=self.steps,
|
callback=step_callback
|
||||||
conditioning_data=conditioning_data,
|
)
|
||||||
control_data=control_data, # list[ControlNetData]
|
|
||||||
callback=step_callback
|
|
||||||
)
|
|
||||||
|
|
||||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
@ -417,9 +446,12 @@ class LatentsToImageInvocation(BaseInvocation):
|
|||||||
type: Literal["l2i"] = "l2i"
|
type: Literal["l2i"] = "l2i"
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
latents: Optional[LatentsField] = Field(description="The latents to generate an image from")
|
latents: Optional[LatentsField] = Field(
|
||||||
|
description="The latents to generate an image from")
|
||||||
vae: VaeField = Field(default=None, description="Vae submodel")
|
vae: VaeField = Field(default=None, description="Vae submodel")
|
||||||
tiled: bool = Field(default=False, description="Decode latents by overlaping tiles(less memory consumption)")
|
tiled: bool = Field(
|
||||||
|
default=False,
|
||||||
|
description="Decode latents by overlaping tiles(less memory consumption)")
|
||||||
|
|
||||||
# Schema customisation
|
# Schema customisation
|
||||||
class Config(InvocationConfig):
|
class Config(InvocationConfig):
|
||||||
@ -450,7 +482,7 @@ class LatentsToImageInvocation(BaseInvocation):
|
|||||||
# copied from diffusers pipeline
|
# copied from diffusers pipeline
|
||||||
latents = latents / vae.config.scaling_factor
|
latents = latents / vae.config.scaling_factor
|
||||||
image = vae.decode(latents, return_dict=False)[0]
|
image = vae.decode(latents, return_dict=False)[0]
|
||||||
image = (image / 2 + 0.5).clamp(0, 1) # denormalize
|
image = (image / 2 + 0.5).clamp(0, 1) # denormalize
|
||||||
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
||||||
np_image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
np_image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
||||||
|
|
||||||
@ -473,9 +505,9 @@ class LatentsToImageInvocation(BaseInvocation):
|
|||||||
height=image_dto.height,
|
height=image_dto.height,
|
||||||
)
|
)
|
||||||
|
|
||||||
LATENTS_INTERPOLATION_MODE = Literal[
|
|
||||||
"nearest", "linear", "bilinear", "bicubic", "trilinear", "area", "nearest-exact"
|
LATENTS_INTERPOLATION_MODE = Literal["nearest", "linear",
|
||||||
]
|
"bilinear", "bicubic", "trilinear", "area", "nearest-exact"]
|
||||||
|
|
||||||
|
|
||||||
class ResizeLatentsInvocation(BaseInvocation):
|
class ResizeLatentsInvocation(BaseInvocation):
|
||||||
@ -484,21 +516,25 @@ class ResizeLatentsInvocation(BaseInvocation):
|
|||||||
type: Literal["lresize"] = "lresize"
|
type: Literal["lresize"] = "lresize"
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
latents: Optional[LatentsField] = Field(description="The latents to resize")
|
latents: Optional[LatentsField] = Field(
|
||||||
width: int = Field(ge=64, multiple_of=8, description="The width to resize to (px)")
|
description="The latents to resize")
|
||||||
height: int = Field(ge=64, multiple_of=8, description="The height to resize to (px)")
|
width: int = Field(
|
||||||
mode: LATENTS_INTERPOLATION_MODE = Field(default="bilinear", description="The interpolation mode")
|
ge=64, multiple_of=8, description="The width to resize to (px)")
|
||||||
antialias: bool = Field(default=False, description="Whether or not to antialias (applied in bilinear and bicubic modes only)")
|
height: int = Field(
|
||||||
|
ge=64, multiple_of=8, description="The height to resize to (px)")
|
||||||
|
mode: LATENTS_INTERPOLATION_MODE = Field(
|
||||||
|
default="bilinear", description="The interpolation mode")
|
||||||
|
antialias: bool = Field(
|
||||||
|
default=False,
|
||||||
|
description="Whether or not to antialias (applied in bilinear and bicubic modes only)")
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||||
latents = context.services.latents.get(self.latents.latents_name)
|
latents = context.services.latents.get(self.latents.latents_name)
|
||||||
|
|
||||||
resized_latents = torch.nn.functional.interpolate(
|
resized_latents = torch.nn.functional.interpolate(
|
||||||
latents,
|
latents, size=(self.height // 8, self.width // 8),
|
||||||
size=(self.height // 8, self.width // 8),
|
mode=self.mode, antialias=self.antialias
|
||||||
mode=self.mode,
|
if self.mode in ["bilinear", "bicubic"] else False,)
|
||||||
antialias=self.antialias if self.mode in ["bilinear", "bicubic"] else False,
|
|
||||||
)
|
|
||||||
|
|
||||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
@ -515,21 +551,24 @@ class ScaleLatentsInvocation(BaseInvocation):
|
|||||||
type: Literal["lscale"] = "lscale"
|
type: Literal["lscale"] = "lscale"
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
latents: Optional[LatentsField] = Field(description="The latents to scale")
|
latents: Optional[LatentsField] = Field(
|
||||||
scale_factor: float = Field(gt=0, description="The factor by which to scale the latents")
|
description="The latents to scale")
|
||||||
mode: LATENTS_INTERPOLATION_MODE = Field(default="bilinear", description="The interpolation mode")
|
scale_factor: float = Field(
|
||||||
antialias: bool = Field(default=False, description="Whether or not to antialias (applied in bilinear and bicubic modes only)")
|
gt=0, description="The factor by which to scale the latents")
|
||||||
|
mode: LATENTS_INTERPOLATION_MODE = Field(
|
||||||
|
default="bilinear", description="The interpolation mode")
|
||||||
|
antialias: bool = Field(
|
||||||
|
default=False,
|
||||||
|
description="Whether or not to antialias (applied in bilinear and bicubic modes only)")
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||||
latents = context.services.latents.get(self.latents.latents_name)
|
latents = context.services.latents.get(self.latents.latents_name)
|
||||||
|
|
||||||
# resizing
|
# resizing
|
||||||
resized_latents = torch.nn.functional.interpolate(
|
resized_latents = torch.nn.functional.interpolate(
|
||||||
latents,
|
latents, scale_factor=self.scale_factor, mode=self.mode,
|
||||||
scale_factor=self.scale_factor,
|
antialias=self.antialias
|
||||||
mode=self.mode,
|
if self.mode in ["bilinear", "bicubic"] else False,)
|
||||||
antialias=self.antialias if self.mode in ["bilinear", "bicubic"] else False,
|
|
||||||
)
|
|
||||||
|
|
||||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
@ -548,7 +587,9 @@ class ImageToLatentsInvocation(BaseInvocation):
|
|||||||
# Inputs
|
# Inputs
|
||||||
image: Union[ImageField, None] = Field(description="The image to encode")
|
image: Union[ImageField, None] = Field(description="The image to encode")
|
||||||
vae: VaeField = Field(default=None, description="Vae submodel")
|
vae: VaeField = Field(default=None, description="Vae submodel")
|
||||||
tiled: bool = Field(default=False, description="Encode latents by overlaping tiles(less memory consumption)")
|
tiled: bool = Field(
|
||||||
|
default=False,
|
||||||
|
description="Encode latents by overlaping tiles(less memory consumption)")
|
||||||
|
|
||||||
# Schema customisation
|
# Schema customisation
|
||||||
class Config(InvocationConfig):
|
class Config(InvocationConfig):
|
||||||
|
@ -1,31 +1,38 @@
|
|||||||
from typing import Literal, Optional, Union, List
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
import copy
|
import copy
|
||||||
|
from typing import List, Literal, Optional, Union
|
||||||
|
|
||||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from ...backend.util.devices import choose_torch_device, torch_dtype
|
|
||||||
from ...backend.model_management import BaseModelType, ModelType, SubModelType
|
from ...backend.model_management import BaseModelType, ModelType, SubModelType
|
||||||
|
from .baseinvocation import (BaseInvocation, BaseInvocationOutput,
|
||||||
|
InvocationConfig, InvocationContext)
|
||||||
|
|
||||||
|
|
||||||
class ModelInfo(BaseModel):
|
class ModelInfo(BaseModel):
|
||||||
model_name: str = Field(description="Info to load submodel")
|
model_name: str = Field(description="Info to load submodel")
|
||||||
base_model: BaseModelType = Field(description="Base model")
|
base_model: BaseModelType = Field(description="Base model")
|
||||||
model_type: ModelType = Field(description="Info to load submodel")
|
model_type: ModelType = Field(description="Info to load submodel")
|
||||||
submodel: Optional[SubModelType] = Field(description="Info to load submodel")
|
submodel: Optional[SubModelType] = Field(
|
||||||
|
default=None, description="Info to load submodel"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class LoraInfo(ModelInfo):
|
class LoraInfo(ModelInfo):
|
||||||
weight: float = Field(description="Lora's weight which to use when apply to model")
|
weight: float = Field(description="Lora's weight which to use when apply to model")
|
||||||
|
|
||||||
|
|
||||||
class UNetField(BaseModel):
|
class UNetField(BaseModel):
|
||||||
unet: ModelInfo = Field(description="Info to load unet submodel")
|
unet: ModelInfo = Field(description="Info to load unet submodel")
|
||||||
scheduler: ModelInfo = Field(description="Info to load scheduler submodel")
|
scheduler: ModelInfo = Field(description="Info to load scheduler submodel")
|
||||||
loras: List[LoraInfo] = Field(description="Loras to apply on model loading")
|
loras: List[LoraInfo] = Field(description="Loras to apply on model loading")
|
||||||
|
|
||||||
|
|
||||||
class ClipField(BaseModel):
|
class ClipField(BaseModel):
|
||||||
tokenizer: ModelInfo = Field(description="Info to load tokenizer submodel")
|
tokenizer: ModelInfo = Field(description="Info to load tokenizer submodel")
|
||||||
text_encoder: ModelInfo = Field(description="Info to load text_encoder submodel")
|
text_encoder: ModelInfo = Field(description="Info to load text_encoder submodel")
|
||||||
loras: List[LoraInfo] = Field(description="Loras to apply on model loading")
|
loras: List[LoraInfo] = Field(description="Loras to apply on model loading")
|
||||||
|
|
||||||
|
|
||||||
class VaeField(BaseModel):
|
class VaeField(BaseModel):
|
||||||
# TODO: better naming?
|
# TODO: better naming?
|
||||||
vae: ModelInfo = Field(description="Info to load vae submodel")
|
vae: ModelInfo = Field(description="Info to load vae submodel")
|
||||||
@ -34,43 +41,48 @@ class VaeField(BaseModel):
|
|||||||
class ModelLoaderOutput(BaseInvocationOutput):
|
class ModelLoaderOutput(BaseInvocationOutput):
|
||||||
"""Model loader output"""
|
"""Model loader output"""
|
||||||
|
|
||||||
#fmt: off
|
# fmt: off
|
||||||
type: Literal["model_loader_output"] = "model_loader_output"
|
type: Literal["model_loader_output"] = "model_loader_output"
|
||||||
|
|
||||||
unet: UNetField = Field(default=None, description="UNet submodel")
|
unet: UNetField = Field(default=None, description="UNet submodel")
|
||||||
clip: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels")
|
clip: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels")
|
||||||
vae: VaeField = Field(default=None, description="Vae submodel")
|
vae: VaeField = Field(default=None, description="Vae submodel")
|
||||||
#fmt: on
|
# fmt: on
|
||||||
|
|
||||||
|
|
||||||
class PipelineModelField(BaseModel):
|
class MainModelField(BaseModel):
|
||||||
"""Pipeline model field"""
|
"""Main model field"""
|
||||||
|
|
||||||
model_name: str = Field(description="Name of the model")
|
model_name: str = Field(description="Name of the model")
|
||||||
base_model: BaseModelType = Field(description="Base model")
|
base_model: BaseModelType = Field(description="Base model")
|
||||||
|
|
||||||
|
|
||||||
class PipelineModelLoaderInvocation(BaseInvocation):
|
class LoRAModelField(BaseModel):
|
||||||
"""Loads a pipeline model, outputting its submodels."""
|
"""LoRA model field"""
|
||||||
|
|
||||||
type: Literal["pipeline_model_loader"] = "pipeline_model_loader"
|
model_name: str = Field(description="Name of the LoRA model")
|
||||||
|
base_model: BaseModelType = Field(description="Base model")
|
||||||
|
|
||||||
model: PipelineModelField = Field(description="The model to load")
|
|
||||||
|
class MainModelLoaderInvocation(BaseInvocation):
|
||||||
|
"""Loads a main model, outputting its submodels."""
|
||||||
|
|
||||||
|
type: Literal["main_model_loader"] = "main_model_loader"
|
||||||
|
|
||||||
|
model: MainModelField = Field(description="The model to load")
|
||||||
# TODO: precision?
|
# TODO: precision?
|
||||||
|
|
||||||
# Schema customisation
|
# Schema customisation
|
||||||
class Config(InvocationConfig):
|
class Config(InvocationConfig):
|
||||||
schema_extra = {
|
schema_extra = {
|
||||||
"ui": {
|
"ui": {
|
||||||
|
"title": "Model Loader",
|
||||||
"tags": ["model", "loader"],
|
"tags": ["model", "loader"],
|
||||||
"type_hints": {
|
"type_hints": {"model": "model"},
|
||||||
"model": "model"
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ModelLoaderOutput:
|
def invoke(self, context: InvocationContext) -> ModelLoaderOutput:
|
||||||
|
|
||||||
base_model = self.model.base_model
|
base_model = self.model.base_model
|
||||||
model_name = self.model.model_name
|
model_name = self.model.model_name
|
||||||
model_type = ModelType.Main
|
model_type = ModelType.Main
|
||||||
@ -112,7 +124,6 @@ class PipelineModelLoaderInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
return ModelLoaderOutput(
|
return ModelLoaderOutput(
|
||||||
unet=UNetField(
|
unet=UNetField(
|
||||||
unet=ModelInfo(
|
unet=ModelInfo(
|
||||||
@ -151,47 +162,66 @@ class PipelineModelLoaderInvocation(BaseInvocation):
|
|||||||
model_type=model_type,
|
model_type=model_type,
|
||||||
submodel=SubModelType.Vae,
|
submodel=SubModelType.Vae,
|
||||||
),
|
),
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class LoraLoaderOutput(BaseInvocationOutput):
|
class LoraLoaderOutput(BaseInvocationOutput):
|
||||||
"""Model loader output"""
|
"""Model loader output"""
|
||||||
|
|
||||||
#fmt: off
|
# fmt: off
|
||||||
type: Literal["lora_loader_output"] = "lora_loader_output"
|
type: Literal["lora_loader_output"] = "lora_loader_output"
|
||||||
|
|
||||||
unet: Optional[UNetField] = Field(default=None, description="UNet submodel")
|
unet: Optional[UNetField] = Field(default=None, description="UNet submodel")
|
||||||
clip: Optional[ClipField] = Field(default=None, description="Tokenizer and text_encoder submodels")
|
clip: Optional[ClipField] = Field(default=None, description="Tokenizer and text_encoder submodels")
|
||||||
#fmt: on
|
# fmt: on
|
||||||
|
|
||||||
|
|
||||||
class LoraLoaderInvocation(BaseInvocation):
|
class LoraLoaderInvocation(BaseInvocation):
|
||||||
"""Apply selected lora to unet and text_encoder."""
|
"""Apply selected lora to unet and text_encoder."""
|
||||||
|
|
||||||
type: Literal["lora_loader"] = "lora_loader"
|
type: Literal["lora_loader"] = "lora_loader"
|
||||||
|
|
||||||
lora_name: str = Field(description="Lora model name")
|
lora: Union[LoRAModelField, None] = Field(
|
||||||
|
default=None, description="Lora model name"
|
||||||
|
)
|
||||||
weight: float = Field(default=0.75, description="With what weight to apply lora")
|
weight: float = Field(default=0.75, description="With what weight to apply lora")
|
||||||
|
|
||||||
unet: Optional[UNetField] = Field(description="UNet model for applying lora")
|
unet: Optional[UNetField] = Field(description="UNet model for applying lora")
|
||||||
clip: Optional[ClipField] = Field(description="Clip model for applying lora")
|
clip: Optional[ClipField] = Field(description="Clip model for applying lora")
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> LoraLoaderOutput:
|
class Config(InvocationConfig):
|
||||||
|
schema_extra = {
|
||||||
|
"ui": {
|
||||||
|
"title": "Lora Loader",
|
||||||
|
"tags": ["lora", "loader"],
|
||||||
|
"type_hints": {"lora": "lora_model"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
# TODO: ui rewrite
|
def invoke(self, context: InvocationContext) -> LoraLoaderOutput:
|
||||||
base_model = BaseModelType.StableDiffusion1
|
if self.lora is None:
|
||||||
|
raise Exception("No LoRA provided")
|
||||||
|
|
||||||
|
base_model = self.lora.base_model
|
||||||
|
lora_name = self.lora.model_name
|
||||||
|
|
||||||
if not context.services.model_manager.model_exists(
|
if not context.services.model_manager.model_exists(
|
||||||
base_model=base_model,
|
base_model=base_model,
|
||||||
model_name=self.lora_name,
|
model_name=lora_name,
|
||||||
model_type=ModelType.Lora,
|
model_type=ModelType.Lora,
|
||||||
):
|
):
|
||||||
raise Exception(f"Unkown lora name: {self.lora_name}!")
|
raise Exception(f"Unkown lora name: {lora_name}!")
|
||||||
|
|
||||||
if self.unet is not None and any(lora.model_name == self.lora_name for lora in self.unet.loras):
|
if self.unet is not None and any(
|
||||||
raise Exception(f"Lora \"{self.lora_name}\" already applied to unet")
|
lora.model_name == lora_name for lora in self.unet.loras
|
||||||
|
):
|
||||||
|
raise Exception(f'Lora "{lora_name}" already applied to unet')
|
||||||
|
|
||||||
if self.clip is not None and any(lora.model_name == self.lora_name for lora in self.clip.loras):
|
if self.clip is not None and any(
|
||||||
raise Exception(f"Lora \"{self.lora_name}\" already applied to clip")
|
lora.model_name == lora_name for lora in self.clip.loras
|
||||||
|
):
|
||||||
|
raise Exception(f'Lora "{lora_name}" already applied to clip')
|
||||||
|
|
||||||
output = LoraLoaderOutput()
|
output = LoraLoaderOutput()
|
||||||
|
|
||||||
@ -200,7 +230,7 @@ class LoraLoaderInvocation(BaseInvocation):
|
|||||||
output.unet.loras.append(
|
output.unet.loras.append(
|
||||||
LoraInfo(
|
LoraInfo(
|
||||||
base_model=base_model,
|
base_model=base_model,
|
||||||
model_name=self.lora_name,
|
model_name=lora_name,
|
||||||
model_type=ModelType.Lora,
|
model_type=ModelType.Lora,
|
||||||
submodel=None,
|
submodel=None,
|
||||||
weight=self.weight,
|
weight=self.weight,
|
||||||
@ -212,7 +242,7 @@ class LoraLoaderInvocation(BaseInvocation):
|
|||||||
output.clip.loras.append(
|
output.clip.loras.append(
|
||||||
LoraInfo(
|
LoraInfo(
|
||||||
base_model=base_model,
|
base_model=base_model,
|
||||||
model_name=self.lora_name,
|
model_name=lora_name,
|
||||||
model_type=ModelType.Lora,
|
model_type=ModelType.Lora,
|
||||||
submodel=None,
|
submodel=None,
|
||||||
weight=self.weight,
|
weight=self.weight,
|
||||||
@ -221,3 +251,58 @@ class LoraLoaderInvocation(BaseInvocation):
|
|||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class VAEModelField(BaseModel):
|
||||||
|
"""Vae model field"""
|
||||||
|
|
||||||
|
model_name: str = Field(description="Name of the model")
|
||||||
|
base_model: BaseModelType = Field(description="Base model")
|
||||||
|
|
||||||
|
|
||||||
|
class VaeLoaderOutput(BaseInvocationOutput):
|
||||||
|
"""Model loader output"""
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
type: Literal["vae_loader_output"] = "vae_loader_output"
|
||||||
|
|
||||||
|
vae: VaeField = Field(default=None, description="Vae model")
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
|
||||||
|
class VaeLoaderInvocation(BaseInvocation):
|
||||||
|
"""Loads a VAE model, outputting a VaeLoaderOutput"""
|
||||||
|
|
||||||
|
type: Literal["vae_loader"] = "vae_loader"
|
||||||
|
|
||||||
|
vae_model: VAEModelField = Field(description="The VAE to load")
|
||||||
|
|
||||||
|
# Schema customisation
|
||||||
|
class Config(InvocationConfig):
|
||||||
|
schema_extra = {
|
||||||
|
"ui": {
|
||||||
|
"title": "VAE Loader",
|
||||||
|
"tags": ["vae", "loader"],
|
||||||
|
"type_hints": {"vae_model": "vae_model"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
def invoke(self, context: InvocationContext) -> VaeLoaderOutput:
|
||||||
|
base_model = self.vae_model.base_model
|
||||||
|
model_name = self.vae_model.model_name
|
||||||
|
model_type = ModelType.Vae
|
||||||
|
|
||||||
|
if not context.services.model_manager.model_exists(
|
||||||
|
base_model=base_model,
|
||||||
|
model_name=model_name,
|
||||||
|
model_type=model_type,
|
||||||
|
):
|
||||||
|
raise Exception(f"Unkown vae name: {model_name}!")
|
||||||
|
return VaeLoaderOutput(
|
||||||
|
vae=VaeField(
|
||||||
|
vae=ModelInfo(
|
||||||
|
model_name=model_name,
|
||||||
|
base_model=base_model,
|
||||||
|
model_type=model_type,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
@ -228,10 +228,10 @@ class InvokeAISettings(BaseSettings):
|
|||||||
upcase_environ = dict()
|
upcase_environ = dict()
|
||||||
for key,value in os.environ.items():
|
for key,value in os.environ.items():
|
||||||
upcase_environ[key.upper()] = value
|
upcase_environ[key.upper()] = value
|
||||||
|
|
||||||
fields = cls.__fields__
|
fields = cls.__fields__
|
||||||
cls.argparse_groups = {}
|
cls.argparse_groups = {}
|
||||||
|
|
||||||
for name, field in fields.items():
|
for name, field in fields.items():
|
||||||
if name not in cls._excluded():
|
if name not in cls._excluded():
|
||||||
current_default = field.default
|
current_default = field.default
|
||||||
@ -348,7 +348,7 @@ setting environment variables INVOKEAI_<setting>.
|
|||||||
'''
|
'''
|
||||||
singleton_config: ClassVar[InvokeAIAppConfig] = None
|
singleton_config: ClassVar[InvokeAIAppConfig] = None
|
||||||
singleton_init: ClassVar[Dict] = None
|
singleton_init: ClassVar[Dict] = None
|
||||||
|
|
||||||
#fmt: off
|
#fmt: off
|
||||||
type: Literal["InvokeAI"] = "InvokeAI"
|
type: Literal["InvokeAI"] = "InvokeAI"
|
||||||
host : str = Field(default="127.0.0.1", description="IP address to bind to", category='Web Server')
|
host : str = Field(default="127.0.0.1", description="IP address to bind to", category='Web Server')
|
||||||
@ -367,7 +367,8 @@ setting environment variables INVOKEAI_<setting>.
|
|||||||
|
|
||||||
always_use_cpu : bool = Field(default=False, description="If true, use the CPU for rendering even if a GPU is available.", category='Memory/Performance')
|
always_use_cpu : bool = Field(default=False, description="If true, use the CPU for rendering even if a GPU is available.", category='Memory/Performance')
|
||||||
free_gpu_mem : bool = Field(default=False, description="If true, purge model from GPU after each generation.", category='Memory/Performance')
|
free_gpu_mem : bool = Field(default=False, description="If true, purge model from GPU after each generation.", category='Memory/Performance')
|
||||||
max_loaded_models : int = Field(default=3, gt=0, description="Maximum number of models to keep in memory for rapid switching", category='Memory/Performance')
|
max_loaded_models : int = Field(default=3, gt=0, description="(DEPRECATED: use max_cache_size) Maximum number of models to keep in memory for rapid switching", category='Memory/Performance')
|
||||||
|
max_cache_size : float = Field(default=6.0, gt=0, description="Maximum memory amount used by model cache for rapid switching", category='Memory/Performance')
|
||||||
precision : Literal[tuple(['auto','float16','float32','autocast'])] = Field(default='float16',description='Floating point precision', category='Memory/Performance')
|
precision : Literal[tuple(['auto','float16','float32','autocast'])] = Field(default='float16',description='Floating point precision', category='Memory/Performance')
|
||||||
sequential_guidance : bool = Field(default=False, description="Whether to calculate guidance in serial instead of in parallel, lowering memory requirements", category='Memory/Performance')
|
sequential_guidance : bool = Field(default=False, description="Whether to calculate guidance in serial instead of in parallel, lowering memory requirements", category='Memory/Performance')
|
||||||
xformers_enabled : bool = Field(default=True, description="Enable/disable memory-efficient attention", category='Memory/Performance')
|
xformers_enabled : bool = Field(default=True, description="Enable/disable memory-efficient attention", category='Memory/Performance')
|
||||||
@ -385,9 +386,9 @@ setting environment variables INVOKEAI_<setting>.
|
|||||||
outdir : Path = Field(default='outputs', description='Default folder for output images', category='Paths')
|
outdir : Path = Field(default='outputs', description='Default folder for output images', category='Paths')
|
||||||
from_file : Path = Field(default=None, description='Take command input from the indicated file (command-line client only)', category='Paths')
|
from_file : Path = Field(default=None, description='Take command input from the indicated file (command-line client only)', category='Paths')
|
||||||
use_memory_db : bool = Field(default=False, description='Use in-memory database for storing image metadata', category='Paths')
|
use_memory_db : bool = Field(default=False, description='Use in-memory database for storing image metadata', category='Paths')
|
||||||
|
|
||||||
model : str = Field(default='stable-diffusion-1.5', description='Initial model name', category='Models')
|
model : str = Field(default='stable-diffusion-1.5', description='Initial model name', category='Models')
|
||||||
|
|
||||||
log_handlers : List[str] = Field(default=["console"], description='Log handler. Valid options are "console", "file=<path>", "syslog=path|address:host:port", "http=<url>"', category="Logging")
|
log_handlers : List[str] = Field(default=["console"], description='Log handler. Valid options are "console", "file=<path>", "syslog=path|address:host:port", "http=<url>"', category="Logging")
|
||||||
# note - would be better to read the log_format values from logging.py, but this creates circular dependencies issues
|
# note - would be better to read the log_format values from logging.py, but this creates circular dependencies issues
|
||||||
log_format : Literal[tuple(['plain','color','syslog','legacy'])] = Field(default="color", description='Log format. Use "plain" for text-only, "color" for colorized output, "legacy" for 2.3-style logging and "syslog" for syslog-style', category="Logging")
|
log_format : Literal[tuple(['plain','color','syslog','legacy'])] = Field(default="color", description='Log format. Use "plain" for text-only, "color" for colorized output, "legacy" for 2.3-style logging and "syslog" for syslog-style', category="Logging")
|
||||||
@ -396,7 +397,7 @@ setting environment variables INVOKEAI_<setting>.
|
|||||||
|
|
||||||
def parse_args(self, argv: List[str]=None, conf: DictConfig = None, clobber=False):
|
def parse_args(self, argv: List[str]=None, conf: DictConfig = None, clobber=False):
|
||||||
'''
|
'''
|
||||||
Update settings with contents of init file, environment, and
|
Update settings with contents of init file, environment, and
|
||||||
command-line settings.
|
command-line settings.
|
||||||
:param conf: alternate Omegaconf dictionary object
|
:param conf: alternate Omegaconf dictionary object
|
||||||
:param argv: aternate sys.argv list
|
:param argv: aternate sys.argv list
|
||||||
@ -411,7 +412,7 @@ setting environment variables INVOKEAI_<setting>.
|
|||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
InvokeAISettings.initconf = conf
|
InvokeAISettings.initconf = conf
|
||||||
|
|
||||||
# parse args again in order to pick up settings in configuration file
|
# parse args again in order to pick up settings in configuration file
|
||||||
super().parse_args(argv)
|
super().parse_args(argv)
|
||||||
|
|
||||||
@ -431,7 +432,7 @@ setting environment variables INVOKEAI_<setting>.
|
|||||||
cls.singleton_config = cls(**kwargs)
|
cls.singleton_config = cls(**kwargs)
|
||||||
cls.singleton_init = kwargs
|
cls.singleton_init = kwargs
|
||||||
return cls.singleton_config
|
return cls.singleton_config
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def root_path(self)->Path:
|
def root_path(self)->Path:
|
||||||
'''
|
'''
|
||||||
|
@ -33,13 +33,13 @@ class ModelManagerServiceBase(ABC):
|
|||||||
logger: types.ModuleType,
|
logger: types.ModuleType,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initialize with the path to the models.yaml config file.
|
Initialize with the path to the models.yaml config file.
|
||||||
Optional parameters are the torch device type, precision, max_models,
|
Optional parameters are the torch device type, precision, max_models,
|
||||||
and sequential_offload boolean. Note that the default device
|
and sequential_offload boolean. Note that the default device
|
||||||
type and precision are set up for a CUDA system running at half precision.
|
type and precision are set up for a CUDA system running at half precision.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_model(
|
def get_model(
|
||||||
self,
|
self,
|
||||||
@ -50,8 +50,8 @@ class ModelManagerServiceBase(ABC):
|
|||||||
node: Optional[BaseInvocation] = None,
|
node: Optional[BaseInvocation] = None,
|
||||||
context: Optional[InvocationContext] = None,
|
context: Optional[InvocationContext] = None,
|
||||||
) -> ModelInfo:
|
) -> ModelInfo:
|
||||||
"""Retrieve the indicated model with name and type.
|
"""Retrieve the indicated model with name and type.
|
||||||
submodel can be used to get a part (such as the vae)
|
submodel can be used to get a part (such as the vae)
|
||||||
of a diffusers pipeline."""
|
of a diffusers pipeline."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -115,8 +115,8 @@ class ModelManagerServiceBase(ABC):
|
|||||||
"""
|
"""
|
||||||
Update the named model with a dictionary of attributes. Will fail with an
|
Update the named model with a dictionary of attributes. Will fail with an
|
||||||
assertion error if the name already exists. Pass clobber=True to overwrite.
|
assertion error if the name already exists. Pass clobber=True to overwrite.
|
||||||
On a successful update, the config will be changed in memory. Will fail
|
On a successful update, the config will be changed in memory. Will fail
|
||||||
with an assertion error if provided attributes are incorrect or
|
with an assertion error if provided attributes are incorrect or
|
||||||
the model name is missing. Call commit() to write changes to disk.
|
the model name is missing. Call commit() to write changes to disk.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
@ -129,12 +129,35 @@ class ModelManagerServiceBase(ABC):
|
|||||||
model_type: ModelType,
|
model_type: ModelType,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Delete the named model from configuration. If delete_files is true,
|
Delete the named model from configuration. If delete_files is true,
|
||||||
then the underlying weight file or diffusers directory will be deleted
|
then the underlying weight file or diffusers directory will be deleted
|
||||||
as well. Call commit() to write to disk.
|
as well. Call commit() to write to disk.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def heuristic_import(self,
|
||||||
|
items_to_import: Set[str],
|
||||||
|
prediction_type_helper: Callable[[Path],SchedulerPredictionType]=None,
|
||||||
|
)->Dict[str, AddModelResult]:
|
||||||
|
'''Import a list of paths, repo_ids or URLs. Returns the set of
|
||||||
|
successfully imported items.
|
||||||
|
:param items_to_import: Set of strings corresponding to models to be imported.
|
||||||
|
:param prediction_type_helper: A callback that receives the Path of a Stable Diffusion 2 checkpoint model and returns a SchedulerPredictionType.
|
||||||
|
|
||||||
|
The prediction type helper is necessary to distinguish between
|
||||||
|
models based on Stable Diffusion 2 Base (requiring
|
||||||
|
SchedulerPredictionType.Epsilson) and Stable Diffusion 768
|
||||||
|
(requiring SchedulerPredictionType.VPrediction). It is
|
||||||
|
generally impossible to do this programmatically, so the
|
||||||
|
prediction_type_helper usually asks the user to choose.
|
||||||
|
|
||||||
|
The result is a set of successfully installed models. Each element
|
||||||
|
of the set is a dict corresponding to the newly-created OmegaConf stanza for
|
||||||
|
that model.
|
||||||
|
'''
|
||||||
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def commit(self, conf_file: Path = None) -> None:
|
def commit(self, conf_file: Path = None) -> None:
|
||||||
"""
|
"""
|
||||||
@ -153,7 +176,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
logger: types.ModuleType,
|
logger: types.ModuleType,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initialize with the path to the models.yaml config file.
|
Initialize with the path to the models.yaml config file.
|
||||||
Optional parameters are the torch device type, precision, max_models,
|
Optional parameters are the torch device type, precision, max_models,
|
||||||
and sequential_offload boolean. Note that the default device
|
and sequential_offload boolean. Note that the default device
|
||||||
type and precision are set up for a CUDA system running at half precision.
|
type and precision are set up for a CUDA system running at half precision.
|
||||||
@ -183,6 +206,8 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
if hasattr(config,'max_cache_size') \
|
if hasattr(config,'max_cache_size') \
|
||||||
else config.max_loaded_models * 2.5
|
else config.max_loaded_models * 2.5
|
||||||
|
|
||||||
|
logger.debug(f"Maximum RAM cache size: {max_cache_size} GiB")
|
||||||
|
|
||||||
sequential_offload = config.sequential_guidance
|
sequential_offload = config.sequential_guidance
|
||||||
|
|
||||||
self.mgr = ModelManager(
|
self.mgr = ModelManager(
|
||||||
@ -238,7 +263,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
submodel=submodel,
|
submodel=submodel,
|
||||||
model_info=model_info
|
model_info=model_info
|
||||||
)
|
)
|
||||||
|
|
||||||
return model_info
|
return model_info
|
||||||
|
|
||||||
def model_exists(
|
def model_exists(
|
||||||
@ -291,8 +316,8 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
"""
|
"""
|
||||||
Update the named model with a dictionary of attributes. Will fail with an
|
Update the named model with a dictionary of attributes. Will fail with an
|
||||||
assertion error if the name already exists. Pass clobber=True to overwrite.
|
assertion error if the name already exists. Pass clobber=True to overwrite.
|
||||||
On a successful update, the config will be changed in memory. Will fail
|
On a successful update, the config will be changed in memory. Will fail
|
||||||
with an assertion error if provided attributes are incorrect or
|
with an assertion error if provided attributes are incorrect or
|
||||||
the model name is missing. Call commit() to write changes to disk.
|
the model name is missing. Call commit() to write changes to disk.
|
||||||
"""
|
"""
|
||||||
return self.mgr.add_model(model_name, base_model, model_type, model_attributes, clobber)
|
return self.mgr.add_model(model_name, base_model, model_type, model_attributes, clobber)
|
||||||
@ -305,8 +330,8 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
model_type: ModelType,
|
model_type: ModelType,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Delete the named model from configuration. If delete_files is true,
|
Delete the named model from configuration. If delete_files is true,
|
||||||
then the underlying weight file or diffusers directory will be deleted
|
then the underlying weight file or diffusers directory will be deleted
|
||||||
as well. Call commit() to write to disk.
|
as well. Call commit() to write to disk.
|
||||||
"""
|
"""
|
||||||
self.mgr.del_model(model_name, base_model, model_type)
|
self.mgr.del_model(model_name, base_model, model_type)
|
||||||
@ -360,4 +385,25 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
@property
|
@property
|
||||||
def logger(self):
|
def logger(self):
|
||||||
return self.mgr.logger
|
return self.mgr.logger
|
||||||
|
|
||||||
|
def heuristic_import(self,
|
||||||
|
items_to_import: Set[str],
|
||||||
|
prediction_type_helper: Callable[[Path],SchedulerPredictionType]=None,
|
||||||
|
)->Dict[str, AddModelResult]:
|
||||||
|
'''Import a list of paths, repo_ids or URLs. Returns the set of
|
||||||
|
successfully imported items.
|
||||||
|
:param items_to_import: Set of strings corresponding to models to be imported.
|
||||||
|
:param prediction_type_helper: A callback that receives the Path of a Stable Diffusion 2 checkpoint model and returns a SchedulerPredictionType.
|
||||||
|
|
||||||
|
The prediction type helper is necessary to distinguish between
|
||||||
|
models based on Stable Diffusion 2 Base (requiring
|
||||||
|
SchedulerPredictionType.Epsilson) and Stable Diffusion 768
|
||||||
|
(requiring SchedulerPredictionType.VPrediction). It is
|
||||||
|
generally impossible to do this programmatically, so the
|
||||||
|
prediction_type_helper usually asks the user to choose.
|
||||||
|
|
||||||
|
The result is a set of successfully installed models. Each element
|
||||||
|
of the set is a dict corresponding to the newly-created OmegaConf stanza for
|
||||||
|
that model.
|
||||||
|
'''
|
||||||
|
return self.mgr.heuristic_import(items_to_import, prediction_type_helper)
|
||||||
|
@ -76,6 +76,10 @@ class MigrateTo3(object):
|
|||||||
Create a unique name for a model for use within models.yaml.
|
Create a unique name for a model for use within models.yaml.
|
||||||
'''
|
'''
|
||||||
done = False
|
done = False
|
||||||
|
|
||||||
|
# some model names have slashes in them, which really screws things up
|
||||||
|
name = name.replace('/','_')
|
||||||
|
|
||||||
key = ModelManager.create_key(name,info.base_type,info.model_type)
|
key = ModelManager.create_key(name,info.base_type,info.model_type)
|
||||||
unique_name = key
|
unique_name = key
|
||||||
counter = 1
|
counter = 1
|
||||||
@ -219,11 +223,11 @@ class MigrateTo3(object):
|
|||||||
repo_id = 'openai/clip-vit-large-patch14'
|
repo_id = 'openai/clip-vit-large-patch14'
|
||||||
self._migrate_pretrained(CLIPTokenizer,
|
self._migrate_pretrained(CLIPTokenizer,
|
||||||
repo_id= repo_id,
|
repo_id= repo_id,
|
||||||
dest= target_dir / 'clip-vit-large-patch14' / 'tokenizer',
|
dest= target_dir / 'clip-vit-large-patch14',
|
||||||
**kwargs)
|
**kwargs)
|
||||||
self._migrate_pretrained(CLIPTextModel,
|
self._migrate_pretrained(CLIPTextModel,
|
||||||
repo_id = repo_id,
|
repo_id = repo_id,
|
||||||
dest = target_dir / 'clip-vit-large-patch14' / 'text_encoder',
|
dest = target_dir / 'clip-vit-large-patch14',
|
||||||
**kwargs)
|
**kwargs)
|
||||||
|
|
||||||
# sd-2
|
# sd-2
|
||||||
|
@ -18,7 +18,7 @@ from tqdm import tqdm
|
|||||||
import invokeai.configs as configs
|
import invokeai.configs as configs
|
||||||
|
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
from invokeai.backend.model_management import ModelManager, ModelType, BaseModelType, ModelVariantType
|
from invokeai.backend.model_management import ModelManager, ModelType, BaseModelType, ModelVariantType, AddModelResult
|
||||||
from invokeai.backend.model_management.model_probe import ModelProbe, SchedulerPredictionType, ModelProbeInfo
|
from invokeai.backend.model_management.model_probe import ModelProbe, SchedulerPredictionType, ModelProbeInfo
|
||||||
from invokeai.backend.util import download_with_resume
|
from invokeai.backend.util import download_with_resume
|
||||||
from ..util.logging import InvokeAILogger
|
from ..util.logging import InvokeAILogger
|
||||||
@ -166,17 +166,22 @@ class ModelInstall(object):
|
|||||||
# add requested models
|
# add requested models
|
||||||
for path in selections.install_models:
|
for path in selections.install_models:
|
||||||
logger.info(f'Installing {path} [{job}/{jobs}]')
|
logger.info(f'Installing {path} [{job}/{jobs}]')
|
||||||
self.heuristic_install(path)
|
self.heuristic_import(path)
|
||||||
job += 1
|
job += 1
|
||||||
|
|
||||||
self.mgr.commit()
|
self.mgr.commit()
|
||||||
|
|
||||||
def heuristic_install(self,
|
def heuristic_import(self,
|
||||||
model_path_id_or_url: Union[str,Path],
|
model_path_id_or_url: Union[str,Path],
|
||||||
models_installed: Set[Path]=None)->Set[Path]:
|
models_installed: Set[Path]=None)->Dict[str, AddModelResult]:
|
||||||
|
'''
|
||||||
|
:param model_path_id_or_url: A Path to a local model to import, or a string representing its repo_id or URL
|
||||||
|
:param models_installed: Set of installed models, used for recursive invocation
|
||||||
|
Returns a set of dict objects corresponding to newly-created stanzas in models.yaml.
|
||||||
|
'''
|
||||||
|
|
||||||
if not models_installed:
|
if not models_installed:
|
||||||
models_installed = set()
|
models_installed = dict()
|
||||||
|
|
||||||
# A little hack to allow nested routines to retrieve info on the requested ID
|
# A little hack to allow nested routines to retrieve info on the requested ID
|
||||||
self.current_id = model_path_id_or_url
|
self.current_id = model_path_id_or_url
|
||||||
@ -185,24 +190,24 @@ class ModelInstall(object):
|
|||||||
try:
|
try:
|
||||||
# checkpoint file, or similar
|
# checkpoint file, or similar
|
||||||
if path.is_file():
|
if path.is_file():
|
||||||
models_installed.add(self._install_path(path))
|
models_installed.update(self._install_path(path))
|
||||||
|
|
||||||
# folders style or similar
|
# folders style or similar
|
||||||
elif path.is_dir() and any([(path/x).exists() for x in {'config.json','model_index.json','learned_embeds.bin'}]):
|
elif path.is_dir() and any([(path/x).exists() for x in {'config.json','model_index.json','learned_embeds.bin'}]):
|
||||||
models_installed.add(self._install_path(path))
|
models_installed.update(self._install_path(path))
|
||||||
|
|
||||||
# recursive scan
|
# recursive scan
|
||||||
elif path.is_dir():
|
elif path.is_dir():
|
||||||
for child in path.iterdir():
|
for child in path.iterdir():
|
||||||
self.heuristic_install(child, models_installed=models_installed)
|
self.heuristic_import(child, models_installed=models_installed)
|
||||||
|
|
||||||
# huggingface repo
|
# huggingface repo
|
||||||
elif len(str(path).split('/')) == 2:
|
elif len(str(path).split('/')) == 2:
|
||||||
models_installed.add(self._install_repo(str(path)))
|
models_installed.update(self._install_repo(str(path)))
|
||||||
|
|
||||||
# a URL
|
# a URL
|
||||||
elif model_path_id_or_url.startswith(("http:", "https:", "ftp:")):
|
elif model_path_id_or_url.startswith(("http:", "https:", "ftp:")):
|
||||||
models_installed.add(self._install_url(model_path_id_or_url))
|
models_installed.update(self._install_url(model_path_id_or_url))
|
||||||
|
|
||||||
else:
|
else:
|
||||||
logger.warning(f'{str(model_path_id_or_url)} is not recognized as a local path, repo ID or URL. Skipping')
|
logger.warning(f'{str(model_path_id_or_url)} is not recognized as a local path, repo ID or URL. Skipping')
|
||||||
@ -214,24 +219,25 @@ class ModelInstall(object):
|
|||||||
|
|
||||||
# install a model from a local path. The optional info parameter is there to prevent
|
# install a model from a local path. The optional info parameter is there to prevent
|
||||||
# the model from being probed twice in the event that it has already been probed.
|
# the model from being probed twice in the event that it has already been probed.
|
||||||
def _install_path(self, path: Path, info: ModelProbeInfo=None)->Path:
|
def _install_path(self, path: Path, info: ModelProbeInfo=None)->Dict[str, AddModelResult]:
|
||||||
try:
|
try:
|
||||||
# logger.debug(f'Probing {path}')
|
model_result = None
|
||||||
info = info or ModelProbe().heuristic_probe(path,self.prediction_helper)
|
info = info or ModelProbe().heuristic_probe(path,self.prediction_helper)
|
||||||
model_name = path.stem if info.format=='checkpoint' else path.name
|
model_name = path.stem if path.is_file() else path.name
|
||||||
if self.mgr.model_exists(model_name, info.base_type, info.model_type):
|
if self.mgr.model_exists(model_name, info.base_type, info.model_type):
|
||||||
raise ValueError(f'A model named "{model_name}" is already installed.')
|
raise ValueError(f'A model named "{model_name}" is already installed.')
|
||||||
attributes = self._make_attributes(path,info)
|
attributes = self._make_attributes(path,info)
|
||||||
self.mgr.add_model(model_name = model_name,
|
model_result = self.mgr.add_model(model_name = model_name,
|
||||||
base_model = info.base_type,
|
base_model = info.base_type,
|
||||||
model_type = info.model_type,
|
model_type = info.model_type,
|
||||||
model_attributes = attributes,
|
model_attributes = attributes,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f'{str(e)} Skipping registration.')
|
logger.warning(f'{str(e)} Skipping registration.')
|
||||||
return path
|
return {}
|
||||||
|
return {str(path): model_result}
|
||||||
|
|
||||||
def _install_url(self, url: str)->Path:
|
def _install_url(self, url: str)->dict:
|
||||||
# copy to a staging area, probe, import and delete
|
# copy to a staging area, probe, import and delete
|
||||||
with TemporaryDirectory(dir=self.config.models_path) as staging:
|
with TemporaryDirectory(dir=self.config.models_path) as staging:
|
||||||
location = download_with_resume(url,Path(staging))
|
location = download_with_resume(url,Path(staging))
|
||||||
@ -244,7 +250,7 @@ class ModelInstall(object):
|
|||||||
# staged version will be garbage-collected at this time
|
# staged version will be garbage-collected at this time
|
||||||
return self._install_path(Path(models_path), info)
|
return self._install_path(Path(models_path), info)
|
||||||
|
|
||||||
def _install_repo(self, repo_id: str)->Path:
|
def _install_repo(self, repo_id: str)->dict:
|
||||||
hinfo = HfApi().model_info(repo_id)
|
hinfo = HfApi().model_info(repo_id)
|
||||||
|
|
||||||
# we try to figure out how to download this most economically
|
# we try to figure out how to download this most economically
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
"""
|
"""
|
||||||
Initialization file for invokeai.backend.model_management
|
Initialization file for invokeai.backend.model_management
|
||||||
"""
|
"""
|
||||||
from .model_manager import ModelManager, ModelInfo
|
from .model_manager import ModelManager, ModelInfo, AddModelResult
|
||||||
from .model_cache import ModelCache
|
from .model_cache import ModelCache
|
||||||
from .models import BaseModelType, ModelType, SubModelType, ModelVariantType
|
from .models import BaseModelType, ModelType, SubModelType, ModelVariantType
|
||||||
|
|
||||||
|
@ -29,7 +29,7 @@ import invokeai.backend.util.logging as logger
|
|||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
|
|
||||||
from .model_manager import ModelManager
|
from .model_manager import ModelManager
|
||||||
from .model_cache import ModelCache
|
from picklescan.scanner import scan_file_path
|
||||||
from .models import BaseModelType, ModelVariantType
|
from .models import BaseModelType, ModelVariantType
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -1014,7 +1014,10 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
|
|||||||
checkpoint = load_file(checkpoint_path)
|
checkpoint = load_file(checkpoint_path)
|
||||||
else:
|
else:
|
||||||
if scan_needed:
|
if scan_needed:
|
||||||
ModelCache.scan_model(checkpoint_path, checkpoint_path)
|
# scan model
|
||||||
|
scan_result = scan_file_path(checkpoint_path)
|
||||||
|
if scan_result.infected_files != 0:
|
||||||
|
raise "The model {checkpoint_path} is potentially infected by malware. Aborting import."
|
||||||
checkpoint = torch.load(checkpoint_path)
|
checkpoint = torch.load(checkpoint_path)
|
||||||
|
|
||||||
# sometimes there is a state_dict key and sometimes not
|
# sometimes there is a state_dict key and sometimes not
|
||||||
|
@ -1,18 +1,17 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
from pathlib import Path
|
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import Optional, Dict, Tuple, Any
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from compel.embeddings_provider import BaseTextualInversionManager
|
||||||
|
from diffusers.models import UNet2DConditionModel
|
||||||
from safetensors.torch import load_file
|
from safetensors.torch import load_file
|
||||||
from torch.utils.hooks import RemovableHandle
|
from torch.utils.hooks import RemovableHandle
|
||||||
|
|
||||||
from diffusers.models import UNet2DConditionModel
|
|
||||||
from transformers import CLIPTextModel
|
from transformers import CLIPTextModel
|
||||||
|
|
||||||
from compel.embeddings_provider import BaseTextualInversionManager
|
|
||||||
|
|
||||||
class LoRALayerBase:
|
class LoRALayerBase:
|
||||||
#rank: Optional[int]
|
#rank: Optional[int]
|
||||||
@ -539,9 +538,10 @@ class ModelPatcher:
|
|||||||
original_weights[module_key] = module.weight.detach().to(device="cpu", copy=True)
|
original_weights[module_key] = module.weight.detach().to(device="cpu", copy=True)
|
||||||
|
|
||||||
# enable autocast to calc fp16 loras on cpu
|
# enable autocast to calc fp16 loras on cpu
|
||||||
with torch.autocast(device_type="cpu"):
|
#with torch.autocast(device_type="cpu"):
|
||||||
layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0
|
layer.to(dtype=torch.float32)
|
||||||
layer_weight = layer.get_weight() * lora_weight * layer_scale
|
layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0
|
||||||
|
layer_weight = layer.get_weight() * lora_weight * layer_scale
|
||||||
|
|
||||||
if module.weight.shape != layer_weight.shape:
|
if module.weight.shape != layer_weight.shape:
|
||||||
# TODO: debug on lycoris
|
# TODO: debug on lycoris
|
||||||
@ -655,6 +655,9 @@ class TextualInversionModel:
|
|||||||
else:
|
else:
|
||||||
result.embedding = next(iter(state_dict.values()))
|
result.embedding = next(iter(state_dict.values()))
|
||||||
|
|
||||||
|
if len(result.embedding.shape) == 1:
|
||||||
|
result.embedding = result.embedding.unsqueeze(0)
|
||||||
|
|
||||||
if not isinstance(result.embedding, torch.Tensor):
|
if not isinstance(result.embedding, torch.Tensor):
|
||||||
raise ValueError(f"Invalid embeddings file: {file_path.name}")
|
raise ValueError(f"Invalid embeddings file: {file_path.name}")
|
||||||
|
|
||||||
|
@ -233,14 +233,14 @@ import hashlib
|
|||||||
import textwrap
|
import textwrap
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional, List, Tuple, Union, Set, Callable, types
|
from typing import Optional, List, Tuple, Union, Dict, Set, Callable, types
|
||||||
from shutil import rmtree
|
from shutil import rmtree
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
from omegaconf.dictconfig import DictConfig
|
from omegaconf.dictconfig import DictConfig
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
@ -249,7 +249,7 @@ from .model_cache import ModelCache, ModelLocker
|
|||||||
from .models import (
|
from .models import (
|
||||||
BaseModelType, ModelType, SubModelType,
|
BaseModelType, ModelType, SubModelType,
|
||||||
ModelError, SchedulerPredictionType, MODEL_CLASSES,
|
ModelError, SchedulerPredictionType, MODEL_CLASSES,
|
||||||
ModelConfigBase,
|
ModelConfigBase, ModelNotFoundException,
|
||||||
)
|
)
|
||||||
|
|
||||||
# We are only starting to number the config file with release 3.
|
# We are only starting to number the config file with release 3.
|
||||||
@ -278,8 +278,13 @@ class InvalidModelError(Exception):
|
|||||||
"Raised when an invalid model is requested"
|
"Raised when an invalid model is requested"
|
||||||
pass
|
pass
|
||||||
|
|
||||||
MAX_CACHE_SIZE = 6.0 # GB
|
class AddModelResult(BaseModel):
|
||||||
|
name: str = Field(description="The name of the model after import")
|
||||||
|
model_type: ModelType = Field(description="The type of model")
|
||||||
|
base_model: BaseModelType = Field(description="The base model")
|
||||||
|
config: ModelConfigBase = Field(description="The configuration of the model")
|
||||||
|
|
||||||
|
MAX_CACHE_SIZE = 6.0 # GB
|
||||||
|
|
||||||
class ConfigMeta(BaseModel):
|
class ConfigMeta(BaseModel):
|
||||||
version: str
|
version: str
|
||||||
@ -404,7 +409,7 @@ class ModelManager(object):
|
|||||||
if model_key not in self.models:
|
if model_key not in self.models:
|
||||||
self.scan_models_directory(base_model=base_model, model_type=model_type)
|
self.scan_models_directory(base_model=base_model, model_type=model_type)
|
||||||
if model_key not in self.models:
|
if model_key not in self.models:
|
||||||
raise Exception(f"Model not found - {model_key}")
|
raise ModelNotFoundException(f"Model not found - {model_key}")
|
||||||
|
|
||||||
model_config = self.models[model_key]
|
model_config = self.models[model_key]
|
||||||
model_path = self.app_config.root_path / model_config.path
|
model_path = self.app_config.root_path / model_config.path
|
||||||
@ -416,7 +421,7 @@ class ModelManager(object):
|
|||||||
|
|
||||||
else:
|
else:
|
||||||
self.models.pop(model_key, None)
|
self.models.pop(model_key, None)
|
||||||
raise Exception(f"Model not found - {model_key}")
|
raise ModelNotFoundException(f"Model not found - {model_key}")
|
||||||
|
|
||||||
# vae/movq override
|
# vae/movq override
|
||||||
# TODO:
|
# TODO:
|
||||||
@ -571,13 +576,16 @@ class ModelManager(object):
|
|||||||
model_type: ModelType,
|
model_type: ModelType,
|
||||||
model_attributes: dict,
|
model_attributes: dict,
|
||||||
clobber: bool = False,
|
clobber: bool = False,
|
||||||
) -> None:
|
) -> AddModelResult:
|
||||||
"""
|
"""
|
||||||
Update the named model with a dictionary of attributes. Will fail with an
|
Update the named model with a dictionary of attributes. Will fail with an
|
||||||
assertion error if the name already exists. Pass clobber=True to overwrite.
|
assertion error if the name already exists. Pass clobber=True to overwrite.
|
||||||
On a successful update, the config will be changed in memory and the
|
On a successful update, the config will be changed in memory and the
|
||||||
method will return True. Will fail with an assertion error if provided
|
method will return True. Will fail with an assertion error if provided
|
||||||
attributes are incorrect or the model name is missing.
|
attributes are incorrect or the model name is missing.
|
||||||
|
|
||||||
|
The returned dict has the same format as the dict returned by
|
||||||
|
model_info().
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model_class = MODEL_CLASSES[base_model][model_type]
|
model_class = MODEL_CLASSES[base_model][model_type]
|
||||||
@ -601,12 +609,18 @@ class ModelManager(object):
|
|||||||
old_model_cache.unlink()
|
old_model_cache.unlink()
|
||||||
|
|
||||||
# remove in-memory cache
|
# remove in-memory cache
|
||||||
# note: it not garantie to release memory(model can has other references)
|
# note: it not guaranteed to release memory(model can has other references)
|
||||||
cache_ids = self.cache_keys.pop(model_key, [])
|
cache_ids = self.cache_keys.pop(model_key, [])
|
||||||
for cache_id in cache_ids:
|
for cache_id in cache_ids:
|
||||||
self.cache.uncache_model(cache_id)
|
self.cache.uncache_model(cache_id)
|
||||||
|
|
||||||
self.models[model_key] = model_config
|
self.models[model_key] = model_config
|
||||||
|
return AddModelResult(
|
||||||
|
name = model_name,
|
||||||
|
model_type = model_type,
|
||||||
|
base_model = base_model,
|
||||||
|
config = model_config,
|
||||||
|
)
|
||||||
|
|
||||||
def search_models(self, search_folder):
|
def search_models(self, search_folder):
|
||||||
self.logger.info(f"Finding Models In: {search_folder}")
|
self.logger.info(f"Finding Models In: {search_folder}")
|
||||||
@ -717,19 +731,19 @@ class ModelManager(object):
|
|||||||
|
|
||||||
if model_path.is_relative_to(self.app_config.root_path):
|
if model_path.is_relative_to(self.app_config.root_path):
|
||||||
model_path = model_path.relative_to(self.app_config.root_path)
|
model_path = model_path.relative_to(self.app_config.root_path)
|
||||||
try:
|
try:
|
||||||
model_config: ModelConfigBase = model_class.probe_config(str(model_path))
|
model_config: ModelConfigBase = model_class.probe_config(str(model_path))
|
||||||
self.models[model_key] = model_config
|
self.models[model_key] = model_config
|
||||||
new_models_found = True
|
new_models_found = True
|
||||||
except NotImplementedError as e:
|
except NotImplementedError as e:
|
||||||
self.logger.warning(e)
|
self.logger.warning(e)
|
||||||
|
|
||||||
imported_models = self.autoimport()
|
imported_models = self.autoimport()
|
||||||
|
|
||||||
if (new_models_found or imported_models) and self.config_path:
|
if (new_models_found or imported_models) and self.config_path:
|
||||||
self.commit()
|
self.commit()
|
||||||
|
|
||||||
def autoimport(self)->set[Path]:
|
def autoimport(self)->Dict[str, AddModelResult]:
|
||||||
'''
|
'''
|
||||||
Scan the autoimport directory (if defined) and import new models, delete defunct models.
|
Scan the autoimport directory (if defined) and import new models, delete defunct models.
|
||||||
'''
|
'''
|
||||||
@ -742,7 +756,6 @@ class ModelManager(object):
|
|||||||
prediction_type_helper = ask_user_for_prediction_type,
|
prediction_type_helper = ask_user_for_prediction_type,
|
||||||
)
|
)
|
||||||
|
|
||||||
installed = set()
|
|
||||||
scanned_dirs = set()
|
scanned_dirs = set()
|
||||||
|
|
||||||
config = self.app_config
|
config = self.app_config
|
||||||
@ -756,13 +769,14 @@ class ModelManager(object):
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
self.logger.info(f'Scanning {autodir} for models to import')
|
self.logger.info(f'Scanning {autodir} for models to import')
|
||||||
|
installed = dict()
|
||||||
|
|
||||||
autodir = self.app_config.root_path / autodir
|
autodir = self.app_config.root_path / autodir
|
||||||
if not autodir.exists():
|
if not autodir.exists():
|
||||||
continue
|
continue
|
||||||
|
|
||||||
items_scanned = 0
|
items_scanned = 0
|
||||||
new_models_found = set()
|
new_models_found = dict()
|
||||||
|
|
||||||
for root, dirs, files in os.walk(autodir):
|
for root, dirs, files in os.walk(autodir):
|
||||||
items_scanned += len(dirs) + len(files)
|
items_scanned += len(dirs) + len(files)
|
||||||
@ -772,7 +786,7 @@ class ModelManager(object):
|
|||||||
scanned_dirs.add(path)
|
scanned_dirs.add(path)
|
||||||
continue
|
continue
|
||||||
if any([(path/x).exists() for x in {'config.json','model_index.json','learned_embeds.bin'}]):
|
if any([(path/x).exists() for x in {'config.json','model_index.json','learned_embeds.bin'}]):
|
||||||
new_models_found.update(installer.heuristic_install(path))
|
new_models_found.update(installer.heuristic_import(path))
|
||||||
scanned_dirs.add(path)
|
scanned_dirs.add(path)
|
||||||
|
|
||||||
for f in files:
|
for f in files:
|
||||||
@ -780,7 +794,7 @@ class ModelManager(object):
|
|||||||
if path in known_paths or path.parent in scanned_dirs:
|
if path in known_paths or path.parent in scanned_dirs:
|
||||||
continue
|
continue
|
||||||
if path.suffix in {'.ckpt','.bin','.pth','.safetensors','.pt'}:
|
if path.suffix in {'.ckpt','.bin','.pth','.safetensors','.pt'}:
|
||||||
new_models_found.update(installer.heuristic_install(path))
|
new_models_found.update(installer.heuristic_import(path))
|
||||||
|
|
||||||
self.logger.info(f'Scanned {items_scanned} files and directories, imported {len(new_models_found)} models')
|
self.logger.info(f'Scanned {items_scanned} files and directories, imported {len(new_models_found)} models')
|
||||||
installed.update(new_models_found)
|
installed.update(new_models_found)
|
||||||
@ -790,7 +804,7 @@ class ModelManager(object):
|
|||||||
def heuristic_import(self,
|
def heuristic_import(self,
|
||||||
items_to_import: Set[str],
|
items_to_import: Set[str],
|
||||||
prediction_type_helper: Callable[[Path],SchedulerPredictionType]=None,
|
prediction_type_helper: Callable[[Path],SchedulerPredictionType]=None,
|
||||||
)->Set[str]:
|
)->Dict[str, AddModelResult]:
|
||||||
'''Import a list of paths, repo_ids or URLs. Returns the set of
|
'''Import a list of paths, repo_ids or URLs. Returns the set of
|
||||||
successfully imported items.
|
successfully imported items.
|
||||||
:param items_to_import: Set of strings corresponding to models to be imported.
|
:param items_to_import: Set of strings corresponding to models to be imported.
|
||||||
@ -803,17 +817,20 @@ class ModelManager(object):
|
|||||||
generally impossible to do this programmatically, so the
|
generally impossible to do this programmatically, so the
|
||||||
prediction_type_helper usually asks the user to choose.
|
prediction_type_helper usually asks the user to choose.
|
||||||
|
|
||||||
|
The result is a set of successfully installed models. Each element
|
||||||
|
of the set is a dict corresponding to the newly-created OmegaConf stanza for
|
||||||
|
that model.
|
||||||
'''
|
'''
|
||||||
# avoid circular import here
|
# avoid circular import here
|
||||||
from invokeai.backend.install.model_install_backend import ModelInstall
|
from invokeai.backend.install.model_install_backend import ModelInstall
|
||||||
successfully_installed = set()
|
successfully_installed = dict()
|
||||||
|
|
||||||
installer = ModelInstall(config = self.app_config,
|
installer = ModelInstall(config = self.app_config,
|
||||||
prediction_type_helper = prediction_type_helper,
|
prediction_type_helper = prediction_type_helper,
|
||||||
model_manager = self)
|
model_manager = self)
|
||||||
for thing in items_to_import:
|
for thing in items_to_import:
|
||||||
try:
|
try:
|
||||||
installed = installer.heuristic_install(thing)
|
installed = installer.heuristic_import(thing)
|
||||||
successfully_installed.update(installed)
|
successfully_installed.update(installed)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.logger.warning(f'{thing} could not be imported: {str(e)}')
|
self.logger.warning(f'{thing} could not be imported: {str(e)}')
|
||||||
|
@ -2,7 +2,7 @@ import inspect
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from typing import Literal, get_origin
|
from typing import Literal, get_origin
|
||||||
from .base import BaseModelType, ModelType, SubModelType, ModelBase, ModelConfigBase, ModelVariantType, SchedulerPredictionType, ModelError, SilenceWarnings
|
from .base import BaseModelType, ModelType, SubModelType, ModelBase, ModelConfigBase, ModelVariantType, SchedulerPredictionType, ModelError, SilenceWarnings, ModelNotFoundException
|
||||||
from .stable_diffusion import StableDiffusion1Model, StableDiffusion2Model
|
from .stable_diffusion import StableDiffusion1Model, StableDiffusion2Model
|
||||||
from .vae import VaeModel
|
from .vae import VaeModel
|
||||||
from .lora import LoRAModel
|
from .lora import LoRAModel
|
||||||
|
@ -15,6 +15,9 @@ from contextlib import suppress
|
|||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from typing import List, Dict, Optional, Type, Literal, TypeVar, Generic, Callable, Any, Union
|
from typing import List, Dict, Optional, Type, Literal, TypeVar, Generic, Callable, Any, Union
|
||||||
|
|
||||||
|
class ModelNotFoundException(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
class BaseModelType(str, Enum):
|
class BaseModelType(str, Enum):
|
||||||
StableDiffusion1 = "sd-1"
|
StableDiffusion1 = "sd-1"
|
||||||
StableDiffusion2 = "sd-2"
|
StableDiffusion2 = "sd-2"
|
||||||
|
@ -8,6 +8,7 @@ from .base import (
|
|||||||
ModelType,
|
ModelType,
|
||||||
SubModelType,
|
SubModelType,
|
||||||
classproperty,
|
classproperty,
|
||||||
|
ModelNotFoundException,
|
||||||
)
|
)
|
||||||
# TODO: naming
|
# TODO: naming
|
||||||
from ..lora import TextualInversionModel as TextualInversionModelRaw
|
from ..lora import TextualInversionModel as TextualInversionModelRaw
|
||||||
@ -37,8 +38,15 @@ class TextualInversionModel(ModelBase):
|
|||||||
if child_type is not None:
|
if child_type is not None:
|
||||||
raise Exception("There is no child models in textual inversion")
|
raise Exception("There is no child models in textual inversion")
|
||||||
|
|
||||||
|
checkpoint_path = self.model_path
|
||||||
|
if os.path.isdir(checkpoint_path):
|
||||||
|
checkpoint_path = os.path.join(checkpoint_path, "learned_embeds.bin")
|
||||||
|
|
||||||
|
if not os.path.exists(checkpoint_path):
|
||||||
|
raise ModelNotFoundException()
|
||||||
|
|
||||||
model = TextualInversionModelRaw.from_checkpoint(
|
model = TextualInversionModelRaw.from_checkpoint(
|
||||||
file_path=self.model_path,
|
file_path=checkpoint_path,
|
||||||
dtype=torch_dtype,
|
dtype=torch_dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
2
invokeai/frontend/web/dist/index.html
vendored
2
invokeai/frontend/web/dist/index.html
vendored
@ -12,7 +12,7 @@
|
|||||||
margin: 0;
|
margin: 0;
|
||||||
}
|
}
|
||||||
</style>
|
</style>
|
||||||
<script type="module" crossorigin src="./assets/index-8a3e9251.js"></script>
|
<script type="module" crossorigin src="./assets/index-c0367e37.js"></script>
|
||||||
</head>
|
</head>
|
||||||
|
|
||||||
<body dir="ltr">
|
<body dir="ltr">
|
||||||
|
17
invokeai/frontend/web/dist/locales/en.json
vendored
17
invokeai/frontend/web/dist/locales/en.json
vendored
@ -24,16 +24,13 @@
|
|||||||
},
|
},
|
||||||
"common": {
|
"common": {
|
||||||
"hotkeysLabel": "Hotkeys",
|
"hotkeysLabel": "Hotkeys",
|
||||||
"themeLabel": "Theme",
|
"darkMode": "Dark Mode",
|
||||||
|
"lightMode": "Light Mode",
|
||||||
"languagePickerLabel": "Language",
|
"languagePickerLabel": "Language",
|
||||||
"reportBugLabel": "Report Bug",
|
"reportBugLabel": "Report Bug",
|
||||||
"githubLabel": "Github",
|
"githubLabel": "Github",
|
||||||
"discordLabel": "Discord",
|
"discordLabel": "Discord",
|
||||||
"settingsLabel": "Settings",
|
"settingsLabel": "Settings",
|
||||||
"darkTheme": "Dark",
|
|
||||||
"lightTheme": "Light",
|
|
||||||
"greenTheme": "Green",
|
|
||||||
"oceanTheme": "Ocean",
|
|
||||||
"langArabic": "العربية",
|
"langArabic": "العربية",
|
||||||
"langEnglish": "English",
|
"langEnglish": "English",
|
||||||
"langDutch": "Nederlands",
|
"langDutch": "Nederlands",
|
||||||
@ -55,6 +52,7 @@
|
|||||||
"unifiedCanvas": "Unified Canvas",
|
"unifiedCanvas": "Unified Canvas",
|
||||||
"linear": "Linear",
|
"linear": "Linear",
|
||||||
"nodes": "Node Editor",
|
"nodes": "Node Editor",
|
||||||
|
"modelmanager": "Model Manager",
|
||||||
"postprocessing": "Post Processing",
|
"postprocessing": "Post Processing",
|
||||||
"nodesDesc": "A node based system for the generation of images is under development currently. Stay tuned for updates about this amazing feature.",
|
"nodesDesc": "A node based system for the generation of images is under development currently. Stay tuned for updates about this amazing feature.",
|
||||||
"postProcessing": "Post Processing",
|
"postProcessing": "Post Processing",
|
||||||
@ -336,6 +334,7 @@
|
|||||||
"modelManager": {
|
"modelManager": {
|
||||||
"modelManager": "Model Manager",
|
"modelManager": "Model Manager",
|
||||||
"model": "Model",
|
"model": "Model",
|
||||||
|
"vae": "VAE",
|
||||||
"allModels": "All Models",
|
"allModels": "All Models",
|
||||||
"checkpointModels": "Checkpoints",
|
"checkpointModels": "Checkpoints",
|
||||||
"diffusersModels": "Diffusers",
|
"diffusersModels": "Diffusers",
|
||||||
@ -351,6 +350,7 @@
|
|||||||
"scanForModels": "Scan For Models",
|
"scanForModels": "Scan For Models",
|
||||||
"addManually": "Add Manually",
|
"addManually": "Add Manually",
|
||||||
"manual": "Manual",
|
"manual": "Manual",
|
||||||
|
"baseModel": "Base Model",
|
||||||
"name": "Name",
|
"name": "Name",
|
||||||
"nameValidationMsg": "Enter a name for your model",
|
"nameValidationMsg": "Enter a name for your model",
|
||||||
"description": "Description",
|
"description": "Description",
|
||||||
@ -363,6 +363,7 @@
|
|||||||
"repoIDValidationMsg": "Online repository of your model",
|
"repoIDValidationMsg": "Online repository of your model",
|
||||||
"vaeLocation": "VAE Location",
|
"vaeLocation": "VAE Location",
|
||||||
"vaeLocationValidationMsg": "Path to where your VAE is located.",
|
"vaeLocationValidationMsg": "Path to where your VAE is located.",
|
||||||
|
"variant": "Variant",
|
||||||
"vaeRepoID": "VAE Repo ID",
|
"vaeRepoID": "VAE Repo ID",
|
||||||
"vaeRepoIDValidationMsg": "Online repository of your VAE",
|
"vaeRepoIDValidationMsg": "Online repository of your VAE",
|
||||||
"width": "Width",
|
"width": "Width",
|
||||||
@ -524,7 +525,8 @@
|
|||||||
"initialImage": "Initial Image",
|
"initialImage": "Initial Image",
|
||||||
"showOptionsPanel": "Show Options Panel",
|
"showOptionsPanel": "Show Options Panel",
|
||||||
"hidePreview": "Hide Preview",
|
"hidePreview": "Hide Preview",
|
||||||
"showPreview": "Show Preview"
|
"showPreview": "Show Preview",
|
||||||
|
"controlNetControlMode": "Control Mode"
|
||||||
},
|
},
|
||||||
"settings": {
|
"settings": {
|
||||||
"models": "Models",
|
"models": "Models",
|
||||||
@ -547,7 +549,8 @@
|
|||||||
"general": "General",
|
"general": "General",
|
||||||
"generation": "Generation",
|
"generation": "Generation",
|
||||||
"ui": "User Interface",
|
"ui": "User Interface",
|
||||||
"availableSchedulers": "Available Schedulers"
|
"favoriteSchedulers": "Favorite Schedulers",
|
||||||
|
"favoriteSchedulersPlaceholder": "No schedulers favorited"
|
||||||
},
|
},
|
||||||
"toast": {
|
"toast": {
|
||||||
"serverError": "Server Error",
|
"serverError": "Server Error",
|
||||||
|
@ -67,6 +67,7 @@
|
|||||||
"@fontsource-variable/inter": "^5.0.3",
|
"@fontsource-variable/inter": "^5.0.3",
|
||||||
"@fontsource/inter": "^5.0.3",
|
"@fontsource/inter": "^5.0.3",
|
||||||
"@mantine/core": "^6.0.14",
|
"@mantine/core": "^6.0.14",
|
||||||
|
"@mantine/form": "^6.0.15",
|
||||||
"@mantine/hooks": "^6.0.14",
|
"@mantine/hooks": "^6.0.14",
|
||||||
"@reduxjs/toolkit": "^1.9.5",
|
"@reduxjs/toolkit": "^1.9.5",
|
||||||
"@roarr/browser-log-writer": "^1.1.5",
|
"@roarr/browser-log-writer": "^1.1.5",
|
||||||
|
@ -53,6 +53,7 @@
|
|||||||
"linear": "Linear",
|
"linear": "Linear",
|
||||||
"nodes": "Node Editor",
|
"nodes": "Node Editor",
|
||||||
"batch": "Batch Manager",
|
"batch": "Batch Manager",
|
||||||
|
"modelmanager": "Model Manager",
|
||||||
"postprocessing": "Post Processing",
|
"postprocessing": "Post Processing",
|
||||||
"nodesDesc": "A node based system for the generation of images is under development currently. Stay tuned for updates about this amazing feature.",
|
"nodesDesc": "A node based system for the generation of images is under development currently. Stay tuned for updates about this amazing feature.",
|
||||||
"postProcessing": "Post Processing",
|
"postProcessing": "Post Processing",
|
||||||
@ -334,6 +335,7 @@
|
|||||||
"modelManager": {
|
"modelManager": {
|
||||||
"modelManager": "Model Manager",
|
"modelManager": "Model Manager",
|
||||||
"model": "Model",
|
"model": "Model",
|
||||||
|
"vae": "VAE",
|
||||||
"allModels": "All Models",
|
"allModels": "All Models",
|
||||||
"checkpointModels": "Checkpoints",
|
"checkpointModels": "Checkpoints",
|
||||||
"diffusersModels": "Diffusers",
|
"diffusersModels": "Diffusers",
|
||||||
@ -349,6 +351,7 @@
|
|||||||
"scanForModels": "Scan For Models",
|
"scanForModels": "Scan For Models",
|
||||||
"addManually": "Add Manually",
|
"addManually": "Add Manually",
|
||||||
"manual": "Manual",
|
"manual": "Manual",
|
||||||
|
"baseModel": "Base Model",
|
||||||
"name": "Name",
|
"name": "Name",
|
||||||
"nameValidationMsg": "Enter a name for your model",
|
"nameValidationMsg": "Enter a name for your model",
|
||||||
"description": "Description",
|
"description": "Description",
|
||||||
@ -361,6 +364,7 @@
|
|||||||
"repoIDValidationMsg": "Online repository of your model",
|
"repoIDValidationMsg": "Online repository of your model",
|
||||||
"vaeLocation": "VAE Location",
|
"vaeLocation": "VAE Location",
|
||||||
"vaeLocationValidationMsg": "Path to where your VAE is located.",
|
"vaeLocationValidationMsg": "Path to where your VAE is located.",
|
||||||
|
"variant": "Variant",
|
||||||
"vaeRepoID": "VAE Repo ID",
|
"vaeRepoID": "VAE Repo ID",
|
||||||
"vaeRepoIDValidationMsg": "Online repository of your VAE",
|
"vaeRepoIDValidationMsg": "Online repository of your VAE",
|
||||||
"width": "Width",
|
"width": "Width",
|
||||||
|
@ -4,6 +4,7 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
|||||||
import { PartialAppConfig } from 'app/types/invokeai';
|
import { PartialAppConfig } from 'app/types/invokeai';
|
||||||
import ImageUploader from 'common/components/ImageUploader';
|
import ImageUploader from 'common/components/ImageUploader';
|
||||||
import GalleryDrawer from 'features/gallery/components/GalleryPanel';
|
import GalleryDrawer from 'features/gallery/components/GalleryPanel';
|
||||||
|
import DeleteImageModal from 'features/imageDeletion/components/DeleteImageModal';
|
||||||
import Lightbox from 'features/lightbox/components/Lightbox';
|
import Lightbox from 'features/lightbox/components/Lightbox';
|
||||||
import SiteHeader from 'features/system/components/SiteHeader';
|
import SiteHeader from 'features/system/components/SiteHeader';
|
||||||
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
|
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
|
||||||
@ -15,11 +16,10 @@ import InvokeTabs from 'features/ui/components/InvokeTabs';
|
|||||||
import ParametersDrawer from 'features/ui/components/ParametersDrawer';
|
import ParametersDrawer from 'features/ui/components/ParametersDrawer';
|
||||||
import i18n from 'i18n';
|
import i18n from 'i18n';
|
||||||
import { ReactNode, memo, useEffect } from 'react';
|
import { ReactNode, memo, useEffect } from 'react';
|
||||||
|
import DeleteBoardImagesModal from '../../features/gallery/components/Boards/DeleteBoardImagesModal';
|
||||||
|
import UpdateImageBoardModal from '../../features/gallery/components/Boards/UpdateImageBoardModal';
|
||||||
import GlobalHotkeys from './GlobalHotkeys';
|
import GlobalHotkeys from './GlobalHotkeys';
|
||||||
import Toaster from './Toaster';
|
import Toaster from './Toaster';
|
||||||
import UpdateImageBoardModal from '../../features/gallery/components/Boards/UpdateImageBoardModal';
|
|
||||||
import DeleteBoardImagesModal from '../../features/gallery/components/Boards/DeleteBoardImagesModal';
|
|
||||||
import DeleteImageModal from 'features/imageDeletion/components/DeleteImageModal';
|
|
||||||
|
|
||||||
const DEFAULT_CONFIG = {};
|
const DEFAULT_CONFIG = {};
|
||||||
|
|
||||||
|
@ -1,4 +1,8 @@
|
|||||||
import { Box, ChakraProps, Flex, Heading, Image } from '@chakra-ui/react';
|
import { Box, ChakraProps, Flex, Heading, Image } from '@chakra-ui/react';
|
||||||
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
|
import { stateSelector } from 'app/store/store';
|
||||||
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||||
import { memo } from 'react';
|
import { memo } from 'react';
|
||||||
import { TypesafeDraggableData } from './typesafeDnd';
|
import { TypesafeDraggableData } from './typesafeDnd';
|
||||||
|
|
||||||
@ -28,7 +32,24 @@ const STYLES: ChakraProps['sx'] = {
|
|||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const selector = createSelector(
|
||||||
|
stateSelector,
|
||||||
|
(state) => {
|
||||||
|
const gallerySelectionCount = state.gallery.selection.length;
|
||||||
|
const batchSelectionCount = state.batch.selection.length;
|
||||||
|
|
||||||
|
return {
|
||||||
|
gallerySelectionCount,
|
||||||
|
batchSelectionCount,
|
||||||
|
};
|
||||||
|
},
|
||||||
|
defaultSelectorOptions
|
||||||
|
);
|
||||||
|
|
||||||
const DragPreview = (props: OverlayDragImageProps) => {
|
const DragPreview = (props: OverlayDragImageProps) => {
|
||||||
|
const { gallerySelectionCount, batchSelectionCount } =
|
||||||
|
useAppSelector(selector);
|
||||||
|
|
||||||
if (!props.dragData) {
|
if (!props.dragData) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@ -57,7 +78,7 @@ const DragPreview = (props: OverlayDragImageProps) => {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (props.dragData.payloadType === 'IMAGE_NAMES') {
|
if (props.dragData.payloadType === 'BATCH_SELECTION') {
|
||||||
return (
|
return (
|
||||||
<Flex
|
<Flex
|
||||||
sx={{
|
sx={{
|
||||||
@ -70,7 +91,26 @@ const DragPreview = (props: OverlayDragImageProps) => {
|
|||||||
...STYLES,
|
...STYLES,
|
||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
<Heading>{props.dragData.payload.imageNames.length}</Heading>
|
<Heading>{batchSelectionCount}</Heading>
|
||||||
|
<Heading size="sm">Images</Heading>
|
||||||
|
</Flex>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (props.dragData.payloadType === 'GALLERY_SELECTION') {
|
||||||
|
return (
|
||||||
|
<Flex
|
||||||
|
sx={{
|
||||||
|
cursor: 'none',
|
||||||
|
userSelect: 'none',
|
||||||
|
position: 'relative',
|
||||||
|
alignItems: 'center',
|
||||||
|
justifyContent: 'center',
|
||||||
|
flexDir: 'column',
|
||||||
|
...STYLES,
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<Heading>{gallerySelectionCount}</Heading>
|
||||||
<Heading size="sm">Images</Heading>
|
<Heading size="sm">Images</Heading>
|
||||||
</Flex>
|
</Flex>
|
||||||
);
|
);
|
||||||
|
@ -77,14 +77,18 @@ export type ImageDraggableData = BaseDragData & {
|
|||||||
payload: { imageDTO: ImageDTO };
|
payload: { imageDTO: ImageDTO };
|
||||||
};
|
};
|
||||||
|
|
||||||
export type ImageNamesDraggableData = BaseDragData & {
|
export type GallerySelectionDraggableData = BaseDragData & {
|
||||||
payloadType: 'IMAGE_NAMES';
|
payloadType: 'GALLERY_SELECTION';
|
||||||
payload: { imageNames: string[] };
|
};
|
||||||
|
|
||||||
|
export type BatchSelectionDraggableData = BaseDragData & {
|
||||||
|
payloadType: 'BATCH_SELECTION';
|
||||||
};
|
};
|
||||||
|
|
||||||
export type TypesafeDraggableData =
|
export type TypesafeDraggableData =
|
||||||
| ImageDraggableData
|
| ImageDraggableData
|
||||||
| ImageNamesDraggableData;
|
| GallerySelectionDraggableData
|
||||||
|
| BatchSelectionDraggableData;
|
||||||
|
|
||||||
interface UseDroppableTypesafeArguments
|
interface UseDroppableTypesafeArguments
|
||||||
extends Omit<UseDroppableArguments, 'data'> {
|
extends Omit<UseDroppableArguments, 'data'> {
|
||||||
@ -155,11 +159,13 @@ export const isValidDrop = (
|
|||||||
case 'SET_NODES_IMAGE':
|
case 'SET_NODES_IMAGE':
|
||||||
return payloadType === 'IMAGE_DTO';
|
return payloadType === 'IMAGE_DTO';
|
||||||
case 'SET_MULTI_NODES_IMAGE':
|
case 'SET_MULTI_NODES_IMAGE':
|
||||||
return payloadType === 'IMAGE_DTO' || 'IMAGE_NAMES';
|
return payloadType === 'IMAGE_DTO' || 'GALLERY_SELECTION';
|
||||||
case 'ADD_TO_BATCH':
|
case 'ADD_TO_BATCH':
|
||||||
return payloadType === 'IMAGE_DTO' || 'IMAGE_NAMES';
|
return payloadType === 'IMAGE_DTO' || 'GALLERY_SELECTION';
|
||||||
case 'MOVE_BOARD':
|
case 'MOVE_BOARD':
|
||||||
return payloadType === 'IMAGE_DTO' || 'IMAGE_NAMES';
|
return (
|
||||||
|
payloadType === 'IMAGE_DTO' || 'GALLERY_SELECTION' || 'BATCH_SELECTION'
|
||||||
|
);
|
||||||
default:
|
default:
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
@ -20,10 +20,8 @@ const serializationDenylist: {
|
|||||||
nodes: nodesPersistDenylist,
|
nodes: nodesPersistDenylist,
|
||||||
postprocessing: postprocessingPersistDenylist,
|
postprocessing: postprocessingPersistDenylist,
|
||||||
system: systemPersistDenylist,
|
system: systemPersistDenylist,
|
||||||
// config: configPersistDenyList,
|
|
||||||
ui: uiPersistDenylist,
|
ui: uiPersistDenylist,
|
||||||
controlNet: controlNetDenylist,
|
controlNet: controlNetDenylist,
|
||||||
// hotkeys: hotkeysPersistDenylist,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
export const serialize: SerializeFunction = (data, key) => {
|
export const serialize: SerializeFunction = (data, key) => {
|
||||||
|
@ -1,21 +1,21 @@
|
|||||||
import { startAppListening } from '..';
|
|
||||||
import { imageDeleted } from 'services/api/thunks/image';
|
|
||||||
import { log } from 'app/logging/useLogger';
|
import { log } from 'app/logging/useLogger';
|
||||||
import { clamp } from 'lodash-es';
|
|
||||||
import {
|
|
||||||
imageSelected,
|
|
||||||
imageRemoved,
|
|
||||||
selectImagesIds,
|
|
||||||
} from 'features/gallery/store/gallerySlice';
|
|
||||||
import { resetCanvas } from 'features/canvas/store/canvasSlice';
|
import { resetCanvas } from 'features/canvas/store/canvasSlice';
|
||||||
import { controlNetReset } from 'features/controlNet/store/controlNetSlice';
|
import { controlNetReset } from 'features/controlNet/store/controlNetSlice';
|
||||||
import { clearInitialImage } from 'features/parameters/store/generationSlice';
|
import {
|
||||||
import { nodeEditorReset } from 'features/nodes/store/nodesSlice';
|
imageRemoved,
|
||||||
import { api } from 'services/api';
|
imageSelected,
|
||||||
|
selectFilteredImages,
|
||||||
|
} from 'features/gallery/store/gallerySlice';
|
||||||
import {
|
import {
|
||||||
imageDeletionConfirmed,
|
imageDeletionConfirmed,
|
||||||
isModalOpenChanged,
|
isModalOpenChanged,
|
||||||
} from 'features/imageDeletion/store/imageDeletionSlice';
|
} from 'features/imageDeletion/store/imageDeletionSlice';
|
||||||
|
import { nodeEditorReset } from 'features/nodes/store/nodesSlice';
|
||||||
|
import { clearInitialImage } from 'features/parameters/store/generationSlice';
|
||||||
|
import { clamp } from 'lodash-es';
|
||||||
|
import { api } from 'services/api';
|
||||||
|
import { imageDeleted } from 'services/api/thunks/image';
|
||||||
|
import { startAppListening } from '..';
|
||||||
|
|
||||||
const moduleLog = log.child({ namespace: 'image' });
|
const moduleLog = log.child({ namespace: 'image' });
|
||||||
|
|
||||||
@ -37,7 +37,9 @@ export const addRequestedImageDeletionListener = () => {
|
|||||||
state.gallery.selection[state.gallery.selection.length - 1];
|
state.gallery.selection[state.gallery.selection.length - 1];
|
||||||
|
|
||||||
if (lastSelectedImage === image_name) {
|
if (lastSelectedImage === image_name) {
|
||||||
const ids = selectImagesIds(state);
|
const filteredImages = selectFilteredImages(state);
|
||||||
|
|
||||||
|
const ids = filteredImages.map((i) => i.image_name);
|
||||||
|
|
||||||
const deletedImageIndex = ids.findIndex(
|
const deletedImageIndex = ids.findIndex(
|
||||||
(result) => result.toString() === image_name
|
(result) => result.toString() === image_name
|
||||||
|
@ -1,24 +1,23 @@
|
|||||||
import { createAction } from '@reduxjs/toolkit';
|
import { createAction } from '@reduxjs/toolkit';
|
||||||
import { startAppListening } from '../';
|
|
||||||
import { log } from 'app/logging/useLogger';
|
|
||||||
import {
|
import {
|
||||||
TypesafeDraggableData,
|
TypesafeDraggableData,
|
||||||
TypesafeDroppableData,
|
TypesafeDroppableData,
|
||||||
} from 'app/components/ImageDnd/typesafeDnd';
|
} from 'app/components/ImageDnd/typesafeDnd';
|
||||||
import { imageSelected } from 'features/gallery/store/gallerySlice';
|
import { log } from 'app/logging/useLogger';
|
||||||
import { initialImageChanged } from 'features/parameters/store/generationSlice';
|
|
||||||
import {
|
import {
|
||||||
imageAddedToBatch,
|
imageAddedToBatch,
|
||||||
imagesAddedToBatch,
|
imagesAddedToBatch,
|
||||||
} from 'features/batch/store/batchSlice';
|
} from 'features/batch/store/batchSlice';
|
||||||
import { controlNetImageChanged } from 'features/controlNet/store/controlNetSlice';
|
|
||||||
import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice';
|
import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice';
|
||||||
|
import { controlNetImageChanged } from 'features/controlNet/store/controlNetSlice';
|
||||||
|
import { imageSelected } from 'features/gallery/store/gallerySlice';
|
||||||
import {
|
import {
|
||||||
fieldValueChanged,
|
fieldValueChanged,
|
||||||
imageCollectionFieldValueChanged,
|
imageCollectionFieldValueChanged,
|
||||||
} from 'features/nodes/store/nodesSlice';
|
} from 'features/nodes/store/nodesSlice';
|
||||||
import { boardsApi } from 'services/api/endpoints/boards';
|
import { initialImageChanged } from 'features/parameters/store/generationSlice';
|
||||||
import { boardImagesApi } from 'services/api/endpoints/boardImages';
|
import { boardImagesApi } from 'services/api/endpoints/boardImages';
|
||||||
|
import { startAppListening } from '../';
|
||||||
|
|
||||||
const moduleLog = log.child({ namespace: 'dnd' });
|
const moduleLog = log.child({ namespace: 'dnd' });
|
||||||
|
|
||||||
@ -33,6 +32,7 @@ export const addImageDroppedListener = () => {
|
|||||||
effect: (action, { dispatch, getState }) => {
|
effect: (action, { dispatch, getState }) => {
|
||||||
const { activeData, overData } = action.payload;
|
const { activeData, overData } = action.payload;
|
||||||
const { actionType } = overData;
|
const { actionType } = overData;
|
||||||
|
const state = getState();
|
||||||
|
|
||||||
// set current image
|
// set current image
|
||||||
if (
|
if (
|
||||||
@ -64,9 +64,9 @@ export const addImageDroppedListener = () => {
|
|||||||
// add multiple images to batch
|
// add multiple images to batch
|
||||||
if (
|
if (
|
||||||
actionType === 'ADD_TO_BATCH' &&
|
actionType === 'ADD_TO_BATCH' &&
|
||||||
activeData.payloadType === 'IMAGE_NAMES'
|
activeData.payloadType === 'GALLERY_SELECTION'
|
||||||
) {
|
) {
|
||||||
dispatch(imagesAddedToBatch(activeData.payload.imageNames));
|
dispatch(imagesAddedToBatch(state.gallery.selection));
|
||||||
}
|
}
|
||||||
|
|
||||||
// set control image
|
// set control image
|
||||||
@ -128,14 +128,14 @@ export const addImageDroppedListener = () => {
|
|||||||
// set multiple nodes images (multiple images handler)
|
// set multiple nodes images (multiple images handler)
|
||||||
if (
|
if (
|
||||||
actionType === 'SET_MULTI_NODES_IMAGE' &&
|
actionType === 'SET_MULTI_NODES_IMAGE' &&
|
||||||
activeData.payloadType === 'IMAGE_NAMES'
|
activeData.payloadType === 'GALLERY_SELECTION'
|
||||||
) {
|
) {
|
||||||
const { fieldName, nodeId } = overData.context;
|
const { fieldName, nodeId } = overData.context;
|
||||||
dispatch(
|
dispatch(
|
||||||
imageCollectionFieldValueChanged({
|
imageCollectionFieldValueChanged({
|
||||||
nodeId,
|
nodeId,
|
||||||
fieldName,
|
fieldName,
|
||||||
value: activeData.payload.imageNames.map((image_name) => ({
|
value: state.gallery.selection.map((image_name) => ({
|
||||||
image_name,
|
image_name,
|
||||||
})),
|
})),
|
||||||
})
|
})
|
||||||
|
@ -8,31 +8,32 @@ import {
|
|||||||
import dynamicMiddlewares from 'redux-dynamic-middlewares';
|
import dynamicMiddlewares from 'redux-dynamic-middlewares';
|
||||||
import { rememberEnhancer, rememberReducer } from 'redux-remember';
|
import { rememberEnhancer, rememberReducer } from 'redux-remember';
|
||||||
|
|
||||||
|
import batchReducer from 'features/batch/store/batchSlice';
|
||||||
import canvasReducer from 'features/canvas/store/canvasSlice';
|
import canvasReducer from 'features/canvas/store/canvasSlice';
|
||||||
import controlNetReducer from 'features/controlNet/store/controlNetSlice';
|
import controlNetReducer from 'features/controlNet/store/controlNetSlice';
|
||||||
|
import dynamicPromptsReducer from 'features/dynamicPrompts/store/slice';
|
||||||
|
import boardsReducer from 'features/gallery/store/boardSlice';
|
||||||
import galleryReducer from 'features/gallery/store/gallerySlice';
|
import galleryReducer from 'features/gallery/store/gallerySlice';
|
||||||
|
import imageDeletionReducer from 'features/imageDeletion/store/imageDeletionSlice';
|
||||||
import lightboxReducer from 'features/lightbox/store/lightboxSlice';
|
import lightboxReducer from 'features/lightbox/store/lightboxSlice';
|
||||||
|
import loraReducer from 'features/lora/store/loraSlice';
|
||||||
|
import nodesReducer from 'features/nodes/store/nodesSlice';
|
||||||
import generationReducer from 'features/parameters/store/generationSlice';
|
import generationReducer from 'features/parameters/store/generationSlice';
|
||||||
import postprocessingReducer from 'features/parameters/store/postprocessingSlice';
|
import postprocessingReducer from 'features/parameters/store/postprocessingSlice';
|
||||||
import systemReducer from 'features/system/store/systemSlice';
|
|
||||||
import nodesReducer from 'features/nodes/store/nodesSlice';
|
|
||||||
import boardsReducer from 'features/gallery/store/boardSlice';
|
|
||||||
import configReducer from 'features/system/store/configSlice';
|
import configReducer from 'features/system/store/configSlice';
|
||||||
|
import systemReducer from 'features/system/store/systemSlice';
|
||||||
import hotkeysReducer from 'features/ui/store/hotkeysSlice';
|
import hotkeysReducer from 'features/ui/store/hotkeysSlice';
|
||||||
import uiReducer from 'features/ui/store/uiSlice';
|
import uiReducer from 'features/ui/store/uiSlice';
|
||||||
import dynamicPromptsReducer from 'features/dynamicPrompts/store/slice';
|
|
||||||
import batchReducer from 'features/batch/store/batchSlice';
|
|
||||||
import imageDeletionReducer from 'features/imageDeletion/store/imageDeletionSlice';
|
|
||||||
|
|
||||||
import { listenerMiddleware } from './middleware/listenerMiddleware';
|
import { listenerMiddleware } from './middleware/listenerMiddleware';
|
||||||
|
|
||||||
import { actionSanitizer } from './middleware/devtools/actionSanitizer';
|
import { api } from 'services/api';
|
||||||
import { actionsDenylist } from './middleware/devtools/actionsDenylist';
|
|
||||||
import { stateSanitizer } from './middleware/devtools/stateSanitizer';
|
|
||||||
import { LOCALSTORAGE_PREFIX } from './constants';
|
import { LOCALSTORAGE_PREFIX } from './constants';
|
||||||
import { serialize } from './enhancers/reduxRemember/serialize';
|
import { serialize } from './enhancers/reduxRemember/serialize';
|
||||||
import { unserialize } from './enhancers/reduxRemember/unserialize';
|
import { unserialize } from './enhancers/reduxRemember/unserialize';
|
||||||
import { api } from 'services/api';
|
import { actionSanitizer } from './middleware/devtools/actionSanitizer';
|
||||||
|
import { actionsDenylist } from './middleware/devtools/actionsDenylist';
|
||||||
|
import { stateSanitizer } from './middleware/devtools/stateSanitizer';
|
||||||
|
|
||||||
const allReducers = {
|
const allReducers = {
|
||||||
canvas: canvasReducer,
|
canvas: canvasReducer,
|
||||||
@ -50,6 +51,7 @@ const allReducers = {
|
|||||||
dynamicPrompts: dynamicPromptsReducer,
|
dynamicPrompts: dynamicPromptsReducer,
|
||||||
batch: batchReducer,
|
batch: batchReducer,
|
||||||
imageDeletion: imageDeletionReducer,
|
imageDeletion: imageDeletionReducer,
|
||||||
|
lora: loraReducer,
|
||||||
[api.reducerPath]: api.reducer,
|
[api.reducerPath]: api.reducer,
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -69,6 +71,7 @@ const rememberedKeys: (keyof typeof allReducers)[] = [
|
|||||||
'controlNet',
|
'controlNet',
|
||||||
'dynamicPrompts',
|
'dynamicPrompts',
|
||||||
'batch',
|
'batch',
|
||||||
|
'lora',
|
||||||
// 'boards',
|
// 'boards',
|
||||||
// 'hotkeys',
|
// 'hotkeys',
|
||||||
// 'config',
|
// 'config',
|
||||||
|
@ -4,22 +4,25 @@ import {
|
|||||||
Collapse,
|
Collapse,
|
||||||
Flex,
|
Flex,
|
||||||
Spacer,
|
Spacer,
|
||||||
Switch,
|
Text,
|
||||||
useColorMode,
|
useColorMode,
|
||||||
|
useDisclosure,
|
||||||
} from '@chakra-ui/react';
|
} from '@chakra-ui/react';
|
||||||
|
import { AnimatePresence, motion } from 'framer-motion';
|
||||||
import { PropsWithChildren, memo } from 'react';
|
import { PropsWithChildren, memo } from 'react';
|
||||||
import { mode } from 'theme/util/mode';
|
import { mode } from 'theme/util/mode';
|
||||||
|
|
||||||
export type IAIToggleCollapseProps = PropsWithChildren & {
|
export type IAIToggleCollapseProps = PropsWithChildren & {
|
||||||
label: string;
|
label: string;
|
||||||
isOpen: boolean;
|
activeLabel?: string;
|
||||||
onToggle: () => void;
|
defaultIsOpen?: boolean;
|
||||||
withSwitch?: boolean;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
const IAICollapse = (props: IAIToggleCollapseProps) => {
|
const IAICollapse = (props: IAIToggleCollapseProps) => {
|
||||||
const { label, isOpen, onToggle, children, withSwitch = false } = props;
|
const { label, activeLabel, children, defaultIsOpen = false } = props;
|
||||||
|
const { isOpen, onToggle } = useDisclosure({ defaultIsOpen });
|
||||||
const { colorMode } = useColorMode();
|
const { colorMode } = useColorMode();
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Box>
|
<Box>
|
||||||
<Flex
|
<Flex
|
||||||
@ -28,6 +31,7 @@ const IAICollapse = (props: IAIToggleCollapseProps) => {
|
|||||||
alignItems: 'center',
|
alignItems: 'center',
|
||||||
p: 2,
|
p: 2,
|
||||||
px: 4,
|
px: 4,
|
||||||
|
gap: 2,
|
||||||
borderTopRadius: 'base',
|
borderTopRadius: 'base',
|
||||||
borderBottomRadius: isOpen ? 0 : 'base',
|
borderBottomRadius: isOpen ? 0 : 'base',
|
||||||
bg: isOpen
|
bg: isOpen
|
||||||
@ -48,19 +52,40 @@ const IAICollapse = (props: IAIToggleCollapseProps) => {
|
|||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
{label}
|
{label}
|
||||||
|
<AnimatePresence>
|
||||||
|
{activeLabel && (
|
||||||
|
<motion.div
|
||||||
|
key="statusText"
|
||||||
|
initial={{
|
||||||
|
opacity: 0,
|
||||||
|
}}
|
||||||
|
animate={{
|
||||||
|
opacity: 1,
|
||||||
|
transition: { duration: 0.1 },
|
||||||
|
}}
|
||||||
|
exit={{
|
||||||
|
opacity: 0,
|
||||||
|
transition: { duration: 0.1 },
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<Text
|
||||||
|
sx={{ color: 'accent.500', _dark: { color: 'accent.300' } }}
|
||||||
|
>
|
||||||
|
{activeLabel}
|
||||||
|
</Text>
|
||||||
|
</motion.div>
|
||||||
|
)}
|
||||||
|
</AnimatePresence>
|
||||||
<Spacer />
|
<Spacer />
|
||||||
{withSwitch && <Switch isChecked={isOpen} pointerEvents="none" />}
|
<ChevronUpIcon
|
||||||
{!withSwitch && (
|
sx={{
|
||||||
<ChevronUpIcon
|
w: '1rem',
|
||||||
sx={{
|
h: '1rem',
|
||||||
w: '1rem',
|
transform: isOpen ? 'rotate(0deg)' : 'rotate(180deg)',
|
||||||
h: '1rem',
|
transitionProperty: 'common',
|
||||||
transform: isOpen ? 'rotate(0deg)' : 'rotate(180deg)',
|
transitionDuration: 'normal',
|
||||||
transitionProperty: 'common',
|
}}
|
||||||
transitionDuration: 'normal',
|
/>
|
||||||
}}
|
|
||||||
/>
|
|
||||||
)}
|
|
||||||
</Flex>
|
</Flex>
|
||||||
<Collapse in={isOpen} animateOpacity style={{ overflow: 'unset' }}>
|
<Collapse in={isOpen} animateOpacity style={{ overflow: 'unset' }}>
|
||||||
<Box
|
<Box
|
||||||
|
@ -61,7 +61,7 @@ const IAIMantineMultiSelect = (props: IAIMultiSelectProps) => {
|
|||||||
'&:focus-within': {
|
'&:focus-within': {
|
||||||
borderColor: mode(accent200, accent600)(colorMode),
|
borderColor: mode(accent200, accent600)(colorMode),
|
||||||
},
|
},
|
||||||
'&:disabled': {
|
'&[data-disabled]': {
|
||||||
backgroundColor: mode(base300, base700)(colorMode),
|
backgroundColor: mode(base300, base700)(colorMode),
|
||||||
color: mode(base600, base400)(colorMode),
|
color: mode(base600, base400)(colorMode),
|
||||||
},
|
},
|
||||||
|
@ -64,7 +64,7 @@ const IAIMantineSelect = (props: IAISelectProps) => {
|
|||||||
'&:focus-within': {
|
'&:focus-within': {
|
||||||
borderColor: mode(accent200, accent600)(colorMode),
|
borderColor: mode(accent200, accent600)(colorMode),
|
||||||
},
|
},
|
||||||
'&:disabled': {
|
'&[data-disabled]': {
|
||||||
backgroundColor: mode(base300, base700)(colorMode),
|
backgroundColor: mode(base300, base700)(colorMode),
|
||||||
color: mode(base600, base400)(colorMode),
|
color: mode(base600, base400)(colorMode),
|
||||||
},
|
},
|
||||||
|
@ -36,7 +36,6 @@ const IAISwitch = (props: Props) => {
|
|||||||
isDisabled={isDisabled}
|
isDisabled={isDisabled}
|
||||||
width={width}
|
width={width}
|
||||||
display="flex"
|
display="flex"
|
||||||
gap={4}
|
|
||||||
alignItems="center"
|
alignItems="center"
|
||||||
{...formControlProps}
|
{...formControlProps}
|
||||||
>
|
>
|
||||||
@ -47,6 +46,7 @@ const IAISwitch = (props: Props) => {
|
|||||||
sx={{
|
sx={{
|
||||||
cursor: isDisabled ? 'not-allowed' : 'pointer',
|
cursor: isDisabled ? 'not-allowed' : 'pointer',
|
||||||
...formLabelProps?.sx,
|
...formLabelProps?.sx,
|
||||||
|
pe: 4,
|
||||||
}}
|
}}
|
||||||
{...formLabelProps}
|
{...formLabelProps}
|
||||||
>
|
>
|
||||||
|
@ -1,28 +1,29 @@
|
|||||||
import { Box, Icon, Skeleton } from '@chakra-ui/react';
|
import { Box, Icon, Skeleton } from '@chakra-ui/react';
|
||||||
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
|
import { TypesafeDraggableData } from 'app/components/ImageDnd/typesafeDnd';
|
||||||
|
import { stateSelector } from 'app/store/store';
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
import { FaExclamationCircle } from 'react-icons/fa';
|
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||||
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
|
import IAIDndImage from 'common/components/IAIDndImage';
|
||||||
import { MouseEvent, memo, useCallback, useMemo } from 'react';
|
|
||||||
import {
|
import {
|
||||||
batchImageRangeEndSelected,
|
batchImageRangeEndSelected,
|
||||||
batchImageSelected,
|
batchImageSelected,
|
||||||
batchImageSelectionToggled,
|
batchImageSelectionToggled,
|
||||||
imageRemovedFromBatch,
|
imageRemovedFromBatch,
|
||||||
} from 'features/batch/store/batchSlice';
|
} from 'features/batch/store/batchSlice';
|
||||||
import IAIDndImage from 'common/components/IAIDndImage';
|
import { MouseEvent, memo, useCallback, useMemo } from 'react';
|
||||||
import { createSelector } from '@reduxjs/toolkit';
|
import { FaExclamationCircle } from 'react-icons/fa';
|
||||||
import { RootState, stateSelector } from 'app/store/store';
|
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
|
||||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
|
||||||
import { TypesafeDraggableData } from 'app/components/ImageDnd/typesafeDnd';
|
|
||||||
|
|
||||||
const isSelectedSelector = createSelector(
|
const makeSelector = (image_name: string) =>
|
||||||
[stateSelector, (state: RootState, imageName: string) => imageName],
|
createSelector(
|
||||||
(state, imageName) => ({
|
[stateSelector],
|
||||||
selection: state.batch.selection,
|
(state) => ({
|
||||||
isSelected: state.batch.selection.includes(imageName),
|
selectionCount: state.batch.selection.length,
|
||||||
}),
|
isSelected: state.batch.selection.includes(image_name),
|
||||||
defaultSelectorOptions
|
}),
|
||||||
);
|
defaultSelectorOptions
|
||||||
|
);
|
||||||
|
|
||||||
type BatchImageProps = {
|
type BatchImageProps = {
|
||||||
imageName: string;
|
imageName: string;
|
||||||
@ -37,10 +38,13 @@ const BatchImage = (props: BatchImageProps) => {
|
|||||||
} = useGetImageDTOQuery(props.imageName);
|
} = useGetImageDTOQuery(props.imageName);
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
|
|
||||||
const { isSelected, selection } = useAppSelector((state) =>
|
const selector = useMemo(
|
||||||
isSelectedSelector(state, props.imageName)
|
() => makeSelector(props.imageName),
|
||||||
|
[props.imageName]
|
||||||
);
|
);
|
||||||
|
|
||||||
|
const { isSelected, selectionCount } = useAppSelector(selector);
|
||||||
|
|
||||||
const handleClickRemove = useCallback(() => {
|
const handleClickRemove = useCallback(() => {
|
||||||
dispatch(imageRemovedFromBatch(props.imageName));
|
dispatch(imageRemovedFromBatch(props.imageName));
|
||||||
}, [dispatch, props.imageName]);
|
}, [dispatch, props.imageName]);
|
||||||
@ -59,13 +63,10 @@ const BatchImage = (props: BatchImageProps) => {
|
|||||||
);
|
);
|
||||||
|
|
||||||
const draggableData = useMemo<TypesafeDraggableData | undefined>(() => {
|
const draggableData = useMemo<TypesafeDraggableData | undefined>(() => {
|
||||||
if (selection.length > 1) {
|
if (selectionCount > 1) {
|
||||||
return {
|
return {
|
||||||
id: 'batch',
|
id: 'batch',
|
||||||
payloadType: 'IMAGE_NAMES',
|
payloadType: 'BATCH_SELECTION',
|
||||||
payload: {
|
|
||||||
imageNames: selection,
|
|
||||||
},
|
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -76,7 +77,7 @@ const BatchImage = (props: BatchImageProps) => {
|
|||||||
payload: { imageDTO },
|
payload: { imageDTO },
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
}, [imageDTO, selection]);
|
}, [imageDTO, selectionCount]);
|
||||||
|
|
||||||
if (isError) {
|
if (isError) {
|
||||||
return <Icon as={FaExclamationCircle} />;
|
return <Icon as={FaExclamationCircle} />;
|
||||||
|
@ -1,25 +1,22 @@
|
|||||||
import { memo, useCallback, useMemo, useState } from 'react';
|
|
||||||
import { ImageDTO } from 'services/api/types';
|
|
||||||
import {
|
|
||||||
ControlNetConfig,
|
|
||||||
controlNetImageChanged,
|
|
||||||
controlNetSelector,
|
|
||||||
} from '../store/controlNetSlice';
|
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
|
||||||
import { Box, Flex, SystemStyleObject } from '@chakra-ui/react';
|
import { Box, Flex, SystemStyleObject } from '@chakra-ui/react';
|
||||||
import IAIDndImage from 'common/components/IAIDndImage';
|
|
||||||
import { createSelector } from '@reduxjs/toolkit';
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
|
||||||
import { IAILoadingImageFallback } from 'common/components/IAIImageFallback';
|
|
||||||
import IAIIconButton from 'common/components/IAIIconButton';
|
|
||||||
import { FaUndo } from 'react-icons/fa';
|
|
||||||
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
|
|
||||||
import { skipToken } from '@reduxjs/toolkit/dist/query';
|
import { skipToken } from '@reduxjs/toolkit/dist/query';
|
||||||
import {
|
import {
|
||||||
TypesafeDraggableData,
|
TypesafeDraggableData,
|
||||||
TypesafeDroppableData,
|
TypesafeDroppableData,
|
||||||
} from 'app/components/ImageDnd/typesafeDnd';
|
} from 'app/components/ImageDnd/typesafeDnd';
|
||||||
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||||
|
import IAIDndImage from 'common/components/IAIDndImage';
|
||||||
|
import { IAILoadingImageFallback } from 'common/components/IAIImageFallback';
|
||||||
|
import { memo, useCallback, useMemo, useState } from 'react';
|
||||||
|
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
|
||||||
import { PostUploadAction } from 'services/api/thunks/image';
|
import { PostUploadAction } from 'services/api/thunks/image';
|
||||||
|
import {
|
||||||
|
ControlNetConfig,
|
||||||
|
controlNetImageChanged,
|
||||||
|
controlNetSelector,
|
||||||
|
} from '../store/controlNetSlice';
|
||||||
|
|
||||||
const selector = createSelector(
|
const selector = createSelector(
|
||||||
controlNetSelector,
|
controlNetSelector,
|
||||||
@ -83,15 +80,14 @@ const ControlNetImagePreview = (props: Props) => {
|
|||||||
}
|
}
|
||||||
}, [controlImage, controlNetId]);
|
}, [controlImage, controlNetId]);
|
||||||
|
|
||||||
const droppableData = useMemo<TypesafeDroppableData | undefined>(() => {
|
const droppableData = useMemo<TypesafeDroppableData | undefined>(
|
||||||
if (controlNetId) {
|
() => ({
|
||||||
return {
|
id: controlNetId,
|
||||||
id: controlNetId,
|
actionType: 'SET_CONTROLNET_IMAGE',
|
||||||
actionType: 'SET_CONTROLNET_IMAGE',
|
context: { controlNetId },
|
||||||
context: { controlNetId },
|
}),
|
||||||
};
|
[controlNetId]
|
||||||
}
|
);
|
||||||
}, [controlNetId]);
|
|
||||||
|
|
||||||
const postUploadAction = useMemo<PostUploadAction>(
|
const postUploadAction = useMemo<PostUploadAction>(
|
||||||
() => ({ type: 'SET_CONTROLNET_IMAGE', controlNetId }),
|
() => ({ type: 'SET_CONTROLNET_IMAGE', controlNetId }),
|
||||||
|
@ -0,0 +1,36 @@
|
|||||||
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
|
import { stateSelector } from 'app/store/store';
|
||||||
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||||
|
import IAISwitch from 'common/components/IAISwitch';
|
||||||
|
import { isControlNetEnabledToggled } from 'features/controlNet/store/controlNetSlice';
|
||||||
|
import { useCallback } from 'react';
|
||||||
|
|
||||||
|
const selector = createSelector(
|
||||||
|
stateSelector,
|
||||||
|
(state) => {
|
||||||
|
const { isEnabled } = state.controlNet;
|
||||||
|
|
||||||
|
return { isEnabled };
|
||||||
|
},
|
||||||
|
defaultSelectorOptions
|
||||||
|
);
|
||||||
|
|
||||||
|
const ParamControlNetFeatureToggle = () => {
|
||||||
|
const { isEnabled } = useAppSelector(selector);
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
|
||||||
|
const handleChange = useCallback(() => {
|
||||||
|
dispatch(isControlNetEnabledToggled());
|
||||||
|
}, [dispatch]);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<IAISwitch
|
||||||
|
label="Enable ControlNet"
|
||||||
|
isChecked={isEnabled}
|
||||||
|
onChange={handleChange}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default ParamControlNetFeatureToggle;
|
@ -0,0 +1,15 @@
|
|||||||
|
import { filter } from 'lodash-es';
|
||||||
|
import { ControlNetConfig } from '../store/controlNetSlice';
|
||||||
|
|
||||||
|
export const getValidControlNets = (
|
||||||
|
controlNets: Record<string, ControlNetConfig>
|
||||||
|
) => {
|
||||||
|
const validControlNets = filter(
|
||||||
|
controlNets,
|
||||||
|
(c) =>
|
||||||
|
c.isEnabled &&
|
||||||
|
(Boolean(c.processedControlImage) ||
|
||||||
|
(c.processorType === 'none' && Boolean(c.controlImage)))
|
||||||
|
);
|
||||||
|
return validControlNets;
|
||||||
|
};
|
@ -1,40 +1,30 @@
|
|||||||
|
import { Flex } from '@chakra-ui/react';
|
||||||
import { createSelector } from '@reduxjs/toolkit';
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
import { stateSelector } from 'app/store/store';
|
import { stateSelector } from 'app/store/store';
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||||
import IAICollapse from 'common/components/IAICollapse';
|
import IAICollapse from 'common/components/IAICollapse';
|
||||||
import { useCallback } from 'react';
|
|
||||||
import { isEnabledToggled } from '../store/slice';
|
|
||||||
import ParamDynamicPromptsMaxPrompts from './ParamDynamicPromptsMaxPrompts';
|
|
||||||
import ParamDynamicPromptsCombinatorial from './ParamDynamicPromptsCombinatorial';
|
import ParamDynamicPromptsCombinatorial from './ParamDynamicPromptsCombinatorial';
|
||||||
import { Flex } from '@chakra-ui/react';
|
import ParamDynamicPromptsToggle from './ParamDynamicPromptsEnabled';
|
||||||
|
import ParamDynamicPromptsMaxPrompts from './ParamDynamicPromptsMaxPrompts';
|
||||||
|
|
||||||
const selector = createSelector(
|
const selector = createSelector(
|
||||||
stateSelector,
|
stateSelector,
|
||||||
(state) => {
|
(state) => {
|
||||||
const { isEnabled } = state.dynamicPrompts;
|
const { isEnabled } = state.dynamicPrompts;
|
||||||
|
|
||||||
return { isEnabled };
|
return { activeLabel: isEnabled ? 'Enabled' : undefined };
|
||||||
},
|
},
|
||||||
defaultSelectorOptions
|
defaultSelectorOptions
|
||||||
);
|
);
|
||||||
|
|
||||||
const ParamDynamicPromptsCollapse = () => {
|
const ParamDynamicPromptsCollapse = () => {
|
||||||
const dispatch = useAppDispatch();
|
const { activeLabel } = useAppSelector(selector);
|
||||||
const { isEnabled } = useAppSelector(selector);
|
|
||||||
|
|
||||||
const handleToggleIsEnabled = useCallback(() => {
|
|
||||||
dispatch(isEnabledToggled());
|
|
||||||
}, [dispatch]);
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<IAICollapse
|
<IAICollapse label="Dynamic Prompts" activeLabel={activeLabel}>
|
||||||
isOpen={isEnabled}
|
|
||||||
onToggle={handleToggleIsEnabled}
|
|
||||||
label="Dynamic Prompts"
|
|
||||||
withSwitch
|
|
||||||
>
|
|
||||||
<Flex sx={{ gap: 2, flexDir: 'column' }}>
|
<Flex sx={{ gap: 2, flexDir: 'column' }}>
|
||||||
|
<ParamDynamicPromptsToggle />
|
||||||
<ParamDynamicPromptsCombinatorial />
|
<ParamDynamicPromptsCombinatorial />
|
||||||
<ParamDynamicPromptsMaxPrompts />
|
<ParamDynamicPromptsMaxPrompts />
|
||||||
</Flex>
|
</Flex>
|
||||||
|
@ -1,23 +1,23 @@
|
|||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
|
||||||
import { combinatorialToggled } from '../store/slice';
|
|
||||||
import { createSelector } from '@reduxjs/toolkit';
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
|
||||||
import { useCallback } from 'react';
|
|
||||||
import { stateSelector } from 'app/store/store';
|
import { stateSelector } from 'app/store/store';
|
||||||
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||||
import IAISwitch from 'common/components/IAISwitch';
|
import IAISwitch from 'common/components/IAISwitch';
|
||||||
|
import { useCallback } from 'react';
|
||||||
|
import { combinatorialToggled } from '../store/slice';
|
||||||
|
|
||||||
const selector = createSelector(
|
const selector = createSelector(
|
||||||
stateSelector,
|
stateSelector,
|
||||||
(state) => {
|
(state) => {
|
||||||
const { combinatorial } = state.dynamicPrompts;
|
const { combinatorial, isEnabled } = state.dynamicPrompts;
|
||||||
|
|
||||||
return { combinatorial };
|
return { combinatorial, isDisabled: !isEnabled };
|
||||||
},
|
},
|
||||||
defaultSelectorOptions
|
defaultSelectorOptions
|
||||||
);
|
);
|
||||||
|
|
||||||
const ParamDynamicPromptsCombinatorial = () => {
|
const ParamDynamicPromptsCombinatorial = () => {
|
||||||
const { combinatorial } = useAppSelector(selector);
|
const { combinatorial, isDisabled } = useAppSelector(selector);
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
|
|
||||||
const handleChange = useCallback(() => {
|
const handleChange = useCallback(() => {
|
||||||
@ -26,6 +26,7 @@ const ParamDynamicPromptsCombinatorial = () => {
|
|||||||
|
|
||||||
return (
|
return (
|
||||||
<IAISwitch
|
<IAISwitch
|
||||||
|
isDisabled={isDisabled}
|
||||||
label="Combinatorial Generation"
|
label="Combinatorial Generation"
|
||||||
isChecked={combinatorial}
|
isChecked={combinatorial}
|
||||||
onChange={handleChange}
|
onChange={handleChange}
|
||||||
|
@ -0,0 +1,36 @@
|
|||||||
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
|
import { stateSelector } from 'app/store/store';
|
||||||
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||||
|
import IAISwitch from 'common/components/IAISwitch';
|
||||||
|
import { useCallback } from 'react';
|
||||||
|
import { isEnabledToggled } from '../store/slice';
|
||||||
|
|
||||||
|
const selector = createSelector(
|
||||||
|
stateSelector,
|
||||||
|
(state) => {
|
||||||
|
const { isEnabled } = state.dynamicPrompts;
|
||||||
|
|
||||||
|
return { isEnabled };
|
||||||
|
},
|
||||||
|
defaultSelectorOptions
|
||||||
|
);
|
||||||
|
|
||||||
|
const ParamDynamicPromptsToggle = () => {
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
const { isEnabled } = useAppSelector(selector);
|
||||||
|
|
||||||
|
const handleToggleIsEnabled = useCallback(() => {
|
||||||
|
dispatch(isEnabledToggled());
|
||||||
|
}, [dispatch]);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<IAISwitch
|
||||||
|
label="Enable Dynamic Prompts"
|
||||||
|
isChecked={isEnabled}
|
||||||
|
onChange={handleToggleIsEnabled}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default ParamDynamicPromptsToggle;
|
@ -1,25 +1,31 @@
|
|||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
|
||||||
import IAISlider from 'common/components/IAISlider';
|
|
||||||
import { maxPromptsChanged, maxPromptsReset } from '../store/slice';
|
|
||||||
import { createSelector } from '@reduxjs/toolkit';
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
|
||||||
import { useCallback } from 'react';
|
|
||||||
import { stateSelector } from 'app/store/store';
|
import { stateSelector } from 'app/store/store';
|
||||||
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||||
|
import IAISlider from 'common/components/IAISlider';
|
||||||
|
import { useCallback } from 'react';
|
||||||
|
import { maxPromptsChanged, maxPromptsReset } from '../store/slice';
|
||||||
|
|
||||||
const selector = createSelector(
|
const selector = createSelector(
|
||||||
stateSelector,
|
stateSelector,
|
||||||
(state) => {
|
(state) => {
|
||||||
const { maxPrompts, combinatorial } = state.dynamicPrompts;
|
const { maxPrompts, combinatorial, isEnabled } = state.dynamicPrompts;
|
||||||
const { min, sliderMax, inputMax } =
|
const { min, sliderMax, inputMax } =
|
||||||
state.config.sd.dynamicPrompts.maxPrompts;
|
state.config.sd.dynamicPrompts.maxPrompts;
|
||||||
|
|
||||||
return { maxPrompts, min, sliderMax, inputMax, combinatorial };
|
return {
|
||||||
|
maxPrompts,
|
||||||
|
min,
|
||||||
|
sliderMax,
|
||||||
|
inputMax,
|
||||||
|
isDisabled: !isEnabled || !combinatorial,
|
||||||
|
};
|
||||||
},
|
},
|
||||||
defaultSelectorOptions
|
defaultSelectorOptions
|
||||||
);
|
);
|
||||||
|
|
||||||
const ParamDynamicPromptsMaxPrompts = () => {
|
const ParamDynamicPromptsMaxPrompts = () => {
|
||||||
const { maxPrompts, min, sliderMax, inputMax, combinatorial } =
|
const { maxPrompts, min, sliderMax, inputMax, isDisabled } =
|
||||||
useAppSelector(selector);
|
useAppSelector(selector);
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
|
|
||||||
@ -37,7 +43,7 @@ const ParamDynamicPromptsMaxPrompts = () => {
|
|||||||
return (
|
return (
|
||||||
<IAISlider
|
<IAISlider
|
||||||
label="Max Prompts"
|
label="Max Prompts"
|
||||||
isDisabled={!combinatorial}
|
isDisabled={isDisabled}
|
||||||
min={min}
|
min={min}
|
||||||
max={sliderMax}
|
max={sliderMax}
|
||||||
value={maxPrompts}
|
value={maxPrompts}
|
||||||
|
@ -1,19 +1,19 @@
|
|||||||
import { Box, Flex, Image } from '@chakra-ui/react';
|
import { Box, Flex, Image } from '@chakra-ui/react';
|
||||||
import { createSelector } from '@reduxjs/toolkit';
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
import { useAppSelector } from 'app/store/storeHooks';
|
|
||||||
import { isEqual } from 'lodash-es';
|
|
||||||
import ImageMetadataViewer from './ImageMetaDataViewer/ImageMetadataViewer';
|
|
||||||
import NextPrevImageButtons from './NextPrevImageButtons';
|
|
||||||
import { memo, useMemo } from 'react';
|
|
||||||
import IAIDndImage from 'common/components/IAIDndImage';
|
|
||||||
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
|
|
||||||
import { skipToken } from '@reduxjs/toolkit/dist/query';
|
import { skipToken } from '@reduxjs/toolkit/dist/query';
|
||||||
import { stateSelector } from 'app/store/store';
|
|
||||||
import { selectLastSelectedImage } from 'features/gallery/store/gallerySlice';
|
|
||||||
import {
|
import {
|
||||||
TypesafeDraggableData,
|
TypesafeDraggableData,
|
||||||
TypesafeDroppableData,
|
TypesafeDroppableData,
|
||||||
} from 'app/components/ImageDnd/typesafeDnd';
|
} from 'app/components/ImageDnd/typesafeDnd';
|
||||||
|
import { stateSelector } from 'app/store/store';
|
||||||
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import IAIDndImage from 'common/components/IAIDndImage';
|
||||||
|
import { selectLastSelectedImage } from 'features/gallery/store/gallerySlice';
|
||||||
|
import { isEqual } from 'lodash-es';
|
||||||
|
import { memo, useMemo } from 'react';
|
||||||
|
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
|
||||||
|
import ImageMetadataViewer from './ImageMetaDataViewer/ImageMetadataViewer';
|
||||||
|
import NextPrevImageButtons from './NextPrevImageButtons';
|
||||||
|
|
||||||
export const imagesSelector = createSelector(
|
export const imagesSelector = createSelector(
|
||||||
[stateSelector, selectLastSelectedImage],
|
[stateSelector, selectLastSelectedImage],
|
||||||
|
@ -1,34 +1,35 @@
|
|||||||
import { Box } from '@chakra-ui/react';
|
import { Box } from '@chakra-ui/react';
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
|
||||||
import { MouseEvent, memo, useCallback, useMemo } from 'react';
|
|
||||||
import { FaTrash } from 'react-icons/fa';
|
|
||||||
import { useTranslation } from 'react-i18next';
|
|
||||||
import { createSelector } from '@reduxjs/toolkit';
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
import { ImageDTO } from 'services/api/types';
|
|
||||||
import { TypesafeDraggableData } from 'app/components/ImageDnd/typesafeDnd';
|
import { TypesafeDraggableData } from 'app/components/ImageDnd/typesafeDnd';
|
||||||
import { stateSelector } from 'app/store/store';
|
import { stateSelector } from 'app/store/store';
|
||||||
import ImageContextMenu from './ImageContextMenu';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||||
import IAIDndImage from 'common/components/IAIDndImage';
|
import IAIDndImage from 'common/components/IAIDndImage';
|
||||||
|
import { imageToDeleteSelected } from 'features/imageDeletion/store/imageDeletionSlice';
|
||||||
|
import { MouseEvent, memo, useCallback, useMemo } from 'react';
|
||||||
|
import { useTranslation } from 'react-i18next';
|
||||||
|
import { FaTrash } from 'react-icons/fa';
|
||||||
|
import { ImageDTO } from 'services/api/types';
|
||||||
import {
|
import {
|
||||||
imageRangeEndSelected,
|
imageRangeEndSelected,
|
||||||
imageSelected,
|
imageSelected,
|
||||||
imageSelectionToggled,
|
imageSelectionToggled,
|
||||||
} from '../store/gallerySlice';
|
} from '../store/gallerySlice';
|
||||||
import { imageToDeleteSelected } from 'features/imageDeletion/store/imageDeletionSlice';
|
import ImageContextMenu from './ImageContextMenu';
|
||||||
|
|
||||||
export const selector = createSelector(
|
export const makeSelector = (image_name: string) =>
|
||||||
[stateSelector, (state, { image_name }: ImageDTO) => image_name],
|
createSelector(
|
||||||
({ gallery }, image_name) => {
|
[stateSelector],
|
||||||
const isSelected = gallery.selection.includes(image_name);
|
({ gallery }) => {
|
||||||
const selection = gallery.selection;
|
const isSelected = gallery.selection.includes(image_name);
|
||||||
return {
|
const selectionCount = gallery.selection.length;
|
||||||
isSelected,
|
return {
|
||||||
selection,
|
isSelected,
|
||||||
};
|
selectionCount,
|
||||||
},
|
};
|
||||||
defaultSelectorOptions
|
},
|
||||||
);
|
defaultSelectorOptions
|
||||||
|
);
|
||||||
|
|
||||||
interface HoverableImageProps {
|
interface HoverableImageProps {
|
||||||
imageDTO: ImageDTO;
|
imageDTO: ImageDTO;
|
||||||
@ -38,13 +39,13 @@ interface HoverableImageProps {
|
|||||||
* Gallery image component with delete/use all/use seed buttons on hover.
|
* Gallery image component with delete/use all/use seed buttons on hover.
|
||||||
*/
|
*/
|
||||||
const GalleryImage = (props: HoverableImageProps) => {
|
const GalleryImage = (props: HoverableImageProps) => {
|
||||||
const { isSelected, selection } = useAppSelector((state) =>
|
|
||||||
selector(state, props.imageDTO)
|
|
||||||
);
|
|
||||||
|
|
||||||
const { imageDTO } = props;
|
const { imageDTO } = props;
|
||||||
const { image_url, thumbnail_url, image_name } = imageDTO;
|
const { image_url, thumbnail_url, image_name } = imageDTO;
|
||||||
|
|
||||||
|
const localSelector = useMemo(() => makeSelector(image_name), [image_name]);
|
||||||
|
|
||||||
|
const { isSelected, selectionCount } = useAppSelector(localSelector);
|
||||||
|
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
|
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
@ -74,11 +75,10 @@ const GalleryImage = (props: HoverableImageProps) => {
|
|||||||
);
|
);
|
||||||
|
|
||||||
const draggableData = useMemo<TypesafeDraggableData | undefined>(() => {
|
const draggableData = useMemo<TypesafeDraggableData | undefined>(() => {
|
||||||
if (selection.length > 1) {
|
if (selectionCount > 1) {
|
||||||
return {
|
return {
|
||||||
id: 'gallery-image',
|
id: 'gallery-image',
|
||||||
payloadType: 'IMAGE_NAMES',
|
payloadType: 'GALLERY_SELECTION',
|
||||||
payload: { imageNames: selection },
|
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -89,7 +89,7 @@ const GalleryImage = (props: HoverableImageProps) => {
|
|||||||
payload: { imageDTO },
|
payload: { imageDTO },
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
}, [imageDTO, selection]);
|
}, [imageDTO, selectionCount]);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Box sx={{ w: 'full', h: 'full', touchAction: 'none' }}>
|
<Box sx={{ w: 'full', h: 'full', touchAction: 'none' }}>
|
||||||
|
@ -7,7 +7,6 @@ import {
|
|||||||
import { RootState } from 'app/store/store';
|
import { RootState } from 'app/store/store';
|
||||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||||
import { dateComparator } from 'common/util/dateComparator';
|
import { dateComparator } from 'common/util/dateComparator';
|
||||||
import { imageDeletionConfirmed } from 'features/imageDeletion/store/imageDeletionSlice';
|
|
||||||
import { keyBy, uniq } from 'lodash-es';
|
import { keyBy, uniq } from 'lodash-es';
|
||||||
import { boardsApi } from 'services/api/endpoints/boards';
|
import { boardsApi } from 'services/api/endpoints/boards';
|
||||||
import {
|
import {
|
||||||
@ -174,11 +173,6 @@ export const gallerySlice = createSlice({
|
|||||||
state.limit = limit;
|
state.limit = limit;
|
||||||
state.total = total;
|
state.total = total;
|
||||||
});
|
});
|
||||||
builder.addCase(imageDeletionConfirmed, (state, action) => {
|
|
||||||
// Image deleted
|
|
||||||
const { image_name } = action.payload.imageDTO;
|
|
||||||
imagesAdapter.removeOne(state, image_name);
|
|
||||||
});
|
|
||||||
builder.addCase(imageUrlsReceived.fulfilled, (state, action) => {
|
builder.addCase(imageUrlsReceived.fulfilled, (state, action) => {
|
||||||
const { image_name, image_url, thumbnail_url } = action.payload;
|
const { image_name, image_url, thumbnail_url } = action.payload;
|
||||||
|
|
||||||
|
@ -23,6 +23,7 @@ import { stateSelector } from 'app/store/store';
|
|||||||
import {
|
import {
|
||||||
imageDeletionConfirmed,
|
imageDeletionConfirmed,
|
||||||
imageToDeleteCleared,
|
imageToDeleteCleared,
|
||||||
|
isModalOpenChanged,
|
||||||
selectImageUsage,
|
selectImageUsage,
|
||||||
} from '../store/imageDeletionSlice';
|
} from '../store/imageDeletionSlice';
|
||||||
|
|
||||||
@ -63,6 +64,7 @@ const DeleteImageModal = () => {
|
|||||||
|
|
||||||
const handleClose = useCallback(() => {
|
const handleClose = useCallback(() => {
|
||||||
dispatch(imageToDeleteCleared());
|
dispatch(imageToDeleteCleared());
|
||||||
|
dispatch(isModalOpenChanged(false));
|
||||||
}, [dispatch]);
|
}, [dispatch]);
|
||||||
|
|
||||||
const handleDelete = useCallback(() => {
|
const handleDelete = useCallback(() => {
|
||||||
|
@ -31,6 +31,7 @@ const imageDeletion = createSlice({
|
|||||||
},
|
},
|
||||||
imageToDeleteCleared: (state) => {
|
imageToDeleteCleared: (state) => {
|
||||||
state.imageToDelete = null;
|
state.imageToDelete = null;
|
||||||
|
state.isModalOpen = false;
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
@ -0,0 +1,59 @@
|
|||||||
|
import { Flex } from '@chakra-ui/react';
|
||||||
|
import { useAppDispatch } from 'app/store/storeHooks';
|
||||||
|
import IAIIconButton from 'common/components/IAIIconButton';
|
||||||
|
import IAISlider from 'common/components/IAISlider';
|
||||||
|
import { memo, useCallback } from 'react';
|
||||||
|
import { FaTrash } from 'react-icons/fa';
|
||||||
|
import { Lora, loraRemoved, loraWeightChanged } from '../store/loraSlice';
|
||||||
|
|
||||||
|
type Props = {
|
||||||
|
lora: Lora;
|
||||||
|
};
|
||||||
|
|
||||||
|
const ParamLora = (props: Props) => {
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
const { lora } = props;
|
||||||
|
|
||||||
|
const handleChange = useCallback(
|
||||||
|
(v: number) => {
|
||||||
|
dispatch(loraWeightChanged({ id: lora.id, weight: v }));
|
||||||
|
},
|
||||||
|
[dispatch, lora.id]
|
||||||
|
);
|
||||||
|
|
||||||
|
const handleReset = useCallback(() => {
|
||||||
|
dispatch(loraWeightChanged({ id: lora.id, weight: 1 }));
|
||||||
|
}, [dispatch, lora.id]);
|
||||||
|
|
||||||
|
const handleRemoveLora = useCallback(() => {
|
||||||
|
dispatch(loraRemoved(lora.id));
|
||||||
|
}, [dispatch, lora.id]);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Flex sx={{ gap: 2.5, alignItems: 'flex-end' }}>
|
||||||
|
<IAISlider
|
||||||
|
label={lora.name}
|
||||||
|
value={lora.weight}
|
||||||
|
onChange={handleChange}
|
||||||
|
min={-1}
|
||||||
|
max={2}
|
||||||
|
step={0.01}
|
||||||
|
withInput
|
||||||
|
withReset
|
||||||
|
handleReset={handleReset}
|
||||||
|
withSliderMarks
|
||||||
|
sliderMarks={[-1, 0, 1, 2]}
|
||||||
|
/>
|
||||||
|
<IAIIconButton
|
||||||
|
size="sm"
|
||||||
|
onClick={handleRemoveLora}
|
||||||
|
tooltip="Remove LoRA"
|
||||||
|
aria-label="Remove LoRA"
|
||||||
|
icon={<FaTrash />}
|
||||||
|
colorScheme="error"
|
||||||
|
/>
|
||||||
|
</Flex>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default memo(ParamLora);
|
@ -0,0 +1,36 @@
|
|||||||
|
import { Flex } from '@chakra-ui/react';
|
||||||
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
|
import { stateSelector } from 'app/store/store';
|
||||||
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||||
|
import IAICollapse from 'common/components/IAICollapse';
|
||||||
|
import { size } from 'lodash-es';
|
||||||
|
import { memo } from 'react';
|
||||||
|
import ParamLoraList from './ParamLoraList';
|
||||||
|
import ParamLoraSelect from './ParamLoraSelect';
|
||||||
|
|
||||||
|
const selector = createSelector(
|
||||||
|
stateSelector,
|
||||||
|
(state) => {
|
||||||
|
const loraCount = size(state.lora.loras);
|
||||||
|
return {
|
||||||
|
activeLabel: loraCount > 0 ? `${loraCount} Active` : undefined,
|
||||||
|
};
|
||||||
|
},
|
||||||
|
defaultSelectorOptions
|
||||||
|
);
|
||||||
|
|
||||||
|
const ParamLoraCollapse = () => {
|
||||||
|
const { activeLabel } = useAppSelector(selector);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<IAICollapse label={'LoRA'} activeLabel={activeLabel}>
|
||||||
|
<Flex sx={{ flexDir: 'column', gap: 2 }}>
|
||||||
|
<ParamLoraSelect />
|
||||||
|
<ParamLoraList />
|
||||||
|
</Flex>
|
||||||
|
</IAICollapse>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default memo(ParamLoraCollapse);
|
@ -0,0 +1,24 @@
|
|||||||
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
|
import { stateSelector } from 'app/store/store';
|
||||||
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||||
|
import { map } from 'lodash-es';
|
||||||
|
import ParamLora from './ParamLora';
|
||||||
|
|
||||||
|
const selector = createSelector(
|
||||||
|
stateSelector,
|
||||||
|
({ lora }) => {
|
||||||
|
const { loras } = lora;
|
||||||
|
|
||||||
|
return { loras };
|
||||||
|
},
|
||||||
|
defaultSelectorOptions
|
||||||
|
);
|
||||||
|
|
||||||
|
const ParamLoraList = () => {
|
||||||
|
const { loras } = useAppSelector(selector);
|
||||||
|
|
||||||
|
return map(loras, (lora) => <ParamLora key={lora.name} lora={lora} />);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default ParamLoraList;
|
@ -0,0 +1,107 @@
|
|||||||
|
import { Text } from '@chakra-ui/react';
|
||||||
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
|
import { stateSelector } from 'app/store/store';
|
||||||
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||||
|
import IAIMantineMultiSelect from 'common/components/IAIMantineMultiSelect';
|
||||||
|
import { forEach } from 'lodash-es';
|
||||||
|
import { forwardRef, useCallback, useMemo } from 'react';
|
||||||
|
import { useGetLoRAModelsQuery } from 'services/api/endpoints/models';
|
||||||
|
import { loraAdded } from '../store/loraSlice';
|
||||||
|
|
||||||
|
type LoraSelectItem = {
|
||||||
|
label: string;
|
||||||
|
value: string;
|
||||||
|
description?: string;
|
||||||
|
};
|
||||||
|
|
||||||
|
const selector = createSelector(
|
||||||
|
stateSelector,
|
||||||
|
({ lora }) => ({
|
||||||
|
loras: lora.loras,
|
||||||
|
}),
|
||||||
|
defaultSelectorOptions
|
||||||
|
);
|
||||||
|
|
||||||
|
const ParamLoraSelect = () => {
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
const { loras } = useAppSelector(selector);
|
||||||
|
const { data: lorasQueryData } = useGetLoRAModelsQuery();
|
||||||
|
|
||||||
|
const data = useMemo(() => {
|
||||||
|
if (!lorasQueryData) {
|
||||||
|
return [];
|
||||||
|
}
|
||||||
|
|
||||||
|
const data: LoraSelectItem[] = [];
|
||||||
|
|
||||||
|
forEach(lorasQueryData.entities, (lora, id) => {
|
||||||
|
if (!lora || Boolean(id in loras)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
data.push({
|
||||||
|
value: id,
|
||||||
|
label: lora.name,
|
||||||
|
description: lora.description,
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
return data;
|
||||||
|
}, [loras, lorasQueryData]);
|
||||||
|
|
||||||
|
const handleChange = useCallback(
|
||||||
|
(v: string[]) => {
|
||||||
|
const loraEntity = lorasQueryData?.entities[v[0]];
|
||||||
|
if (!loraEntity) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
v[0] && dispatch(loraAdded(loraEntity));
|
||||||
|
},
|
||||||
|
[dispatch, lorasQueryData?.entities]
|
||||||
|
);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<IAIMantineMultiSelect
|
||||||
|
placeholder={data.length === 0 ? 'All LoRAs added' : 'Add LoRA'}
|
||||||
|
value={[]}
|
||||||
|
data={data}
|
||||||
|
maxDropdownHeight={400}
|
||||||
|
nothingFound="No matching LoRAs"
|
||||||
|
itemComponent={SelectItem}
|
||||||
|
disabled={data.length === 0}
|
||||||
|
filter={(value, selected, item: LoraSelectItem) =>
|
||||||
|
item.label.toLowerCase().includes(value.toLowerCase().trim()) ||
|
||||||
|
item.value.toLowerCase().includes(value.toLowerCase().trim())
|
||||||
|
}
|
||||||
|
onChange={handleChange}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
interface ItemProps extends React.ComponentPropsWithoutRef<'div'> {
|
||||||
|
value: string;
|
||||||
|
label: string;
|
||||||
|
description?: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
const SelectItem = forwardRef<HTMLDivElement, ItemProps>(
|
||||||
|
({ label, description, ...others }: ItemProps, ref) => {
|
||||||
|
return (
|
||||||
|
<div ref={ref} {...others}>
|
||||||
|
<div>
|
||||||
|
<Text>{label}</Text>
|
||||||
|
{description && (
|
||||||
|
<Text size="xs" color="base.600">
|
||||||
|
{description}
|
||||||
|
</Text>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
);
|
||||||
|
|
||||||
|
SelectItem.displayName = 'SelectItem';
|
||||||
|
|
||||||
|
export default ParamLoraSelect;
|
46
invokeai/frontend/web/src/features/lora/store/loraSlice.ts
Normal file
46
invokeai/frontend/web/src/features/lora/store/loraSlice.ts
Normal file
@ -0,0 +1,46 @@
|
|||||||
|
import { PayloadAction, createSlice } from '@reduxjs/toolkit';
|
||||||
|
import { LoRAModelConfigEntity } from 'services/api/endpoints/models';
|
||||||
|
|
||||||
|
export type Lora = {
|
||||||
|
id: string;
|
||||||
|
name: string;
|
||||||
|
weight: number;
|
||||||
|
};
|
||||||
|
|
||||||
|
export const defaultLoRAConfig: Omit<Lora, 'id' | 'name'> = {
|
||||||
|
weight: 1,
|
||||||
|
};
|
||||||
|
|
||||||
|
export type LoraState = {
|
||||||
|
loras: Record<string, Lora>;
|
||||||
|
};
|
||||||
|
|
||||||
|
export const intialLoraState: LoraState = {
|
||||||
|
loras: {},
|
||||||
|
};
|
||||||
|
|
||||||
|
export const loraSlice = createSlice({
|
||||||
|
name: 'lora',
|
||||||
|
initialState: intialLoraState,
|
||||||
|
reducers: {
|
||||||
|
loraAdded: (state, action: PayloadAction<LoRAModelConfigEntity>) => {
|
||||||
|
const { name, id } = action.payload;
|
||||||
|
state.loras[id] = { id, name, ...defaultLoRAConfig };
|
||||||
|
},
|
||||||
|
loraRemoved: (state, action: PayloadAction<string>) => {
|
||||||
|
const id = action.payload;
|
||||||
|
delete state.loras[id];
|
||||||
|
},
|
||||||
|
loraWeightChanged: (
|
||||||
|
state,
|
||||||
|
action: PayloadAction<{ id: string; weight: number }>
|
||||||
|
) => {
|
||||||
|
const { id, weight } = action.payload;
|
||||||
|
state.loras[id].weight = weight;
|
||||||
|
},
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
export const { loraAdded, loraRemoved, loraWeightChanged } = loraSlice.actions;
|
||||||
|
|
||||||
|
export default loraSlice.reducer;
|
@ -3,20 +3,22 @@ import { memo } from 'react';
|
|||||||
import { InputFieldTemplate, InputFieldValue } from '../types/types';
|
import { InputFieldTemplate, InputFieldValue } from '../types/types';
|
||||||
import ArrayInputFieldComponent from './fields/ArrayInputFieldComponent';
|
import ArrayInputFieldComponent from './fields/ArrayInputFieldComponent';
|
||||||
import BooleanInputFieldComponent from './fields/BooleanInputFieldComponent';
|
import BooleanInputFieldComponent from './fields/BooleanInputFieldComponent';
|
||||||
import EnumInputFieldComponent from './fields/EnumInputFieldComponent';
|
|
||||||
import ImageInputFieldComponent from './fields/ImageInputFieldComponent';
|
|
||||||
import LatentsInputFieldComponent from './fields/LatentsInputFieldComponent';
|
|
||||||
import ConditioningInputFieldComponent from './fields/ConditioningInputFieldComponent';
|
|
||||||
import UNetInputFieldComponent from './fields/UNetInputFieldComponent';
|
|
||||||
import ClipInputFieldComponent from './fields/ClipInputFieldComponent';
|
import ClipInputFieldComponent from './fields/ClipInputFieldComponent';
|
||||||
import VaeInputFieldComponent from './fields/VaeInputFieldComponent';
|
import ColorInputFieldComponent from './fields/ColorInputFieldComponent';
|
||||||
|
import ConditioningInputFieldComponent from './fields/ConditioningInputFieldComponent';
|
||||||
import ControlInputFieldComponent from './fields/ControlInputFieldComponent';
|
import ControlInputFieldComponent from './fields/ControlInputFieldComponent';
|
||||||
|
import EnumInputFieldComponent from './fields/EnumInputFieldComponent';
|
||||||
|
import ImageCollectionInputFieldComponent from './fields/ImageCollectionInputFieldComponent';
|
||||||
|
import ImageInputFieldComponent from './fields/ImageInputFieldComponent';
|
||||||
|
import ItemInputFieldComponent from './fields/ItemInputFieldComponent';
|
||||||
|
import LatentsInputFieldComponent from './fields/LatentsInputFieldComponent';
|
||||||
|
import LoRAModelInputFieldComponent from './fields/LoRAModelInputFieldComponent';
|
||||||
import ModelInputFieldComponent from './fields/ModelInputFieldComponent';
|
import ModelInputFieldComponent from './fields/ModelInputFieldComponent';
|
||||||
import NumberInputFieldComponent from './fields/NumberInputFieldComponent';
|
import NumberInputFieldComponent from './fields/NumberInputFieldComponent';
|
||||||
import StringInputFieldComponent from './fields/StringInputFieldComponent';
|
import StringInputFieldComponent from './fields/StringInputFieldComponent';
|
||||||
import ColorInputFieldComponent from './fields/ColorInputFieldComponent';
|
import UNetInputFieldComponent from './fields/UNetInputFieldComponent';
|
||||||
import ItemInputFieldComponent from './fields/ItemInputFieldComponent';
|
import VaeInputFieldComponent from './fields/VaeInputFieldComponent';
|
||||||
import ImageCollectionInputFieldComponent from './fields/ImageCollectionInputFieldComponent';
|
import VaeModelInputFieldComponent from './fields/VaeModelInputFieldComponent';
|
||||||
|
|
||||||
type InputFieldComponentProps = {
|
type InputFieldComponentProps = {
|
||||||
nodeId: string;
|
nodeId: string;
|
||||||
@ -152,6 +154,26 @@ const InputFieldComponent = (props: InputFieldComponentProps) => {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (type === 'vae_model' && template.type === 'vae_model') {
|
||||||
|
return (
|
||||||
|
<VaeModelInputFieldComponent
|
||||||
|
nodeId={nodeId}
|
||||||
|
field={field}
|
||||||
|
template={template}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (type === 'lora_model' && template.type === 'lora_model') {
|
||||||
|
return (
|
||||||
|
<LoRAModelInputFieldComponent
|
||||||
|
nodeId={nodeId}
|
||||||
|
field={field}
|
||||||
|
template={template}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
if (type === 'array' && template.type === 'array') {
|
if (type === 'array' && template.type === 'array') {
|
||||||
return (
|
return (
|
||||||
<ArrayInputFieldComponent
|
<ArrayInputFieldComponent
|
||||||
|
@ -7,18 +7,16 @@ import {
|
|||||||
} from 'features/nodes/types/types';
|
} from 'features/nodes/types/types';
|
||||||
import { memo, useCallback, useMemo } from 'react';
|
import { memo, useCallback, useMemo } from 'react';
|
||||||
|
|
||||||
import { FieldComponentProps } from './types';
|
|
||||||
import IAIDndImage from 'common/components/IAIDndImage';
|
|
||||||
import { ImageDTO } from 'services/api/types';
|
|
||||||
import { Flex } from '@chakra-ui/react';
|
import { Flex } from '@chakra-ui/react';
|
||||||
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
|
|
||||||
import { skipToken } from '@reduxjs/toolkit/dist/query';
|
import { skipToken } from '@reduxjs/toolkit/dist/query';
|
||||||
import {
|
import {
|
||||||
NodesImageDropData,
|
|
||||||
TypesafeDraggableData,
|
TypesafeDraggableData,
|
||||||
TypesafeDroppableData,
|
TypesafeDroppableData,
|
||||||
} from 'app/components/ImageDnd/typesafeDnd';
|
} from 'app/components/ImageDnd/typesafeDnd';
|
||||||
|
import IAIDndImage from 'common/components/IAIDndImage';
|
||||||
|
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
|
||||||
import { PostUploadAction } from 'services/api/thunks/image';
|
import { PostUploadAction } from 'services/api/thunks/image';
|
||||||
|
import { FieldComponentProps } from './types';
|
||||||
|
|
||||||
const ImageInputFieldComponent = (
|
const ImageInputFieldComponent = (
|
||||||
props: FieldComponentProps<ImageInputFieldValue, ImageInputFieldTemplate>
|
props: FieldComponentProps<ImageInputFieldValue, ImageInputFieldTemplate>
|
||||||
@ -34,23 +32,6 @@ const ImageInputFieldComponent = (
|
|||||||
isSuccess,
|
isSuccess,
|
||||||
} = useGetImageDTOQuery(field.value?.image_name ?? skipToken);
|
} = useGetImageDTOQuery(field.value?.image_name ?? skipToken);
|
||||||
|
|
||||||
const handleDrop = useCallback(
|
|
||||||
({ image_name }: ImageDTO) => {
|
|
||||||
if (field.value?.image_name === image_name) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
dispatch(
|
|
||||||
fieldValueChanged({
|
|
||||||
nodeId,
|
|
||||||
fieldName: field.name,
|
|
||||||
value: { image_name },
|
|
||||||
})
|
|
||||||
);
|
|
||||||
},
|
|
||||||
[dispatch, field.name, field.value, nodeId]
|
|
||||||
);
|
|
||||||
|
|
||||||
const handleReset = useCallback(() => {
|
const handleReset = useCallback(() => {
|
||||||
dispatch(
|
dispatch(
|
||||||
fieldValueChanged({
|
fieldValueChanged({
|
||||||
@ -71,15 +52,14 @@ const ImageInputFieldComponent = (
|
|||||||
}
|
}
|
||||||
}, [field.name, imageDTO, nodeId]);
|
}, [field.name, imageDTO, nodeId]);
|
||||||
|
|
||||||
const droppableData = useMemo<TypesafeDroppableData | undefined>(() => {
|
const droppableData = useMemo<TypesafeDroppableData | undefined>(
|
||||||
if (imageDTO) {
|
() => ({
|
||||||
return {
|
id: `node-${nodeId}-${field.name}`,
|
||||||
id: `node-${nodeId}-${field.name}`,
|
actionType: 'SET_NODES_IMAGE',
|
||||||
actionType: 'SET_NODES_IMAGE',
|
context: { nodeId, fieldName: field.name },
|
||||||
context: { nodeId, fieldName: field.name },
|
}),
|
||||||
};
|
[field.name, nodeId]
|
||||||
}
|
);
|
||||||
}, [field.name, imageDTO, nodeId]);
|
|
||||||
|
|
||||||
const postUploadAction = useMemo<PostUploadAction>(
|
const postUploadAction = useMemo<PostUploadAction>(
|
||||||
() => ({
|
() => ({
|
||||||
|
@ -0,0 +1,102 @@
|
|||||||
|
import { SelectItem } from '@mantine/core';
|
||||||
|
import { useAppDispatch } from 'app/store/storeHooks';
|
||||||
|
import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
||||||
|
import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
|
||||||
|
import {
|
||||||
|
VaeModelInputFieldTemplate,
|
||||||
|
VaeModelInputFieldValue,
|
||||||
|
} from 'features/nodes/types/types';
|
||||||
|
import { MODEL_TYPE_MAP as BASE_MODEL_NAME_MAP } from 'features/system/components/ModelSelect';
|
||||||
|
import { forEach, isString } from 'lodash-es';
|
||||||
|
import { memo, useCallback, useEffect, useMemo } from 'react';
|
||||||
|
import { useTranslation } from 'react-i18next';
|
||||||
|
import { useGetLoRAModelsQuery } from 'services/api/endpoints/models';
|
||||||
|
import { FieldComponentProps } from './types';
|
||||||
|
|
||||||
|
const LoRAModelInputFieldComponent = (
|
||||||
|
props: FieldComponentProps<
|
||||||
|
VaeModelInputFieldValue,
|
||||||
|
VaeModelInputFieldTemplate
|
||||||
|
>
|
||||||
|
) => {
|
||||||
|
const { nodeId, field } = props;
|
||||||
|
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
const { t } = useTranslation();
|
||||||
|
|
||||||
|
const { data: loraModels } = useGetLoRAModelsQuery();
|
||||||
|
|
||||||
|
const selectedModel = useMemo(
|
||||||
|
() => loraModels?.entities[field.value ?? loraModels.ids[0]],
|
||||||
|
[loraModels?.entities, loraModels?.ids, field.value]
|
||||||
|
);
|
||||||
|
|
||||||
|
const data = useMemo(() => {
|
||||||
|
if (!loraModels) {
|
||||||
|
return [];
|
||||||
|
}
|
||||||
|
|
||||||
|
const data: SelectItem[] = [];
|
||||||
|
|
||||||
|
forEach(loraModels.entities, (model, id) => {
|
||||||
|
if (!model) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
data.push({
|
||||||
|
value: id,
|
||||||
|
label: model.name,
|
||||||
|
group: BASE_MODEL_NAME_MAP[model.base_model],
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
return data;
|
||||||
|
}, [loraModels]);
|
||||||
|
|
||||||
|
const handleValueChanged = useCallback(
|
||||||
|
(v: string | null) => {
|
||||||
|
if (!v) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
dispatch(
|
||||||
|
fieldValueChanged({
|
||||||
|
nodeId,
|
||||||
|
fieldName: field.name,
|
||||||
|
value: v,
|
||||||
|
})
|
||||||
|
);
|
||||||
|
},
|
||||||
|
[dispatch, field.name, nodeId]
|
||||||
|
);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (field.value && loraModels?.ids.includes(field.value)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const firstLora = loraModels?.ids[0];
|
||||||
|
|
||||||
|
if (!isString(firstLora)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
handleValueChanged(firstLora);
|
||||||
|
}, [field.value, handleValueChanged, loraModels?.ids]);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<IAIMantineSelect
|
||||||
|
tooltip={selectedModel?.description}
|
||||||
|
label={
|
||||||
|
selectedModel?.base_model &&
|
||||||
|
BASE_MODEL_NAME_MAP[selectedModel?.base_model]
|
||||||
|
}
|
||||||
|
value={field.value}
|
||||||
|
placeholder="Pick one"
|
||||||
|
data={data}
|
||||||
|
onChange={handleValueChanged}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default memo(LoRAModelInputFieldComponent);
|
@ -6,13 +6,13 @@ import {
|
|||||||
ModelInputFieldValue,
|
ModelInputFieldValue,
|
||||||
} from 'features/nodes/types/types';
|
} from 'features/nodes/types/types';
|
||||||
|
|
||||||
import { memo, useCallback, useEffect, useMemo } from 'react';
|
|
||||||
import { FieldComponentProps } from './types';
|
|
||||||
import { forEach, isString } from 'lodash-es';
|
|
||||||
import { MODEL_TYPE_MAP as BASE_MODEL_NAME_MAP } from 'features/system/components/ModelSelect';
|
|
||||||
import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
||||||
|
import { MODEL_TYPE_MAP as BASE_MODEL_NAME_MAP } from 'features/system/components/ModelSelect';
|
||||||
|
import { forEach, isString } from 'lodash-es';
|
||||||
|
import { memo, useCallback, useEffect, useMemo } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import { useListModelsQuery } from 'services/api/endpoints/models';
|
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
|
||||||
|
import { FieldComponentProps } from './types';
|
||||||
|
|
||||||
const ModelInputFieldComponent = (
|
const ModelInputFieldComponent = (
|
||||||
props: FieldComponentProps<ModelInputFieldValue, ModelInputFieldTemplate>
|
props: FieldComponentProps<ModelInputFieldValue, ModelInputFieldTemplate>
|
||||||
@ -22,18 +22,16 @@ const ModelInputFieldComponent = (
|
|||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
|
|
||||||
const { data: pipelineModels } = useListModelsQuery({
|
const { data: mainModels } = useGetMainModelsQuery();
|
||||||
model_type: 'main',
|
|
||||||
});
|
|
||||||
|
|
||||||
const data = useMemo(() => {
|
const data = useMemo(() => {
|
||||||
if (!pipelineModels) {
|
if (!mainModels) {
|
||||||
return [];
|
return [];
|
||||||
}
|
}
|
||||||
|
|
||||||
const data: SelectItem[] = [];
|
const data: SelectItem[] = [];
|
||||||
|
|
||||||
forEach(pipelineModels.entities, (model, id) => {
|
forEach(mainModels.entities, (model, id) => {
|
||||||
if (!model) {
|
if (!model) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@ -46,11 +44,11 @@ const ModelInputFieldComponent = (
|
|||||||
});
|
});
|
||||||
|
|
||||||
return data;
|
return data;
|
||||||
}, [pipelineModels]);
|
}, [mainModels]);
|
||||||
|
|
||||||
const selectedModel = useMemo(
|
const selectedModel = useMemo(
|
||||||
() => pipelineModels?.entities[field.value ?? pipelineModels.ids[0]],
|
() => mainModels?.entities[field.value ?? mainModels.ids[0]],
|
||||||
[pipelineModels?.entities, pipelineModels?.ids, field.value]
|
[mainModels?.entities, mainModels?.ids, field.value]
|
||||||
);
|
);
|
||||||
|
|
||||||
const handleValueChanged = useCallback(
|
const handleValueChanged = useCallback(
|
||||||
@ -71,18 +69,18 @@ const ModelInputFieldComponent = (
|
|||||||
);
|
);
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (field.value && pipelineModels?.ids.includes(field.value)) {
|
if (field.value && mainModels?.ids.includes(field.value)) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const firstModel = pipelineModels?.ids[0];
|
const firstModel = mainModels?.ids[0];
|
||||||
|
|
||||||
if (!isString(firstModel)) {
|
if (!isString(firstModel)) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
handleValueChanged(firstModel);
|
handleValueChanged(firstModel);
|
||||||
}, [field.value, handleValueChanged, pipelineModels?.ids]);
|
}, [field.value, handleValueChanged, mainModels?.ids]);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<IAIMantineSelect
|
<IAIMantineSelect
|
||||||
|
@ -0,0 +1,95 @@
|
|||||||
|
import { SelectItem } from '@mantine/core';
|
||||||
|
import { useAppDispatch } from 'app/store/storeHooks';
|
||||||
|
import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
||||||
|
import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
|
||||||
|
import {
|
||||||
|
VaeModelInputFieldTemplate,
|
||||||
|
VaeModelInputFieldValue,
|
||||||
|
} from 'features/nodes/types/types';
|
||||||
|
import { MODEL_TYPE_MAP as BASE_MODEL_NAME_MAP } from 'features/system/components/ModelSelect';
|
||||||
|
import { forEach } from 'lodash-es';
|
||||||
|
import { memo, useCallback, useEffect, useMemo } from 'react';
|
||||||
|
import { useTranslation } from 'react-i18next';
|
||||||
|
import { useGetVaeModelsQuery } from 'services/api/endpoints/models';
|
||||||
|
import { FieldComponentProps } from './types';
|
||||||
|
|
||||||
|
const VaeModelInputFieldComponent = (
|
||||||
|
props: FieldComponentProps<
|
||||||
|
VaeModelInputFieldValue,
|
||||||
|
VaeModelInputFieldTemplate
|
||||||
|
>
|
||||||
|
) => {
|
||||||
|
const { nodeId, field } = props;
|
||||||
|
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
const { t } = useTranslation();
|
||||||
|
|
||||||
|
const { data: vaeModels } = useGetVaeModelsQuery();
|
||||||
|
|
||||||
|
const selectedModel = useMemo(
|
||||||
|
() => vaeModels?.entities[field.value ?? vaeModels.ids[0]],
|
||||||
|
[vaeModels?.entities, vaeModels?.ids, field.value]
|
||||||
|
);
|
||||||
|
|
||||||
|
const data = useMemo(() => {
|
||||||
|
if (!vaeModels) {
|
||||||
|
return [];
|
||||||
|
}
|
||||||
|
|
||||||
|
const data: SelectItem[] = [];
|
||||||
|
|
||||||
|
forEach(vaeModels.entities, (model, id) => {
|
||||||
|
if (!model) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
data.push({
|
||||||
|
value: id,
|
||||||
|
label: model.name,
|
||||||
|
group: BASE_MODEL_NAME_MAP[model.base_model],
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
return data;
|
||||||
|
}, [vaeModels]);
|
||||||
|
|
||||||
|
const handleValueChanged = useCallback(
|
||||||
|
(v: string | null) => {
|
||||||
|
if (!v) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
dispatch(
|
||||||
|
fieldValueChanged({
|
||||||
|
nodeId,
|
||||||
|
fieldName: field.name,
|
||||||
|
value: v,
|
||||||
|
})
|
||||||
|
);
|
||||||
|
},
|
||||||
|
[dispatch, field.name, nodeId]
|
||||||
|
);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (field.value && vaeModels?.ids.includes(field.value)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
handleValueChanged('auto');
|
||||||
|
}, [field.value, handleValueChanged, vaeModels?.ids]);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<IAIMantineSelect
|
||||||
|
tooltip={selectedModel?.description}
|
||||||
|
label={
|
||||||
|
selectedModel?.base_model &&
|
||||||
|
BASE_MODEL_NAME_MAP[selectedModel?.base_model]
|
||||||
|
}
|
||||||
|
value={field.value}
|
||||||
|
placeholder="Pick one"
|
||||||
|
data={data}
|
||||||
|
onChange={handleValueChanged}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default memo(VaeModelInputFieldComponent);
|
@ -1,5 +1,8 @@
|
|||||||
import { createSlice, PayloadAction } from '@reduxjs/toolkit';
|
import { createSlice, PayloadAction } from '@reduxjs/toolkit';
|
||||||
|
import { RootState } from 'app/store/store';
|
||||||
|
import { cloneDeep, uniqBy } from 'lodash-es';
|
||||||
import { OpenAPIV3 } from 'openapi-types';
|
import { OpenAPIV3 } from 'openapi-types';
|
||||||
|
import { RgbaColor } from 'react-colorful';
|
||||||
import {
|
import {
|
||||||
addEdge,
|
addEdge,
|
||||||
applyEdgeChanges,
|
applyEdgeChanges,
|
||||||
@ -11,12 +14,9 @@ import {
|
|||||||
NodeChange,
|
NodeChange,
|
||||||
OnConnectStartParams,
|
OnConnectStartParams,
|
||||||
} from 'reactflow';
|
} from 'reactflow';
|
||||||
import { ImageField } from 'services/api/types';
|
|
||||||
import { receivedOpenAPISchema } from 'services/api/thunks/schema';
|
import { receivedOpenAPISchema } from 'services/api/thunks/schema';
|
||||||
|
import { ImageField } from 'services/api/types';
|
||||||
import { InvocationTemplate, InvocationValue } from '../types/types';
|
import { InvocationTemplate, InvocationValue } from '../types/types';
|
||||||
import { RgbaColor } from 'react-colorful';
|
|
||||||
import { RootState } from 'app/store/store';
|
|
||||||
import { cloneDeep, isArray, uniq, uniqBy } from 'lodash-es';
|
|
||||||
|
|
||||||
export type NodesState = {
|
export type NodesState = {
|
||||||
nodes: Node<InvocationValue>[];
|
nodes: Node<InvocationValue>[];
|
||||||
|
@ -17,6 +17,8 @@ export const FIELD_TYPE_MAP: Record<string, FieldType> = {
|
|||||||
ClipField: 'clip',
|
ClipField: 'clip',
|
||||||
VaeField: 'vae',
|
VaeField: 'vae',
|
||||||
model: 'model',
|
model: 'model',
|
||||||
|
vae_model: 'vae_model',
|
||||||
|
lora_model: 'lora_model',
|
||||||
array: 'array',
|
array: 'array',
|
||||||
item: 'item',
|
item: 'item',
|
||||||
ColorField: 'color',
|
ColorField: 'color',
|
||||||
@ -116,6 +118,18 @@ export const FIELDS: Record<FieldType, FieldUIConfig> = {
|
|||||||
title: 'Model',
|
title: 'Model',
|
||||||
description: 'Models are models.',
|
description: 'Models are models.',
|
||||||
},
|
},
|
||||||
|
vae_model: {
|
||||||
|
color: 'teal',
|
||||||
|
colorCssVar: getColorTokenCssVariable('teal'),
|
||||||
|
title: 'VAE',
|
||||||
|
description: 'Models are models.',
|
||||||
|
},
|
||||||
|
lora_model: {
|
||||||
|
color: 'teal',
|
||||||
|
colorCssVar: getColorTokenCssVariable('teal'),
|
||||||
|
title: 'LoRA',
|
||||||
|
description: 'Models are models.',
|
||||||
|
},
|
||||||
array: {
|
array: {
|
||||||
color: 'gray',
|
color: 'gray',
|
||||||
colorCssVar: getColorTokenCssVariable('gray'),
|
colorCssVar: getColorTokenCssVariable('gray'),
|
||||||
|
@ -64,6 +64,8 @@ export type FieldType =
|
|||||||
| 'vae'
|
| 'vae'
|
||||||
| 'control'
|
| 'control'
|
||||||
| 'model'
|
| 'model'
|
||||||
|
| 'vae_model'
|
||||||
|
| 'lora_model'
|
||||||
| 'array'
|
| 'array'
|
||||||
| 'item'
|
| 'item'
|
||||||
| 'color'
|
| 'color'
|
||||||
@ -91,6 +93,8 @@ export type InputFieldValue =
|
|||||||
| ControlInputFieldValue
|
| ControlInputFieldValue
|
||||||
| EnumInputFieldValue
|
| EnumInputFieldValue
|
||||||
| ModelInputFieldValue
|
| ModelInputFieldValue
|
||||||
|
| VaeModelInputFieldValue
|
||||||
|
| LoRAModelInputFieldValue
|
||||||
| ArrayInputFieldValue
|
| ArrayInputFieldValue
|
||||||
| ItemInputFieldValue
|
| ItemInputFieldValue
|
||||||
| ColorInputFieldValue
|
| ColorInputFieldValue
|
||||||
@ -116,6 +120,8 @@ export type InputFieldTemplate =
|
|||||||
| ControlInputFieldTemplate
|
| ControlInputFieldTemplate
|
||||||
| EnumInputFieldTemplate
|
| EnumInputFieldTemplate
|
||||||
| ModelInputFieldTemplate
|
| ModelInputFieldTemplate
|
||||||
|
| VaeModelInputFieldTemplate
|
||||||
|
| LoRAModelInputFieldTemplate
|
||||||
| ArrayInputFieldTemplate
|
| ArrayInputFieldTemplate
|
||||||
| ItemInputFieldTemplate
|
| ItemInputFieldTemplate
|
||||||
| ColorInputFieldTemplate
|
| ColorInputFieldTemplate
|
||||||
@ -228,6 +234,16 @@ export type ModelInputFieldValue = FieldValueBase & {
|
|||||||
value?: string;
|
value?: string;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
export type VaeModelInputFieldValue = FieldValueBase & {
|
||||||
|
type: 'vae_model';
|
||||||
|
value?: string;
|
||||||
|
};
|
||||||
|
|
||||||
|
export type LoRAModelInputFieldValue = FieldValueBase & {
|
||||||
|
type: 'lora_model';
|
||||||
|
value?: string;
|
||||||
|
};
|
||||||
|
|
||||||
export type ArrayInputFieldValue = FieldValueBase & {
|
export type ArrayInputFieldValue = FieldValueBase & {
|
||||||
type: 'array';
|
type: 'array';
|
||||||
value?: (string | number)[];
|
value?: (string | number)[];
|
||||||
@ -305,6 +321,21 @@ export type ConditioningInputFieldTemplate = InputFieldTemplateBase & {
|
|||||||
type: 'conditioning';
|
type: 'conditioning';
|
||||||
};
|
};
|
||||||
|
|
||||||
|
export type UNetInputFieldTemplate = InputFieldTemplateBase & {
|
||||||
|
default: undefined;
|
||||||
|
type: 'unet';
|
||||||
|
};
|
||||||
|
|
||||||
|
export type ClipInputFieldTemplate = InputFieldTemplateBase & {
|
||||||
|
default: undefined;
|
||||||
|
type: 'clip';
|
||||||
|
};
|
||||||
|
|
||||||
|
export type VaeInputFieldTemplate = InputFieldTemplateBase & {
|
||||||
|
default: undefined;
|
||||||
|
type: 'vae';
|
||||||
|
};
|
||||||
|
|
||||||
export type ControlInputFieldTemplate = InputFieldTemplateBase & {
|
export type ControlInputFieldTemplate = InputFieldTemplateBase & {
|
||||||
default: undefined;
|
default: undefined;
|
||||||
type: 'control';
|
type: 'control';
|
||||||
@ -322,6 +353,16 @@ export type ModelInputFieldTemplate = InputFieldTemplateBase & {
|
|||||||
type: 'model';
|
type: 'model';
|
||||||
};
|
};
|
||||||
|
|
||||||
|
export type VaeModelInputFieldTemplate = InputFieldTemplateBase & {
|
||||||
|
default: string;
|
||||||
|
type: 'vae_model';
|
||||||
|
};
|
||||||
|
|
||||||
|
export type LoRAModelInputFieldTemplate = InputFieldTemplateBase & {
|
||||||
|
default: string;
|
||||||
|
type: 'lora_model';
|
||||||
|
};
|
||||||
|
|
||||||
export type ArrayInputFieldTemplate = InputFieldTemplateBase & {
|
export type ArrayInputFieldTemplate = InputFieldTemplateBase & {
|
||||||
default: [];
|
default: [];
|
||||||
type: 'array';
|
type: 'array';
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
import { RootState } from 'app/store/store';
|
import { RootState } from 'app/store/store';
|
||||||
import { filter } from 'lodash-es';
|
import { getValidControlNets } from 'features/controlNet/util/getValidControlNets';
|
||||||
import { CollectInvocation, ControlNetInvocation } from 'services/api/types';
|
import { CollectInvocation, ControlNetInvocation } from 'services/api/types';
|
||||||
import { NonNullableGraph } from '../types/types';
|
import { NonNullableGraph } from '../types/types';
|
||||||
import { CONTROL_NET_COLLECT } from './graphBuilders/constants';
|
import { CONTROL_NET_COLLECT } from './graphBuilders/constants';
|
||||||
@ -11,13 +11,7 @@ export const addControlNetToLinearGraph = (
|
|||||||
): void => {
|
): void => {
|
||||||
const { isEnabled: isControlNetEnabled, controlNets } = state.controlNet;
|
const { isEnabled: isControlNetEnabled, controlNets } = state.controlNet;
|
||||||
|
|
||||||
const validControlNets = filter(
|
const validControlNets = getValidControlNets(controlNets);
|
||||||
controlNets,
|
|
||||||
(c) =>
|
|
||||||
c.isEnabled &&
|
|
||||||
(Boolean(c.processedControlImage) ||
|
|
||||||
(c.processorType === 'none' && Boolean(c.controlImage)))
|
|
||||||
);
|
|
||||||
|
|
||||||
if (isControlNetEnabled && Boolean(validControlNets.length)) {
|
if (isControlNetEnabled && Boolean(validControlNets.length)) {
|
||||||
if (validControlNets.length > 1) {
|
if (validControlNets.length > 1) {
|
||||||
|
@ -3,27 +3,29 @@ import { OpenAPIV3 } from 'openapi-types';
|
|||||||
import { FIELD_TYPE_MAP } from '../types/constants';
|
import { FIELD_TYPE_MAP } from '../types/constants';
|
||||||
import { isSchemaObject } from '../types/typeGuards';
|
import { isSchemaObject } from '../types/typeGuards';
|
||||||
import {
|
import {
|
||||||
BooleanInputFieldTemplate,
|
|
||||||
EnumInputFieldTemplate,
|
|
||||||
FloatInputFieldTemplate,
|
|
||||||
ImageInputFieldTemplate,
|
|
||||||
IntegerInputFieldTemplate,
|
|
||||||
LatentsInputFieldTemplate,
|
|
||||||
ConditioningInputFieldTemplate,
|
|
||||||
UNetInputFieldTemplate,
|
|
||||||
ClipInputFieldTemplate,
|
|
||||||
VaeInputFieldTemplate,
|
|
||||||
ControlInputFieldTemplate,
|
|
||||||
StringInputFieldTemplate,
|
|
||||||
ModelInputFieldTemplate,
|
|
||||||
ArrayInputFieldTemplate,
|
ArrayInputFieldTemplate,
|
||||||
ItemInputFieldTemplate,
|
BooleanInputFieldTemplate,
|
||||||
|
ClipInputFieldTemplate,
|
||||||
ColorInputFieldTemplate,
|
ColorInputFieldTemplate,
|
||||||
InputFieldTemplateBase,
|
ConditioningInputFieldTemplate,
|
||||||
OutputFieldTemplate,
|
ControlInputFieldTemplate,
|
||||||
TypeHints,
|
EnumInputFieldTemplate,
|
||||||
FieldType,
|
FieldType,
|
||||||
|
FloatInputFieldTemplate,
|
||||||
ImageCollectionInputFieldTemplate,
|
ImageCollectionInputFieldTemplate,
|
||||||
|
ImageInputFieldTemplate,
|
||||||
|
InputFieldTemplateBase,
|
||||||
|
IntegerInputFieldTemplate,
|
||||||
|
ItemInputFieldTemplate,
|
||||||
|
LatentsInputFieldTemplate,
|
||||||
|
LoRAModelInputFieldTemplate,
|
||||||
|
ModelInputFieldTemplate,
|
||||||
|
OutputFieldTemplate,
|
||||||
|
StringInputFieldTemplate,
|
||||||
|
TypeHints,
|
||||||
|
UNetInputFieldTemplate,
|
||||||
|
VaeInputFieldTemplate,
|
||||||
|
VaeModelInputFieldTemplate,
|
||||||
} from '../types/types';
|
} from '../types/types';
|
||||||
|
|
||||||
export type BaseFieldProperties = 'name' | 'title' | 'description';
|
export type BaseFieldProperties = 'name' | 'title' | 'description';
|
||||||
@ -175,6 +177,36 @@ const buildModelInputFieldTemplate = ({
|
|||||||
return template;
|
return template;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const buildVaeModelInputFieldTemplate = ({
|
||||||
|
schemaObject,
|
||||||
|
baseField,
|
||||||
|
}: BuildInputFieldArg): VaeModelInputFieldTemplate => {
|
||||||
|
const template: VaeModelInputFieldTemplate = {
|
||||||
|
...baseField,
|
||||||
|
type: 'vae_model',
|
||||||
|
inputRequirement: 'always',
|
||||||
|
inputKind: 'direct',
|
||||||
|
default: schemaObject.default ?? undefined,
|
||||||
|
};
|
||||||
|
|
||||||
|
return template;
|
||||||
|
};
|
||||||
|
|
||||||
|
const buildLoRAModelInputFieldTemplate = ({
|
||||||
|
schemaObject,
|
||||||
|
baseField,
|
||||||
|
}: BuildInputFieldArg): LoRAModelInputFieldTemplate => {
|
||||||
|
const template: LoRAModelInputFieldTemplate = {
|
||||||
|
...baseField,
|
||||||
|
type: 'lora_model',
|
||||||
|
inputRequirement: 'always',
|
||||||
|
inputKind: 'direct',
|
||||||
|
default: schemaObject.default ?? undefined,
|
||||||
|
};
|
||||||
|
|
||||||
|
return template;
|
||||||
|
};
|
||||||
|
|
||||||
const buildImageInputFieldTemplate = ({
|
const buildImageInputFieldTemplate = ({
|
||||||
schemaObject,
|
schemaObject,
|
||||||
baseField,
|
baseField,
|
||||||
@ -441,6 +473,12 @@ export const buildInputFieldTemplate = (
|
|||||||
if (['model'].includes(fieldType)) {
|
if (['model'].includes(fieldType)) {
|
||||||
return buildModelInputFieldTemplate({ schemaObject, baseField });
|
return buildModelInputFieldTemplate({ schemaObject, baseField });
|
||||||
}
|
}
|
||||||
|
if (['vae_model'].includes(fieldType)) {
|
||||||
|
return buildVaeModelInputFieldTemplate({ schemaObject, baseField });
|
||||||
|
}
|
||||||
|
if (['lora_model'].includes(fieldType)) {
|
||||||
|
return buildLoRAModelInputFieldTemplate({ schemaObject, baseField });
|
||||||
|
}
|
||||||
if (['enum'].includes(fieldType)) {
|
if (['enum'].includes(fieldType)) {
|
||||||
return buildEnumInputFieldTemplate({ schemaObject, baseField });
|
return buildEnumInputFieldTemplate({ schemaObject, baseField });
|
||||||
}
|
}
|
||||||
|
@ -75,6 +75,14 @@ export const buildInputFieldValue = (
|
|||||||
if (template.type === 'model') {
|
if (template.type === 'model') {
|
||||||
fieldValue.value = undefined;
|
fieldValue.value = undefined;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (template.type === 'vae_model') {
|
||||||
|
fieldValue.value = undefined;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (template.type === 'lora_model') {
|
||||||
|
fieldValue.value = undefined;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return fieldValue;
|
return fieldValue;
|
||||||
|
@ -0,0 +1,148 @@
|
|||||||
|
import { RootState } from 'app/store/store';
|
||||||
|
import { NonNullableGraph } from 'features/nodes/types/types';
|
||||||
|
import { forEach, size } from 'lodash-es';
|
||||||
|
import { LoraLoaderInvocation } from 'services/api/types';
|
||||||
|
import { modelIdToLoRAModelField } from '../modelIdToLoRAName';
|
||||||
|
import {
|
||||||
|
LORA_LOADER,
|
||||||
|
MAIN_MODEL_LOADER,
|
||||||
|
NEGATIVE_CONDITIONING,
|
||||||
|
POSITIVE_CONDITIONING,
|
||||||
|
} from './constants';
|
||||||
|
|
||||||
|
export const addLoRAsToGraph = (
|
||||||
|
graph: NonNullableGraph,
|
||||||
|
state: RootState,
|
||||||
|
baseNodeId: string
|
||||||
|
): void => {
|
||||||
|
/**
|
||||||
|
* LoRA nodes get the UNet and CLIP models from the main model loader and apply the LoRA to them.
|
||||||
|
* They then output the UNet and CLIP models references on to either the next LoRA in the chain,
|
||||||
|
* or to the inference/conditioning nodes.
|
||||||
|
*
|
||||||
|
* So we need to inject a LoRA chain into the graph.
|
||||||
|
*/
|
||||||
|
|
||||||
|
const { loras } = state.lora;
|
||||||
|
const loraCount = size(loras);
|
||||||
|
|
||||||
|
if (loraCount > 0) {
|
||||||
|
// remove any existing connections from main model loader, we need to insert the lora nodes
|
||||||
|
graph.edges = graph.edges.filter(
|
||||||
|
(e) =>
|
||||||
|
!(
|
||||||
|
e.source.node_id === MAIN_MODEL_LOADER &&
|
||||||
|
['unet', 'clip'].includes(e.source.field)
|
||||||
|
)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// we need to remember the last lora so we can chain from it
|
||||||
|
let lastLoraNodeId = '';
|
||||||
|
let currentLoraIndex = 0;
|
||||||
|
|
||||||
|
forEach(loras, (lora) => {
|
||||||
|
const { id, name, weight } = lora;
|
||||||
|
const loraField = modelIdToLoRAModelField(id);
|
||||||
|
const currentLoraNodeId = `${LORA_LOADER}_${loraField.model_name.replace(
|
||||||
|
'.',
|
||||||
|
'_'
|
||||||
|
)}`;
|
||||||
|
|
||||||
|
const loraLoaderNode: LoraLoaderInvocation = {
|
||||||
|
type: 'lora_loader',
|
||||||
|
id: currentLoraNodeId,
|
||||||
|
lora: loraField,
|
||||||
|
weight,
|
||||||
|
};
|
||||||
|
|
||||||
|
graph.nodes[currentLoraNodeId] = loraLoaderNode;
|
||||||
|
|
||||||
|
if (currentLoraIndex === 0) {
|
||||||
|
// first lora = start the lora chain, attach directly to model loader
|
||||||
|
graph.edges.push({
|
||||||
|
source: {
|
||||||
|
node_id: MAIN_MODEL_LOADER,
|
||||||
|
field: 'unet',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: currentLoraNodeId,
|
||||||
|
field: 'unet',
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
graph.edges.push({
|
||||||
|
source: {
|
||||||
|
node_id: MAIN_MODEL_LOADER,
|
||||||
|
field: 'clip',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: currentLoraNodeId,
|
||||||
|
field: 'clip',
|
||||||
|
},
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
// we are in the middle of the lora chain, instead connect to the previous lora
|
||||||
|
graph.edges.push({
|
||||||
|
source: {
|
||||||
|
node_id: lastLoraNodeId,
|
||||||
|
field: 'unet',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: currentLoraNodeId,
|
||||||
|
field: 'unet',
|
||||||
|
},
|
||||||
|
});
|
||||||
|
graph.edges.push({
|
||||||
|
source: {
|
||||||
|
node_id: lastLoraNodeId,
|
||||||
|
field: 'clip',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: currentLoraNodeId,
|
||||||
|
field: 'clip',
|
||||||
|
},
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
if (currentLoraIndex === loraCount - 1) {
|
||||||
|
// final lora, end the lora chain - we need to connect up to inference and conditioning nodes
|
||||||
|
graph.edges.push({
|
||||||
|
source: {
|
||||||
|
node_id: currentLoraNodeId,
|
||||||
|
field: 'unet',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: baseNodeId,
|
||||||
|
field: 'unet',
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
graph.edges.push({
|
||||||
|
source: {
|
||||||
|
node_id: currentLoraNodeId,
|
||||||
|
field: 'clip',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: POSITIVE_CONDITIONING,
|
||||||
|
field: 'clip',
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
graph.edges.push({
|
||||||
|
source: {
|
||||||
|
node_id: currentLoraNodeId,
|
||||||
|
field: 'clip',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: NEGATIVE_CONDITIONING,
|
||||||
|
field: 'clip',
|
||||||
|
},
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
// increment the lora for the next one in the chain
|
||||||
|
lastLoraNodeId = currentLoraNodeId;
|
||||||
|
currentLoraIndex += 1;
|
||||||
|
});
|
||||||
|
};
|
@ -0,0 +1,68 @@
|
|||||||
|
import { RootState } from 'app/store/store';
|
||||||
|
import { NonNullableGraph } from 'features/nodes/types/types';
|
||||||
|
import { modelIdToVAEModelField } from '../modelIdToVAEModelField';
|
||||||
|
import {
|
||||||
|
IMAGE_TO_IMAGE_GRAPH,
|
||||||
|
IMAGE_TO_LATENTS,
|
||||||
|
INPAINT,
|
||||||
|
INPAINT_GRAPH,
|
||||||
|
LATENTS_TO_IMAGE,
|
||||||
|
MAIN_MODEL_LOADER,
|
||||||
|
TEXT_TO_IMAGE_GRAPH,
|
||||||
|
VAE_LOADER,
|
||||||
|
} from './constants';
|
||||||
|
|
||||||
|
export const addVAEToGraph = (
|
||||||
|
graph: NonNullableGraph,
|
||||||
|
state: RootState
|
||||||
|
): void => {
|
||||||
|
const { vae: vaeId } = state.generation;
|
||||||
|
const vae_model = modelIdToVAEModelField(vaeId);
|
||||||
|
|
||||||
|
if (vaeId !== 'auto') {
|
||||||
|
graph.nodes[VAE_LOADER] = {
|
||||||
|
type: 'vae_loader',
|
||||||
|
id: VAE_LOADER,
|
||||||
|
vae_model,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
if (graph.id === TEXT_TO_IMAGE_GRAPH || graph.id === IMAGE_TO_IMAGE_GRAPH) {
|
||||||
|
graph.edges.push({
|
||||||
|
source: {
|
||||||
|
node_id: vaeId === 'auto' ? MAIN_MODEL_LOADER : VAE_LOADER,
|
||||||
|
field: 'vae',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: LATENTS_TO_IMAGE,
|
||||||
|
field: 'vae',
|
||||||
|
},
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
if (graph.id === IMAGE_TO_IMAGE_GRAPH) {
|
||||||
|
graph.edges.push({
|
||||||
|
source: {
|
||||||
|
node_id: vaeId === 'auto' ? MAIN_MODEL_LOADER : VAE_LOADER,
|
||||||
|
field: 'vae',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: IMAGE_TO_LATENTS,
|
||||||
|
field: 'vae',
|
||||||
|
},
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
if (graph.id === INPAINT_GRAPH) {
|
||||||
|
graph.edges.push({
|
||||||
|
source: {
|
||||||
|
node_id: vaeId === 'auto' ? MAIN_MODEL_LOADER : VAE_LOADER,
|
||||||
|
field: 'vae',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: INPAINT,
|
||||||
|
field: 'vae',
|
||||||
|
},
|
||||||
|
});
|
||||||
|
}
|
||||||
|
};
|
@ -1,31 +1,27 @@
|
|||||||
|
import { log } from 'app/logging/useLogger';
|
||||||
import { RootState } from 'app/store/store';
|
import { RootState } from 'app/store/store';
|
||||||
|
import { NonNullableGraph } from 'features/nodes/types/types';
|
||||||
import {
|
import {
|
||||||
ImageDTO,
|
ImageDTO,
|
||||||
ImageResizeInvocation,
|
ImageResizeInvocation,
|
||||||
ImageToLatentsInvocation,
|
ImageToLatentsInvocation,
|
||||||
RandomIntInvocation,
|
|
||||||
RangeOfSizeInvocation,
|
|
||||||
} from 'services/api/types';
|
} from 'services/api/types';
|
||||||
import { NonNullableGraph } from 'features/nodes/types/types';
|
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
|
||||||
import { log } from 'app/logging/useLogger';
|
import { modelIdToMainModelField } from '../modelIdToMainModelField';
|
||||||
|
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
|
||||||
|
import { addLoRAsToGraph } from './addLoRAsToGraph';
|
||||||
|
import { addVAEToGraph } from './addVAEToGraph';
|
||||||
import {
|
import {
|
||||||
ITERATE,
|
IMAGE_TO_IMAGE_GRAPH,
|
||||||
|
IMAGE_TO_LATENTS,
|
||||||
LATENTS_TO_IMAGE,
|
LATENTS_TO_IMAGE,
|
||||||
PIPELINE_MODEL_LOADER,
|
LATENTS_TO_LATENTS,
|
||||||
|
MAIN_MODEL_LOADER,
|
||||||
NEGATIVE_CONDITIONING,
|
NEGATIVE_CONDITIONING,
|
||||||
NOISE,
|
NOISE,
|
||||||
POSITIVE_CONDITIONING,
|
POSITIVE_CONDITIONING,
|
||||||
RANDOM_INT,
|
|
||||||
RANGE_OF_SIZE,
|
|
||||||
IMAGE_TO_IMAGE_GRAPH,
|
|
||||||
IMAGE_TO_LATENTS,
|
|
||||||
LATENTS_TO_LATENTS,
|
|
||||||
RESIZE,
|
RESIZE,
|
||||||
} from './constants';
|
} from './constants';
|
||||||
import { set } from 'lodash-es';
|
|
||||||
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
|
|
||||||
import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField';
|
|
||||||
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
|
|
||||||
|
|
||||||
const moduleLog = log.child({ namespace: 'nodes' });
|
const moduleLog = log.child({ namespace: 'nodes' });
|
||||||
|
|
||||||
@ -52,7 +48,7 @@ export const buildCanvasImageToImageGraph = (
|
|||||||
// The bounding box determines width and height, not the width and height params
|
// The bounding box determines width and height, not the width and height params
|
||||||
const { width, height } = state.canvas.boundingBoxDimensions;
|
const { width, height } = state.canvas.boundingBoxDimensions;
|
||||||
|
|
||||||
const model = modelIdToPipelineModelField(modelId);
|
const model = modelIdToMainModelField(modelId);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The easiest way to build linear graphs is to do it in the node editor, then copy and paste the
|
* The easiest way to build linear graphs is to do it in the node editor, then copy and paste the
|
||||||
@ -81,9 +77,9 @@ export const buildCanvasImageToImageGraph = (
|
|||||||
type: 'noise',
|
type: 'noise',
|
||||||
id: NOISE,
|
id: NOISE,
|
||||||
},
|
},
|
||||||
[PIPELINE_MODEL_LOADER]: {
|
[MAIN_MODEL_LOADER]: {
|
||||||
type: 'pipeline_model_loader',
|
type: 'main_model_loader',
|
||||||
id: PIPELINE_MODEL_LOADER,
|
id: MAIN_MODEL_LOADER,
|
||||||
model,
|
model,
|
||||||
},
|
},
|
||||||
[LATENTS_TO_IMAGE]: {
|
[LATENTS_TO_IMAGE]: {
|
||||||
@ -110,7 +106,7 @@ export const buildCanvasImageToImageGraph = (
|
|||||||
edges: [
|
edges: [
|
||||||
{
|
{
|
||||||
source: {
|
source: {
|
||||||
node_id: PIPELINE_MODEL_LOADER,
|
node_id: MAIN_MODEL_LOADER,
|
||||||
field: 'clip',
|
field: 'clip',
|
||||||
},
|
},
|
||||||
destination: {
|
destination: {
|
||||||
@ -120,7 +116,7 @@ export const buildCanvasImageToImageGraph = (
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
source: {
|
source: {
|
||||||
node_id: PIPELINE_MODEL_LOADER,
|
node_id: MAIN_MODEL_LOADER,
|
||||||
field: 'clip',
|
field: 'clip',
|
||||||
},
|
},
|
||||||
destination: {
|
destination: {
|
||||||
@ -128,16 +124,6 @@ export const buildCanvasImageToImageGraph = (
|
|||||||
field: 'clip',
|
field: 'clip',
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
|
||||||
source: {
|
|
||||||
node_id: PIPELINE_MODEL_LOADER,
|
|
||||||
field: 'vae',
|
|
||||||
},
|
|
||||||
destination: {
|
|
||||||
node_id: LATENTS_TO_IMAGE,
|
|
||||||
field: 'vae',
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
source: {
|
source: {
|
||||||
node_id: LATENTS_TO_LATENTS,
|
node_id: LATENTS_TO_LATENTS,
|
||||||
@ -170,17 +156,7 @@ export const buildCanvasImageToImageGraph = (
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
source: {
|
source: {
|
||||||
node_id: PIPELINE_MODEL_LOADER,
|
node_id: MAIN_MODEL_LOADER,
|
||||||
field: 'vae',
|
|
||||||
},
|
|
||||||
destination: {
|
|
||||||
node_id: IMAGE_TO_LATENTS,
|
|
||||||
field: 'vae',
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
source: {
|
|
||||||
node_id: PIPELINE_MODEL_LOADER,
|
|
||||||
field: 'unet',
|
field: 'unet',
|
||||||
},
|
},
|
||||||
destination: {
|
destination: {
|
||||||
@ -277,6 +253,11 @@ export const buildCanvasImageToImageGraph = (
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
addLoRAsToGraph(graph, state, LATENTS_TO_LATENTS);
|
||||||
|
|
||||||
|
// Add VAE
|
||||||
|
addVAEToGraph(graph, state);
|
||||||
|
|
||||||
// add dynamic prompts, mutating `graph`
|
// add dynamic prompts, mutating `graph`
|
||||||
addDynamicPromptsToGraph(graph, state);
|
addDynamicPromptsToGraph(graph, state);
|
||||||
|
|
||||||
|
@ -1,23 +1,25 @@
|
|||||||
|
import { log } from 'app/logging/useLogger';
|
||||||
import { RootState } from 'app/store/store';
|
import { RootState } from 'app/store/store';
|
||||||
|
import { NonNullableGraph } from 'features/nodes/types/types';
|
||||||
import {
|
import {
|
||||||
ImageDTO,
|
ImageDTO,
|
||||||
InpaintInvocation,
|
InpaintInvocation,
|
||||||
RandomIntInvocation,
|
RandomIntInvocation,
|
||||||
RangeOfSizeInvocation,
|
RangeOfSizeInvocation,
|
||||||
} from 'services/api/types';
|
} from 'services/api/types';
|
||||||
import { NonNullableGraph } from 'features/nodes/types/types';
|
import { modelIdToMainModelField } from '../modelIdToMainModelField';
|
||||||
import { log } from 'app/logging/useLogger';
|
import { addLoRAsToGraph } from './addLoRAsToGraph';
|
||||||
|
import { addVAEToGraph } from './addVAEToGraph';
|
||||||
import {
|
import {
|
||||||
|
INPAINT,
|
||||||
|
INPAINT_GRAPH,
|
||||||
ITERATE,
|
ITERATE,
|
||||||
PIPELINE_MODEL_LOADER,
|
MAIN_MODEL_LOADER,
|
||||||
NEGATIVE_CONDITIONING,
|
NEGATIVE_CONDITIONING,
|
||||||
POSITIVE_CONDITIONING,
|
POSITIVE_CONDITIONING,
|
||||||
RANDOM_INT,
|
RANDOM_INT,
|
||||||
RANGE_OF_SIZE,
|
RANGE_OF_SIZE,
|
||||||
INPAINT_GRAPH,
|
|
||||||
INPAINT,
|
|
||||||
} from './constants';
|
} from './constants';
|
||||||
import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField';
|
|
||||||
|
|
||||||
const moduleLog = log.child({ namespace: 'nodes' });
|
const moduleLog = log.child({ namespace: 'nodes' });
|
||||||
|
|
||||||
@ -55,7 +57,7 @@ export const buildCanvasInpaintGraph = (
|
|||||||
// We may need to set the inpaint width and height to scale the image
|
// We may need to set the inpaint width and height to scale the image
|
||||||
const { scaledBoundingBoxDimensions, boundingBoxScaleMethod } = state.canvas;
|
const { scaledBoundingBoxDimensions, boundingBoxScaleMethod } = state.canvas;
|
||||||
|
|
||||||
const model = modelIdToPipelineModelField(modelId);
|
const model = modelIdToMainModelField(modelId);
|
||||||
|
|
||||||
const graph: NonNullableGraph = {
|
const graph: NonNullableGraph = {
|
||||||
id: INPAINT_GRAPH,
|
id: INPAINT_GRAPH,
|
||||||
@ -101,9 +103,9 @@ export const buildCanvasInpaintGraph = (
|
|||||||
id: NEGATIVE_CONDITIONING,
|
id: NEGATIVE_CONDITIONING,
|
||||||
prompt: negativePrompt,
|
prompt: negativePrompt,
|
||||||
},
|
},
|
||||||
[PIPELINE_MODEL_LOADER]: {
|
[MAIN_MODEL_LOADER]: {
|
||||||
type: 'pipeline_model_loader',
|
type: 'main_model_loader',
|
||||||
id: PIPELINE_MODEL_LOADER,
|
id: MAIN_MODEL_LOADER,
|
||||||
model,
|
model,
|
||||||
},
|
},
|
||||||
[RANGE_OF_SIZE]: {
|
[RANGE_OF_SIZE]: {
|
||||||
@ -142,7 +144,7 @@ export const buildCanvasInpaintGraph = (
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
source: {
|
source: {
|
||||||
node_id: PIPELINE_MODEL_LOADER,
|
node_id: MAIN_MODEL_LOADER,
|
||||||
field: 'clip',
|
field: 'clip',
|
||||||
},
|
},
|
||||||
destination: {
|
destination: {
|
||||||
@ -152,7 +154,7 @@ export const buildCanvasInpaintGraph = (
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
source: {
|
source: {
|
||||||
node_id: PIPELINE_MODEL_LOADER,
|
node_id: MAIN_MODEL_LOADER,
|
||||||
field: 'clip',
|
field: 'clip',
|
||||||
},
|
},
|
||||||
destination: {
|
destination: {
|
||||||
@ -162,7 +164,7 @@ export const buildCanvasInpaintGraph = (
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
source: {
|
source: {
|
||||||
node_id: PIPELINE_MODEL_LOADER,
|
node_id: MAIN_MODEL_LOADER,
|
||||||
field: 'unet',
|
field: 'unet',
|
||||||
},
|
},
|
||||||
destination: {
|
destination: {
|
||||||
@ -170,16 +172,6 @@ export const buildCanvasInpaintGraph = (
|
|||||||
field: 'unet',
|
field: 'unet',
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
|
||||||
source: {
|
|
||||||
node_id: PIPELINE_MODEL_LOADER,
|
|
||||||
field: 'vae',
|
|
||||||
},
|
|
||||||
destination: {
|
|
||||||
node_id: INPAINT,
|
|
||||||
field: 'vae',
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
source: {
|
source: {
|
||||||
node_id: RANGE_OF_SIZE,
|
node_id: RANGE_OF_SIZE,
|
||||||
@ -203,6 +195,11 @@ export const buildCanvasInpaintGraph = (
|
|||||||
],
|
],
|
||||||
};
|
};
|
||||||
|
|
||||||
|
addLoRAsToGraph(graph, state, INPAINT);
|
||||||
|
|
||||||
|
// Add VAE
|
||||||
|
addVAEToGraph(graph, state);
|
||||||
|
|
||||||
// handle seed
|
// handle seed
|
||||||
if (shouldRandomizeSeed) {
|
if (shouldRandomizeSeed) {
|
||||||
// Random int node to generate the starting seed
|
// Random int node to generate the starting seed
|
||||||
|
@ -1,21 +1,19 @@
|
|||||||
import { RootState } from 'app/store/store';
|
import { RootState } from 'app/store/store';
|
||||||
import { NonNullableGraph } from 'features/nodes/types/types';
|
import { NonNullableGraph } from 'features/nodes/types/types';
|
||||||
import { RandomIntInvocation, RangeOfSizeInvocation } from 'services/api/types';
|
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
|
||||||
|
import { modelIdToMainModelField } from '../modelIdToMainModelField';
|
||||||
|
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
|
||||||
|
import { addLoRAsToGraph } from './addLoRAsToGraph';
|
||||||
|
import { addVAEToGraph } from './addVAEToGraph';
|
||||||
import {
|
import {
|
||||||
ITERATE,
|
|
||||||
LATENTS_TO_IMAGE,
|
LATENTS_TO_IMAGE,
|
||||||
PIPELINE_MODEL_LOADER,
|
MAIN_MODEL_LOADER,
|
||||||
NEGATIVE_CONDITIONING,
|
NEGATIVE_CONDITIONING,
|
||||||
NOISE,
|
NOISE,
|
||||||
POSITIVE_CONDITIONING,
|
POSITIVE_CONDITIONING,
|
||||||
RANDOM_INT,
|
|
||||||
RANGE_OF_SIZE,
|
|
||||||
TEXT_TO_IMAGE_GRAPH,
|
TEXT_TO_IMAGE_GRAPH,
|
||||||
TEXT_TO_LATENTS,
|
TEXT_TO_LATENTS,
|
||||||
} from './constants';
|
} from './constants';
|
||||||
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
|
|
||||||
import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField';
|
|
||||||
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Builds the Canvas tab's Text to Image graph.
|
* Builds the Canvas tab's Text to Image graph.
|
||||||
@ -38,7 +36,7 @@ export const buildCanvasTextToImageGraph = (
|
|||||||
// The bounding box determines width and height, not the width and height params
|
// The bounding box determines width and height, not the width and height params
|
||||||
const { width, height } = state.canvas.boundingBoxDimensions;
|
const { width, height } = state.canvas.boundingBoxDimensions;
|
||||||
|
|
||||||
const model = modelIdToPipelineModelField(modelId);
|
const model = modelIdToMainModelField(modelId);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The easiest way to build linear graphs is to do it in the node editor, then copy and paste the
|
* The easiest way to build linear graphs is to do it in the node editor, then copy and paste the
|
||||||
@ -76,9 +74,9 @@ export const buildCanvasTextToImageGraph = (
|
|||||||
scheduler,
|
scheduler,
|
||||||
steps,
|
steps,
|
||||||
},
|
},
|
||||||
[PIPELINE_MODEL_LOADER]: {
|
[MAIN_MODEL_LOADER]: {
|
||||||
type: 'pipeline_model_loader',
|
type: 'main_model_loader',
|
||||||
id: PIPELINE_MODEL_LOADER,
|
id: MAIN_MODEL_LOADER,
|
||||||
model,
|
model,
|
||||||
},
|
},
|
||||||
[LATENTS_TO_IMAGE]: {
|
[LATENTS_TO_IMAGE]: {
|
||||||
@ -109,7 +107,7 @@ export const buildCanvasTextToImageGraph = (
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
source: {
|
source: {
|
||||||
node_id: PIPELINE_MODEL_LOADER,
|
node_id: MAIN_MODEL_LOADER,
|
||||||
field: 'clip',
|
field: 'clip',
|
||||||
},
|
},
|
||||||
destination: {
|
destination: {
|
||||||
@ -119,7 +117,7 @@ export const buildCanvasTextToImageGraph = (
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
source: {
|
source: {
|
||||||
node_id: PIPELINE_MODEL_LOADER,
|
node_id: MAIN_MODEL_LOADER,
|
||||||
field: 'clip',
|
field: 'clip',
|
||||||
},
|
},
|
||||||
destination: {
|
destination: {
|
||||||
@ -129,7 +127,7 @@ export const buildCanvasTextToImageGraph = (
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
source: {
|
source: {
|
||||||
node_id: PIPELINE_MODEL_LOADER,
|
node_id: MAIN_MODEL_LOADER,
|
||||||
field: 'unet',
|
field: 'unet',
|
||||||
},
|
},
|
||||||
destination: {
|
destination: {
|
||||||
@ -147,16 +145,6 @@ export const buildCanvasTextToImageGraph = (
|
|||||||
field: 'latents',
|
field: 'latents',
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
|
||||||
source: {
|
|
||||||
node_id: PIPELINE_MODEL_LOADER,
|
|
||||||
field: 'vae',
|
|
||||||
},
|
|
||||||
destination: {
|
|
||||||
node_id: LATENTS_TO_IMAGE,
|
|
||||||
field: 'vae',
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
source: {
|
source: {
|
||||||
node_id: NOISE,
|
node_id: NOISE,
|
||||||
@ -170,6 +158,11 @@ export const buildCanvasTextToImageGraph = (
|
|||||||
],
|
],
|
||||||
};
|
};
|
||||||
|
|
||||||
|
addLoRAsToGraph(graph, state, TEXT_TO_LATENTS);
|
||||||
|
|
||||||
|
// Add VAE
|
||||||
|
addVAEToGraph(graph, state);
|
||||||
|
|
||||||
// add dynamic prompts, mutating `graph`
|
// add dynamic prompts, mutating `graph`
|
||||||
addDynamicPromptsToGraph(graph, state);
|
addDynamicPromptsToGraph(graph, state);
|
||||||
|
|
||||||
|
@ -1,28 +1,30 @@
|
|||||||
|
import { log } from 'app/logging/useLogger';
|
||||||
import { RootState } from 'app/store/store';
|
import { RootState } from 'app/store/store';
|
||||||
|
import { NonNullableGraph } from 'features/nodes/types/types';
|
||||||
import {
|
import {
|
||||||
ImageCollectionInvocation,
|
ImageCollectionInvocation,
|
||||||
ImageResizeInvocation,
|
ImageResizeInvocation,
|
||||||
ImageToLatentsInvocation,
|
ImageToLatentsInvocation,
|
||||||
IterateInvocation,
|
IterateInvocation,
|
||||||
} from 'services/api/types';
|
} from 'services/api/types';
|
||||||
import { NonNullableGraph } from 'features/nodes/types/types';
|
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
|
||||||
import { log } from 'app/logging/useLogger';
|
import { modelIdToMainModelField } from '../modelIdToMainModelField';
|
||||||
|
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
|
||||||
|
import { addLoRAsToGraph } from './addLoRAsToGraph';
|
||||||
|
import { addVAEToGraph } from './addVAEToGraph';
|
||||||
import {
|
import {
|
||||||
|
IMAGE_COLLECTION,
|
||||||
|
IMAGE_COLLECTION_ITERATE,
|
||||||
|
IMAGE_TO_IMAGE_GRAPH,
|
||||||
|
IMAGE_TO_LATENTS,
|
||||||
LATENTS_TO_IMAGE,
|
LATENTS_TO_IMAGE,
|
||||||
PIPELINE_MODEL_LOADER,
|
LATENTS_TO_LATENTS,
|
||||||
|
MAIN_MODEL_LOADER,
|
||||||
NEGATIVE_CONDITIONING,
|
NEGATIVE_CONDITIONING,
|
||||||
NOISE,
|
NOISE,
|
||||||
POSITIVE_CONDITIONING,
|
POSITIVE_CONDITIONING,
|
||||||
IMAGE_TO_IMAGE_GRAPH,
|
|
||||||
IMAGE_TO_LATENTS,
|
|
||||||
LATENTS_TO_LATENTS,
|
|
||||||
RESIZE,
|
RESIZE,
|
||||||
IMAGE_COLLECTION,
|
|
||||||
IMAGE_COLLECTION_ITERATE,
|
|
||||||
} from './constants';
|
} from './constants';
|
||||||
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
|
|
||||||
import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField';
|
|
||||||
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
|
|
||||||
|
|
||||||
const moduleLog = log.child({ namespace: 'nodes' });
|
const moduleLog = log.child({ namespace: 'nodes' });
|
||||||
|
|
||||||
@ -69,7 +71,7 @@ export const buildLinearImageToImageGraph = (
|
|||||||
throw new Error('No initial image found in state');
|
throw new Error('No initial image found in state');
|
||||||
}
|
}
|
||||||
|
|
||||||
const model = modelIdToPipelineModelField(modelId);
|
const model = modelIdToMainModelField(modelId);
|
||||||
|
|
||||||
// copy-pasted graph from node editor, filled in with state values & friendly node ids
|
// copy-pasted graph from node editor, filled in with state values & friendly node ids
|
||||||
const graph: NonNullableGraph = {
|
const graph: NonNullableGraph = {
|
||||||
@ -89,9 +91,9 @@ export const buildLinearImageToImageGraph = (
|
|||||||
type: 'noise',
|
type: 'noise',
|
||||||
id: NOISE,
|
id: NOISE,
|
||||||
},
|
},
|
||||||
[PIPELINE_MODEL_LOADER]: {
|
[MAIN_MODEL_LOADER]: {
|
||||||
type: 'pipeline_model_loader',
|
type: 'main_model_loader',
|
||||||
id: PIPELINE_MODEL_LOADER,
|
id: MAIN_MODEL_LOADER,
|
||||||
model,
|
model,
|
||||||
},
|
},
|
||||||
[LATENTS_TO_IMAGE]: {
|
[LATENTS_TO_IMAGE]: {
|
||||||
@ -118,7 +120,7 @@ export const buildLinearImageToImageGraph = (
|
|||||||
edges: [
|
edges: [
|
||||||
{
|
{
|
||||||
source: {
|
source: {
|
||||||
node_id: PIPELINE_MODEL_LOADER,
|
node_id: MAIN_MODEL_LOADER,
|
||||||
field: 'clip',
|
field: 'clip',
|
||||||
},
|
},
|
||||||
destination: {
|
destination: {
|
||||||
@ -128,7 +130,7 @@ export const buildLinearImageToImageGraph = (
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
source: {
|
source: {
|
||||||
node_id: PIPELINE_MODEL_LOADER,
|
node_id: MAIN_MODEL_LOADER,
|
||||||
field: 'clip',
|
field: 'clip',
|
||||||
},
|
},
|
||||||
destination: {
|
destination: {
|
||||||
@ -136,16 +138,6 @@ export const buildLinearImageToImageGraph = (
|
|||||||
field: 'clip',
|
field: 'clip',
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
|
||||||
source: {
|
|
||||||
node_id: PIPELINE_MODEL_LOADER,
|
|
||||||
field: 'vae',
|
|
||||||
},
|
|
||||||
destination: {
|
|
||||||
node_id: LATENTS_TO_IMAGE,
|
|
||||||
field: 'vae',
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
source: {
|
source: {
|
||||||
node_id: LATENTS_TO_LATENTS,
|
node_id: LATENTS_TO_LATENTS,
|
||||||
@ -176,19 +168,10 @@ export const buildLinearImageToImageGraph = (
|
|||||||
field: 'noise',
|
field: 'noise',
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
||||||
{
|
{
|
||||||
source: {
|
source: {
|
||||||
node_id: PIPELINE_MODEL_LOADER,
|
node_id: MAIN_MODEL_LOADER,
|
||||||
field: 'vae',
|
|
||||||
},
|
|
||||||
destination: {
|
|
||||||
node_id: IMAGE_TO_LATENTS,
|
|
||||||
field: 'vae',
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
source: {
|
|
||||||
node_id: PIPELINE_MODEL_LOADER,
|
|
||||||
field: 'unet',
|
field: 'unet',
|
||||||
},
|
},
|
||||||
destination: {
|
destination: {
|
||||||
@ -323,6 +306,11 @@ export const buildLinearImageToImageGraph = (
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
addLoRAsToGraph(graph, state, LATENTS_TO_LATENTS);
|
||||||
|
|
||||||
|
// Add VAE
|
||||||
|
addVAEToGraph(graph, state);
|
||||||
|
|
||||||
// add dynamic prompts, mutating `graph`
|
// add dynamic prompts, mutating `graph`
|
||||||
addDynamicPromptsToGraph(graph, state);
|
addDynamicPromptsToGraph(graph, state);
|
||||||
|
|
||||||
|
@ -1,17 +1,19 @@
|
|||||||
import { RootState } from 'app/store/store';
|
import { RootState } from 'app/store/store';
|
||||||
import { NonNullableGraph } from 'features/nodes/types/types';
|
import { NonNullableGraph } from 'features/nodes/types/types';
|
||||||
|
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
|
||||||
|
import { modelIdToMainModelField } from '../modelIdToMainModelField';
|
||||||
|
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
|
||||||
|
import { addLoRAsToGraph } from './addLoRAsToGraph';
|
||||||
|
import { addVAEToGraph } from './addVAEToGraph';
|
||||||
import {
|
import {
|
||||||
LATENTS_TO_IMAGE,
|
LATENTS_TO_IMAGE,
|
||||||
PIPELINE_MODEL_LOADER,
|
MAIN_MODEL_LOADER,
|
||||||
NEGATIVE_CONDITIONING,
|
NEGATIVE_CONDITIONING,
|
||||||
NOISE,
|
NOISE,
|
||||||
POSITIVE_CONDITIONING,
|
POSITIVE_CONDITIONING,
|
||||||
TEXT_TO_IMAGE_GRAPH,
|
TEXT_TO_IMAGE_GRAPH,
|
||||||
TEXT_TO_LATENTS,
|
TEXT_TO_LATENTS,
|
||||||
} from './constants';
|
} from './constants';
|
||||||
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
|
|
||||||
import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField';
|
|
||||||
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
|
|
||||||
|
|
||||||
export const buildLinearTextToImageGraph = (
|
export const buildLinearTextToImageGraph = (
|
||||||
state: RootState
|
state: RootState
|
||||||
@ -27,7 +29,7 @@ export const buildLinearTextToImageGraph = (
|
|||||||
height,
|
height,
|
||||||
} = state.generation;
|
} = state.generation;
|
||||||
|
|
||||||
const model = modelIdToPipelineModelField(modelId);
|
const model = modelIdToMainModelField(modelId);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The easiest way to build linear graphs is to do it in the node editor, then copy and paste the
|
* The easiest way to build linear graphs is to do it in the node editor, then copy and paste the
|
||||||
@ -65,9 +67,9 @@ export const buildLinearTextToImageGraph = (
|
|||||||
scheduler,
|
scheduler,
|
||||||
steps,
|
steps,
|
||||||
},
|
},
|
||||||
[PIPELINE_MODEL_LOADER]: {
|
[MAIN_MODEL_LOADER]: {
|
||||||
type: 'pipeline_model_loader',
|
type: 'main_model_loader',
|
||||||
id: PIPELINE_MODEL_LOADER,
|
id: MAIN_MODEL_LOADER,
|
||||||
model,
|
model,
|
||||||
},
|
},
|
||||||
[LATENTS_TO_IMAGE]: {
|
[LATENTS_TO_IMAGE]: {
|
||||||
@ -98,7 +100,7 @@ export const buildLinearTextToImageGraph = (
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
source: {
|
source: {
|
||||||
node_id: PIPELINE_MODEL_LOADER,
|
node_id: MAIN_MODEL_LOADER,
|
||||||
field: 'clip',
|
field: 'clip',
|
||||||
},
|
},
|
||||||
destination: {
|
destination: {
|
||||||
@ -108,7 +110,7 @@ export const buildLinearTextToImageGraph = (
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
source: {
|
source: {
|
||||||
node_id: PIPELINE_MODEL_LOADER,
|
node_id: MAIN_MODEL_LOADER,
|
||||||
field: 'clip',
|
field: 'clip',
|
||||||
},
|
},
|
||||||
destination: {
|
destination: {
|
||||||
@ -118,7 +120,7 @@ export const buildLinearTextToImageGraph = (
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
source: {
|
source: {
|
||||||
node_id: PIPELINE_MODEL_LOADER,
|
node_id: MAIN_MODEL_LOADER,
|
||||||
field: 'unet',
|
field: 'unet',
|
||||||
},
|
},
|
||||||
destination: {
|
destination: {
|
||||||
@ -136,16 +138,6 @@ export const buildLinearTextToImageGraph = (
|
|||||||
field: 'latents',
|
field: 'latents',
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
|
||||||
source: {
|
|
||||||
node_id: PIPELINE_MODEL_LOADER,
|
|
||||||
field: 'vae',
|
|
||||||
},
|
|
||||||
destination: {
|
|
||||||
node_id: LATENTS_TO_IMAGE,
|
|
||||||
field: 'vae',
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
source: {
|
source: {
|
||||||
node_id: NOISE,
|
node_id: NOISE,
|
||||||
@ -159,6 +151,11 @@ export const buildLinearTextToImageGraph = (
|
|||||||
],
|
],
|
||||||
};
|
};
|
||||||
|
|
||||||
|
addLoRAsToGraph(graph, state, TEXT_TO_LATENTS);
|
||||||
|
|
||||||
|
// Add Custom VAE Support
|
||||||
|
addVAEToGraph(graph, state);
|
||||||
|
|
||||||
// add dynamic prompts, mutating `graph`
|
// add dynamic prompts, mutating `graph`
|
||||||
addDynamicPromptsToGraph(graph, state);
|
addDynamicPromptsToGraph(graph, state);
|
||||||
|
|
||||||
|
@ -1,10 +1,12 @@
|
|||||||
import { Graph } from 'services/api/types';
|
|
||||||
import { v4 as uuidv4 } from 'uuid';
|
|
||||||
import { cloneDeep, omit, reduce } from 'lodash-es';
|
|
||||||
import { RootState } from 'app/store/store';
|
import { RootState } from 'app/store/store';
|
||||||
import { InputFieldValue } from 'features/nodes/types/types';
|
import { InputFieldValue } from 'features/nodes/types/types';
|
||||||
|
import { cloneDeep, omit, reduce } from 'lodash-es';
|
||||||
|
import { Graph } from 'services/api/types';
|
||||||
import { AnyInvocation } from 'services/events/types';
|
import { AnyInvocation } from 'services/events/types';
|
||||||
import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField';
|
import { v4 as uuidv4 } from 'uuid';
|
||||||
|
import { modelIdToLoRAModelField } from '../modelIdToLoRAName';
|
||||||
|
import { modelIdToMainModelField } from '../modelIdToMainModelField';
|
||||||
|
import { modelIdToVAEModelField } from '../modelIdToVAEModelField';
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* We need to do special handling for some fields
|
* We need to do special handling for some fields
|
||||||
@ -27,7 +29,19 @@ export const parseFieldValue = (field: InputFieldValue) => {
|
|||||||
|
|
||||||
if (field.type === 'model') {
|
if (field.type === 'model') {
|
||||||
if (field.value) {
|
if (field.value) {
|
||||||
return modelIdToPipelineModelField(field.value);
|
return modelIdToMainModelField(field.value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (field.type === 'vae_model') {
|
||||||
|
if (field.value) {
|
||||||
|
return modelIdToVAEModelField(field.value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (field.type === 'lora_model') {
|
||||||
|
if (field.value) {
|
||||||
|
return modelIdToLoRAModelField(field.value);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -7,7 +7,9 @@ export const NOISE = 'noise';
|
|||||||
export const RANDOM_INT = 'rand_int';
|
export const RANDOM_INT = 'rand_int';
|
||||||
export const RANGE_OF_SIZE = 'range_of_size';
|
export const RANGE_OF_SIZE = 'range_of_size';
|
||||||
export const ITERATE = 'iterate';
|
export const ITERATE = 'iterate';
|
||||||
export const PIPELINE_MODEL_LOADER = 'pipeline_model_loader';
|
export const MAIN_MODEL_LOADER = 'main_model_loader';
|
||||||
|
export const VAE_LOADER = 'vae_loader';
|
||||||
|
export const LORA_LOADER = 'lora_loader';
|
||||||
export const IMAGE_TO_LATENTS = 'image_to_latents';
|
export const IMAGE_TO_LATENTS = 'image_to_latents';
|
||||||
export const LATENTS_TO_LATENTS = 'latents_to_latents';
|
export const LATENTS_TO_LATENTS = 'latents_to_latents';
|
||||||
export const RESIZE = 'resize_image';
|
export const RESIZE = 'resize_image';
|
||||||
|
@ -0,0 +1,12 @@
|
|||||||
|
import { BaseModelType, LoRAModelField } from 'services/api/types';
|
||||||
|
|
||||||
|
export const modelIdToLoRAModelField = (loraId: string): LoRAModelField => {
|
||||||
|
const [base_model, model_type, model_name] = loraId.split('/');
|
||||||
|
|
||||||
|
const field: LoRAModelField = {
|
||||||
|
base_model: base_model as BaseModelType,
|
||||||
|
model_name,
|
||||||
|
};
|
||||||
|
|
||||||
|
return field;
|
||||||
|
};
|
@ -0,0 +1,16 @@
|
|||||||
|
import { BaseModelType, MainModelField } from 'services/api/types';
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Crudely converts a model id to a main model field
|
||||||
|
* TODO: Make better
|
||||||
|
*/
|
||||||
|
export const modelIdToMainModelField = (modelId: string): MainModelField => {
|
||||||
|
const [base_model, model_type, model_name] = modelId.split('/');
|
||||||
|
|
||||||
|
const field: MainModelField = {
|
||||||
|
base_model: base_model as BaseModelType,
|
||||||
|
model_name,
|
||||||
|
};
|
||||||
|
|
||||||
|
return field;
|
||||||
|
};
|
@ -1,18 +0,0 @@
|
|||||||
import { BaseModelType, PipelineModelField } from 'services/api/types';
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Crudely converts a model id to a pipeline model field
|
|
||||||
* TODO: Make better
|
|
||||||
*/
|
|
||||||
export const modelIdToPipelineModelField = (
|
|
||||||
modelId: string
|
|
||||||
): PipelineModelField => {
|
|
||||||
const [base_model, model_type, model_name] = modelId.split('/');
|
|
||||||
|
|
||||||
const field: PipelineModelField = {
|
|
||||||
base_model: base_model as BaseModelType,
|
|
||||||
model_name,
|
|
||||||
};
|
|
||||||
|
|
||||||
return field;
|
|
||||||
};
|
|
@ -0,0 +1,16 @@
|
|||||||
|
import { BaseModelType, VAEModelField } from 'services/api/types';
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Crudely converts a model id to a main model field
|
||||||
|
* TODO: Make better
|
||||||
|
*/
|
||||||
|
export const modelIdToVAEModelField = (modelId: string): VAEModelField => {
|
||||||
|
const [base_model, model_type, model_name] = modelId.split('/');
|
||||||
|
|
||||||
|
const field: VAEModelField = {
|
||||||
|
base_model: base_model as BaseModelType,
|
||||||
|
model_name,
|
||||||
|
};
|
||||||
|
|
||||||
|
return field;
|
||||||
|
};
|
@ -1,20 +1,15 @@
|
|||||||
import { Flex, useDisclosure } from '@chakra-ui/react';
|
import { Flex } from '@chakra-ui/react';
|
||||||
import { useTranslation } from 'react-i18next';
|
|
||||||
import IAICollapse from 'common/components/IAICollapse';
|
import IAICollapse from 'common/components/IAICollapse';
|
||||||
import { memo } from 'react';
|
import { memo } from 'react';
|
||||||
import ParamBoundingBoxWidth from './ParamBoundingBoxWidth';
|
import { useTranslation } from 'react-i18next';
|
||||||
import ParamBoundingBoxHeight from './ParamBoundingBoxHeight';
|
import ParamBoundingBoxHeight from './ParamBoundingBoxHeight';
|
||||||
|
import ParamBoundingBoxWidth from './ParamBoundingBoxWidth';
|
||||||
|
|
||||||
const ParamBoundingBoxCollapse = () => {
|
const ParamBoundingBoxCollapse = () => {
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
const { isOpen, onToggle } = useDisclosure();
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<IAICollapse
|
<IAICollapse label={t('parameters.boundingBoxHeader')}>
|
||||||
label={t('parameters.boundingBoxHeader')}
|
|
||||||
isOpen={isOpen}
|
|
||||||
onToggle={onToggle}
|
|
||||||
>
|
|
||||||
<Flex sx={{ gap: 2, flexDirection: 'column' }}>
|
<Flex sx={{ gap: 2, flexDirection: 'column' }}>
|
||||||
<ParamBoundingBoxWidth />
|
<ParamBoundingBoxWidth />
|
||||||
<ParamBoundingBoxHeight />
|
<ParamBoundingBoxHeight />
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
import { Flex, useDisclosure } from '@chakra-ui/react';
|
import { Flex } from '@chakra-ui/react';
|
||||||
import { memo } from 'react';
|
import { memo } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
|
|
||||||
@ -6,19 +6,14 @@ import IAICollapse from 'common/components/IAICollapse';
|
|||||||
import ParamInfillMethod from './ParamInfillMethod';
|
import ParamInfillMethod from './ParamInfillMethod';
|
||||||
import ParamInfillTilesize from './ParamInfillTilesize';
|
import ParamInfillTilesize from './ParamInfillTilesize';
|
||||||
import ParamScaleBeforeProcessing from './ParamScaleBeforeProcessing';
|
import ParamScaleBeforeProcessing from './ParamScaleBeforeProcessing';
|
||||||
import ParamScaledWidth from './ParamScaledWidth';
|
|
||||||
import ParamScaledHeight from './ParamScaledHeight';
|
import ParamScaledHeight from './ParamScaledHeight';
|
||||||
|
import ParamScaledWidth from './ParamScaledWidth';
|
||||||
|
|
||||||
const ParamInfillCollapse = () => {
|
const ParamInfillCollapse = () => {
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
const { isOpen, onToggle } = useDisclosure();
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<IAICollapse
|
<IAICollapse label={t('parameters.infillScalingHeader')}>
|
||||||
label={t('parameters.infillScalingHeader')}
|
|
||||||
isOpen={isOpen}
|
|
||||||
onToggle={onToggle}
|
|
||||||
>
|
|
||||||
<Flex sx={{ gap: 2, flexDirection: 'column' }}>
|
<Flex sx={{ gap: 2, flexDirection: 'column' }}>
|
||||||
<ParamInfillMethod />
|
<ParamInfillMethod />
|
||||||
<ParamInfillTilesize />
|
<ParamInfillTilesize />
|
||||||
|
@ -1,22 +1,16 @@
|
|||||||
|
import IAICollapse from 'common/components/IAICollapse';
|
||||||
|
import { memo } from 'react';
|
||||||
|
import { useTranslation } from 'react-i18next';
|
||||||
import ParamSeamBlur from './ParamSeamBlur';
|
import ParamSeamBlur from './ParamSeamBlur';
|
||||||
import ParamSeamSize from './ParamSeamSize';
|
import ParamSeamSize from './ParamSeamSize';
|
||||||
import ParamSeamSteps from './ParamSeamSteps';
|
import ParamSeamSteps from './ParamSeamSteps';
|
||||||
import ParamSeamStrength from './ParamSeamStrength';
|
import ParamSeamStrength from './ParamSeamStrength';
|
||||||
import { useDisclosure } from '@chakra-ui/react';
|
|
||||||
import { useTranslation } from 'react-i18next';
|
|
||||||
import IAICollapse from 'common/components/IAICollapse';
|
|
||||||
import { memo } from 'react';
|
|
||||||
|
|
||||||
const ParamSeamCorrectionCollapse = () => {
|
const ParamSeamCorrectionCollapse = () => {
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
const { isOpen, onToggle } = useDisclosure();
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<IAICollapse
|
<IAICollapse label={t('parameters.seamCorrectionHeader')}>
|
||||||
label={t('parameters.seamCorrectionHeader')}
|
|
||||||
isOpen={isOpen}
|
|
||||||
onToggle={onToggle}
|
|
||||||
>
|
|
||||||
<ParamSeamSize />
|
<ParamSeamSize />
|
||||||
<ParamSeamBlur />
|
<ParamSeamBlur />
|
||||||
<ParamSeamStrength />
|
<ParamSeamStrength />
|
||||||
|
@ -1,41 +1,45 @@
|
|||||||
import { Divider, Flex } from '@chakra-ui/react';
|
import { Divider, Flex } from '@chakra-ui/react';
|
||||||
import { useTranslation } from 'react-i18next';
|
|
||||||
import IAICollapse from 'common/components/IAICollapse';
|
|
||||||
import { Fragment, memo, useCallback } from 'react';
|
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
|
||||||
import { createSelector } from '@reduxjs/toolkit';
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||||
|
import IAIButton from 'common/components/IAIButton';
|
||||||
|
import IAICollapse from 'common/components/IAICollapse';
|
||||||
|
import ControlNet from 'features/controlNet/components/ControlNet';
|
||||||
|
import ParamControlNetFeatureToggle from 'features/controlNet/components/parameters/ParamControlNetFeatureToggle';
|
||||||
import {
|
import {
|
||||||
controlNetAdded,
|
controlNetAdded,
|
||||||
controlNetSelector,
|
controlNetSelector,
|
||||||
isControlNetEnabledToggled,
|
|
||||||
} from 'features/controlNet/store/controlNetSlice';
|
} from 'features/controlNet/store/controlNetSlice';
|
||||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
import { getValidControlNets } from 'features/controlNet/util/getValidControlNets';
|
||||||
import { map } from 'lodash-es';
|
|
||||||
import { v4 as uuidv4 } from 'uuid';
|
|
||||||
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
|
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
|
||||||
import IAIButton from 'common/components/IAIButton';
|
import { map } from 'lodash-es';
|
||||||
import ControlNet from 'features/controlNet/components/ControlNet';
|
import { Fragment, memo, useCallback } from 'react';
|
||||||
|
import { useTranslation } from 'react-i18next';
|
||||||
|
import { v4 as uuidv4 } from 'uuid';
|
||||||
|
|
||||||
const selector = createSelector(
|
const selector = createSelector(
|
||||||
controlNetSelector,
|
controlNetSelector,
|
||||||
(controlNet) => {
|
(controlNet) => {
|
||||||
const { controlNets, isEnabled } = controlNet;
|
const { controlNets, isEnabled } = controlNet;
|
||||||
|
|
||||||
return { controlNetsArray: map(controlNets), isEnabled };
|
const validControlNets = getValidControlNets(controlNets);
|
||||||
|
|
||||||
|
const activeLabel =
|
||||||
|
isEnabled && validControlNets.length > 0
|
||||||
|
? `${validControlNets.length} Active`
|
||||||
|
: undefined;
|
||||||
|
|
||||||
|
return { controlNetsArray: map(controlNets), activeLabel };
|
||||||
},
|
},
|
||||||
defaultSelectorOptions
|
defaultSelectorOptions
|
||||||
);
|
);
|
||||||
|
|
||||||
const ParamControlNetCollapse = () => {
|
const ParamControlNetCollapse = () => {
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
const { controlNetsArray, isEnabled } = useAppSelector(selector);
|
const { controlNetsArray, activeLabel } = useAppSelector(selector);
|
||||||
const isControlNetDisabled = useFeatureStatus('controlNet').isFeatureDisabled;
|
const isControlNetDisabled = useFeatureStatus('controlNet').isFeatureDisabled;
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
|
|
||||||
const handleClickControlNetToggle = useCallback(() => {
|
|
||||||
dispatch(isControlNetEnabledToggled());
|
|
||||||
}, [dispatch]);
|
|
||||||
|
|
||||||
const handleClickedAddControlNet = useCallback(() => {
|
const handleClickedAddControlNet = useCallback(() => {
|
||||||
dispatch(controlNetAdded({ controlNetId: uuidv4() }));
|
dispatch(controlNetAdded({ controlNetId: uuidv4() }));
|
||||||
}, [dispatch]);
|
}, [dispatch]);
|
||||||
@ -45,13 +49,9 @@ const ParamControlNetCollapse = () => {
|
|||||||
}
|
}
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<IAICollapse
|
<IAICollapse label="ControlNet" activeLabel={activeLabel}>
|
||||||
label={'ControlNet'}
|
|
||||||
isOpen={isEnabled}
|
|
||||||
onToggle={handleClickControlNetToggle}
|
|
||||||
withSwitch
|
|
||||||
>
|
|
||||||
<Flex sx={{ flexDir: 'column', gap: 3 }}>
|
<Flex sx={{ flexDir: 'column', gap: 3 }}>
|
||||||
|
<ParamControlNetFeatureToggle />
|
||||||
{controlNetsArray.map((c, i) => (
|
{controlNetsArray.map((c, i) => (
|
||||||
<Fragment key={c.controlNetId}>
|
<Fragment key={c.controlNetId}>
|
||||||
{i > 0 && <Divider />}
|
{i > 0 && <Divider />}
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import { createSelector } from '@reduxjs/toolkit';
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||||
import IAINumberInput from 'common/components/IAINumberInput';
|
import IAINumberInput from 'common/components/IAINumberInput';
|
||||||
import IAISlider from 'common/components/IAISlider';
|
import IAISlider from 'common/components/IAISlider';
|
||||||
import { generationSelector } from 'features/parameters/store/generationSelectors';
|
import { generationSelector } from 'features/parameters/store/generationSelectors';
|
||||||
@ -27,7 +28,8 @@ const selector = createSelector(
|
|||||||
shouldUseSliders,
|
shouldUseSliders,
|
||||||
shift,
|
shift,
|
||||||
};
|
};
|
||||||
}
|
},
|
||||||
|
defaultSelectorOptions
|
||||||
);
|
);
|
||||||
|
|
||||||
const ParamCFGScale = () => {
|
const ParamCFGScale = () => {
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import { createSelector } from '@reduxjs/toolkit';
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||||
import IAISlider, { IAIFullSliderProps } from 'common/components/IAISlider';
|
import IAISlider, { IAIFullSliderProps } from 'common/components/IAISlider';
|
||||||
import { generationSelector } from 'features/parameters/store/generationSelectors';
|
import { generationSelector } from 'features/parameters/store/generationSelectors';
|
||||||
import { setHeight } from 'features/parameters/store/generationSlice';
|
import { setHeight } from 'features/parameters/store/generationSlice';
|
||||||
@ -25,7 +26,8 @@ const selector = createSelector(
|
|||||||
inputMax,
|
inputMax,
|
||||||
step,
|
step,
|
||||||
};
|
};
|
||||||
}
|
},
|
||||||
|
defaultSelectorOptions
|
||||||
);
|
);
|
||||||
|
|
||||||
type ParamHeightProps = Omit<
|
type ParamHeightProps = Omit<
|
||||||
|
@ -1,37 +1,38 @@
|
|||||||
import { createSelector } from '@reduxjs/toolkit';
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
import { stateSelector } from 'app/store/store';
|
import { stateSelector } from 'app/store/store';
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||||
import IAINumberInput from 'common/components/IAINumberInput';
|
import IAINumberInput from 'common/components/IAINumberInput';
|
||||||
import IAISlider from 'common/components/IAISlider';
|
import IAISlider from 'common/components/IAISlider';
|
||||||
import { generationSelector } from 'features/parameters/store/generationSelectors';
|
|
||||||
import { setIterations } from 'features/parameters/store/generationSlice';
|
import { setIterations } from 'features/parameters/store/generationSlice';
|
||||||
import { configSelector } from 'features/system/store/configSelectors';
|
|
||||||
import { hotkeysSelector } from 'features/ui/store/hotkeysSlice';
|
|
||||||
import { uiSelector } from 'features/ui/store/uiSelectors';
|
|
||||||
import { memo, useCallback } from 'react';
|
import { memo, useCallback } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
|
|
||||||
const selector = createSelector([stateSelector], (state) => {
|
const selector = createSelector(
|
||||||
const { initial, min, sliderMax, inputMax, fineStep, coarseStep } =
|
[stateSelector],
|
||||||
state.config.sd.iterations;
|
(state) => {
|
||||||
const { iterations } = state.generation;
|
const { initial, min, sliderMax, inputMax, fineStep, coarseStep } =
|
||||||
const { shouldUseSliders } = state.ui;
|
state.config.sd.iterations;
|
||||||
const isDisabled =
|
const { iterations } = state.generation;
|
||||||
state.dynamicPrompts.isEnabled && state.dynamicPrompts.combinatorial;
|
const { shouldUseSliders } = state.ui;
|
||||||
|
const isDisabled =
|
||||||
|
state.dynamicPrompts.isEnabled && state.dynamicPrompts.combinatorial;
|
||||||
|
|
||||||
const step = state.hotkeys.shift ? fineStep : coarseStep;
|
const step = state.hotkeys.shift ? fineStep : coarseStep;
|
||||||
|
|
||||||
return {
|
return {
|
||||||
iterations,
|
iterations,
|
||||||
initial,
|
initial,
|
||||||
min,
|
min,
|
||||||
sliderMax,
|
sliderMax,
|
||||||
inputMax,
|
inputMax,
|
||||||
step,
|
step,
|
||||||
shouldUseSliders,
|
shouldUseSliders,
|
||||||
isDisabled,
|
isDisabled,
|
||||||
};
|
};
|
||||||
});
|
},
|
||||||
|
defaultSelectorOptions
|
||||||
|
);
|
||||||
|
|
||||||
const ParamIterations = () => {
|
const ParamIterations = () => {
|
||||||
const {
|
const {
|
||||||
|
@ -1,19 +1,19 @@
|
|||||||
import { Box, Flex } from '@chakra-ui/react';
|
import { Box, Flex } from '@chakra-ui/react';
|
||||||
import ModelSelect from 'features/system/components/ModelSelect';
|
import ModelSelect from 'features/system/components/ModelSelect';
|
||||||
|
import VAESelect from 'features/system/components/VAESelect';
|
||||||
import { memo } from 'react';
|
import { memo } from 'react';
|
||||||
import ParamScheduler from './ParamScheduler';
|
|
||||||
|
|
||||||
const ParamSchedulerAndModel = () => {
|
const ParamModelandVAE = () => {
|
||||||
return (
|
return (
|
||||||
<Flex gap={3} w="full">
|
<Flex gap={3} w="full">
|
||||||
<Box w="25rem">
|
|
||||||
<ParamScheduler />
|
|
||||||
</Box>
|
|
||||||
<Box w="full">
|
<Box w="full">
|
||||||
<ModelSelect />
|
<ModelSelect />
|
||||||
</Box>
|
</Box>
|
||||||
|
<Box w="full">
|
||||||
|
<VAESelect />
|
||||||
|
</Box>
|
||||||
</Flex>
|
</Flex>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
export default memo(ParamSchedulerAndModel);
|
export default memo(ParamModelandVAE);
|
@ -1,5 +1,6 @@
|
|||||||
import { createSelector } from '@reduxjs/toolkit';
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||||
import IAINumberInput from 'common/components/IAINumberInput';
|
import IAINumberInput from 'common/components/IAINumberInput';
|
||||||
|
|
||||||
import IAISlider from 'common/components/IAISlider';
|
import IAISlider from 'common/components/IAISlider';
|
||||||
@ -33,7 +34,8 @@ const selector = createSelector(
|
|||||||
step,
|
step,
|
||||||
shouldUseSliders,
|
shouldUseSliders,
|
||||||
};
|
};
|
||||||
}
|
},
|
||||||
|
defaultSelectorOptions
|
||||||
);
|
);
|
||||||
|
|
||||||
const ParamSteps = () => {
|
const ParamSteps = () => {
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import { createSelector } from '@reduxjs/toolkit';
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
import IAISlider from 'common/components/IAISlider';
|
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||||
import { IAIFullSliderProps } from 'common/components/IAISlider';
|
import IAISlider, { IAIFullSliderProps } from 'common/components/IAISlider';
|
||||||
import { generationSelector } from 'features/parameters/store/generationSelectors';
|
import { generationSelector } from 'features/parameters/store/generationSelectors';
|
||||||
import { setWidth } from 'features/parameters/store/generationSlice';
|
import { setWidth } from 'features/parameters/store/generationSlice';
|
||||||
import { configSelector } from 'features/system/store/configSelectors';
|
import { configSelector } from 'features/system/store/configSelectors';
|
||||||
@ -26,7 +26,8 @@ const selector = createSelector(
|
|||||||
inputMax,
|
inputMax,
|
||||||
step,
|
step,
|
||||||
};
|
};
|
||||||
}
|
},
|
||||||
|
defaultSelectorOptions
|
||||||
);
|
);
|
||||||
|
|
||||||
type ParamWidthProps = Omit<IAIFullSliderProps, 'label' | 'value' | 'onChange'>;
|
type ParamWidthProps = Omit<IAIFullSliderProps, 'label' | 'value' | 'onChange'>;
|
||||||
|
@ -1,37 +1,39 @@
|
|||||||
import { Flex } from '@chakra-ui/react';
|
import { Flex } from '@chakra-ui/react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { stateSelector } from 'app/store/store';
|
||||||
import { RootState } from 'app/store/store';
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||||
import IAICollapse from 'common/components/IAICollapse';
|
import IAICollapse from 'common/components/IAICollapse';
|
||||||
import { memo } from 'react';
|
|
||||||
import { ParamHiresStrength } from './ParamHiresStrength';
|
|
||||||
import { setHiresFix } from 'features/parameters/store/postprocessingSlice';
|
|
||||||
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
|
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
|
||||||
|
import { memo } from 'react';
|
||||||
|
import { useTranslation } from 'react-i18next';
|
||||||
|
import { ParamHiresStrength } from './ParamHiresStrength';
|
||||||
|
import { ParamHiresToggle } from './ParamHiresToggle';
|
||||||
|
|
||||||
|
const selector = createSelector(
|
||||||
|
stateSelector,
|
||||||
|
(state) => {
|
||||||
|
const activeLabel = state.postprocessing.hiresFix ? 'Enabled' : undefined;
|
||||||
|
|
||||||
|
return { activeLabel };
|
||||||
|
},
|
||||||
|
defaultSelectorOptions
|
||||||
|
);
|
||||||
|
|
||||||
const ParamHiresCollapse = () => {
|
const ParamHiresCollapse = () => {
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
const hiresFix = useAppSelector(
|
const { activeLabel } = useAppSelector(selector);
|
||||||
(state: RootState) => state.postprocessing.hiresFix
|
|
||||||
);
|
|
||||||
|
|
||||||
const isHiresEnabled = useFeatureStatus('hires').isFeatureEnabled;
|
const isHiresEnabled = useFeatureStatus('hires').isFeatureEnabled;
|
||||||
|
|
||||||
const dispatch = useAppDispatch();
|
|
||||||
|
|
||||||
const handleToggle = () => dispatch(setHiresFix(!hiresFix));
|
|
||||||
|
|
||||||
if (!isHiresEnabled) {
|
if (!isHiresEnabled) {
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<IAICollapse
|
<IAICollapse label={t('parameters.hiresOptim')} activeLabel={activeLabel}>
|
||||||
label={t('parameters.hiresOptim')}
|
|
||||||
isOpen={hiresFix}
|
|
||||||
onToggle={handleToggle}
|
|
||||||
withSwitch
|
|
||||||
>
|
|
||||||
<Flex sx={{ gap: 2, flexDirection: 'column' }}>
|
<Flex sx={{ gap: 2, flexDirection: 'column' }}>
|
||||||
|
<ParamHiresToggle />
|
||||||
<ParamHiresStrength />
|
<ParamHiresStrength />
|
||||||
</Flex>
|
</Flex>
|
||||||
</IAICollapse>
|
</IAICollapse>
|
||||||
|
@ -23,7 +23,6 @@ export const ParamHiresToggle = () => {
|
|||||||
return (
|
return (
|
||||||
<IAISwitch
|
<IAISwitch
|
||||||
label={t('parameters.hiresOptim')}
|
label={t('parameters.hiresOptim')}
|
||||||
fontSize="md"
|
|
||||||
isChecked={hiresFix}
|
isChecked={hiresFix}
|
||||||
onChange={handleChangeHiresFix}
|
onChange={handleChangeHiresFix}
|
||||||
/>
|
/>
|
||||||
|
@ -1,27 +1,33 @@
|
|||||||
import { useTranslation } from 'react-i18next';
|
|
||||||
import { Flex } from '@chakra-ui/react';
|
import { Flex } from '@chakra-ui/react';
|
||||||
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
|
import { stateSelector } from 'app/store/store';
|
||||||
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||||
import IAICollapse from 'common/components/IAICollapse';
|
import IAICollapse from 'common/components/IAICollapse';
|
||||||
import ParamPerlinNoise from './ParamPerlinNoise';
|
|
||||||
import ParamNoiseThreshold from './ParamNoiseThreshold';
|
|
||||||
import { RootState } from 'app/store/store';
|
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
|
||||||
import { setShouldUseNoiseSettings } from 'features/parameters/store/generationSlice';
|
|
||||||
import { memo } from 'react';
|
|
||||||
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
|
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
|
||||||
|
import { memo } from 'react';
|
||||||
|
import { useTranslation } from 'react-i18next';
|
||||||
|
import ParamNoiseThreshold from './ParamNoiseThreshold';
|
||||||
|
import { ParamNoiseToggle } from './ParamNoiseToggle';
|
||||||
|
import ParamPerlinNoise from './ParamPerlinNoise';
|
||||||
|
|
||||||
|
const selector = createSelector(
|
||||||
|
stateSelector,
|
||||||
|
(state) => {
|
||||||
|
const { shouldUseNoiseSettings } = state.generation;
|
||||||
|
return {
|
||||||
|
activeLabel: shouldUseNoiseSettings ? 'Enabled' : undefined,
|
||||||
|
};
|
||||||
|
},
|
||||||
|
defaultSelectorOptions
|
||||||
|
);
|
||||||
|
|
||||||
const ParamNoiseCollapse = () => {
|
const ParamNoiseCollapse = () => {
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
|
|
||||||
const isNoiseEnabled = useFeatureStatus('noise').isFeatureEnabled;
|
const isNoiseEnabled = useFeatureStatus('noise').isFeatureEnabled;
|
||||||
|
|
||||||
const shouldUseNoiseSettings = useAppSelector(
|
const { activeLabel } = useAppSelector(selector);
|
||||||
(state: RootState) => state.generation.shouldUseNoiseSettings
|
|
||||||
);
|
|
||||||
|
|
||||||
const dispatch = useAppDispatch();
|
|
||||||
|
|
||||||
const handleToggle = () =>
|
|
||||||
dispatch(setShouldUseNoiseSettings(!shouldUseNoiseSettings));
|
|
||||||
|
|
||||||
if (!isNoiseEnabled) {
|
if (!isNoiseEnabled) {
|
||||||
return null;
|
return null;
|
||||||
@ -30,11 +36,10 @@ const ParamNoiseCollapse = () => {
|
|||||||
return (
|
return (
|
||||||
<IAICollapse
|
<IAICollapse
|
||||||
label={t('parameters.noiseSettings')}
|
label={t('parameters.noiseSettings')}
|
||||||
isOpen={shouldUseNoiseSettings}
|
activeLabel={activeLabel}
|
||||||
onToggle={handleToggle}
|
|
||||||
withSwitch
|
|
||||||
>
|
>
|
||||||
<Flex sx={{ gap: 2, flexDirection: 'column' }}>
|
<Flex sx={{ gap: 2, flexDirection: 'column' }}>
|
||||||
|
<ParamNoiseToggle />
|
||||||
<ParamPerlinNoise />
|
<ParamPerlinNoise />
|
||||||
<ParamNoiseThreshold />
|
<ParamNoiseThreshold />
|
||||||
</Flex>
|
</Flex>
|
||||||
|
@ -1,18 +1,31 @@
|
|||||||
import { RootState } from 'app/store/store';
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
|
import { stateSelector } from 'app/store/store';
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||||
import IAISlider from 'common/components/IAISlider';
|
import IAISlider from 'common/components/IAISlider';
|
||||||
import { setThreshold } from 'features/parameters/store/generationSlice';
|
import { setThreshold } from 'features/parameters/store/generationSlice';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
|
|
||||||
|
const selector = createSelector(
|
||||||
|
stateSelector,
|
||||||
|
(state) => {
|
||||||
|
const { shouldUseNoiseSettings, threshold } = state.generation;
|
||||||
|
return {
|
||||||
|
isDisabled: !shouldUseNoiseSettings,
|
||||||
|
threshold,
|
||||||
|
};
|
||||||
|
},
|
||||||
|
defaultSelectorOptions
|
||||||
|
);
|
||||||
|
|
||||||
export default function ParamNoiseThreshold() {
|
export default function ParamNoiseThreshold() {
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const threshold = useAppSelector(
|
const { threshold, isDisabled } = useAppSelector(selector);
|
||||||
(state: RootState) => state.generation.threshold
|
|
||||||
);
|
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<IAISlider
|
<IAISlider
|
||||||
|
isDisabled={isDisabled}
|
||||||
label={t('parameters.noiseThreshold')}
|
label={t('parameters.noiseThreshold')}
|
||||||
min={0}
|
min={0}
|
||||||
max={20}
|
max={20}
|
||||||
|
@ -0,0 +1,27 @@
|
|||||||
|
import type { RootState } from 'app/store/store';
|
||||||
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import IAISwitch from 'common/components/IAISwitch';
|
||||||
|
import { setShouldUseNoiseSettings } from 'features/parameters/store/generationSlice';
|
||||||
|
import { ChangeEvent } from 'react';
|
||||||
|
import { useTranslation } from 'react-i18next';
|
||||||
|
|
||||||
|
export const ParamNoiseToggle = () => {
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
|
||||||
|
const shouldUseNoiseSettings = useAppSelector(
|
||||||
|
(state: RootState) => state.generation.shouldUseNoiseSettings
|
||||||
|
);
|
||||||
|
|
||||||
|
const { t } = useTranslation();
|
||||||
|
|
||||||
|
const handleChange = (e: ChangeEvent<HTMLInputElement>) =>
|
||||||
|
dispatch(setShouldUseNoiseSettings(e.target.checked));
|
||||||
|
|
||||||
|
return (
|
||||||
|
<IAISwitch
|
||||||
|
label="Enable Noise Settings"
|
||||||
|
isChecked={shouldUseNoiseSettings}
|
||||||
|
onChange={handleChange}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
};
|
@ -1,16 +1,31 @@
|
|||||||
import { RootState } from 'app/store/store';
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
|
import { stateSelector } from 'app/store/store';
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||||
import IAISlider from 'common/components/IAISlider';
|
import IAISlider from 'common/components/IAISlider';
|
||||||
import { setPerlin } from 'features/parameters/store/generationSlice';
|
import { setPerlin } from 'features/parameters/store/generationSlice';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
|
|
||||||
|
const selector = createSelector(
|
||||||
|
stateSelector,
|
||||||
|
(state) => {
|
||||||
|
const { shouldUseNoiseSettings, perlin } = state.generation;
|
||||||
|
return {
|
||||||
|
isDisabled: !shouldUseNoiseSettings,
|
||||||
|
perlin,
|
||||||
|
};
|
||||||
|
},
|
||||||
|
defaultSelectorOptions
|
||||||
|
);
|
||||||
|
|
||||||
export default function ParamPerlinNoise() {
|
export default function ParamPerlinNoise() {
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const perlin = useAppSelector((state: RootState) => state.generation.perlin);
|
const { perlin, isDisabled } = useAppSelector(selector);
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<IAISlider
|
<IAISlider
|
||||||
|
isDisabled={isDisabled}
|
||||||
label={t('parameters.perlinNoise')}
|
label={t('parameters.perlinNoise')}
|
||||||
min={0}
|
min={0}
|
||||||
max={1}
|
max={1}
|
||||||
|
@ -1,36 +1,46 @@
|
|||||||
import { useTranslation } from 'react-i18next';
|
|
||||||
import { Box, Flex } from '@chakra-ui/react';
|
import { Box, Flex } from '@chakra-ui/react';
|
||||||
import IAICollapse from 'common/components/IAICollapse';
|
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
|
||||||
import { setSeamless } from 'features/parameters/store/generationSlice';
|
|
||||||
import { memo } from 'react';
|
|
||||||
import { createSelector } from '@reduxjs/toolkit';
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
import { generationSelector } from 'features/parameters/store/generationSelectors';
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||||
|
import IAICollapse from 'common/components/IAICollapse';
|
||||||
|
import { generationSelector } from 'features/parameters/store/generationSelectors';
|
||||||
|
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
|
||||||
|
import { memo } from 'react';
|
||||||
|
import { useTranslation } from 'react-i18next';
|
||||||
import ParamSeamlessXAxis from './ParamSeamlessXAxis';
|
import ParamSeamlessXAxis from './ParamSeamlessXAxis';
|
||||||
import ParamSeamlessYAxis from './ParamSeamlessYAxis';
|
import ParamSeamlessYAxis from './ParamSeamlessYAxis';
|
||||||
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
|
|
||||||
|
const getActiveLabel = (seamlessXAxis: boolean, seamlessYAxis: boolean) => {
|
||||||
|
if (seamlessXAxis && seamlessYAxis) {
|
||||||
|
return 'X & Y';
|
||||||
|
}
|
||||||
|
|
||||||
|
if (seamlessXAxis) {
|
||||||
|
return 'X';
|
||||||
|
}
|
||||||
|
|
||||||
|
if (seamlessYAxis) {
|
||||||
|
return 'Y';
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
const selector = createSelector(
|
const selector = createSelector(
|
||||||
generationSelector,
|
generationSelector,
|
||||||
(generation) => {
|
(generation) => {
|
||||||
const { shouldUseSeamless, seamlessXAxis, seamlessYAxis } = generation;
|
const { seamlessXAxis, seamlessYAxis } = generation;
|
||||||
|
|
||||||
return { shouldUseSeamless, seamlessXAxis, seamlessYAxis };
|
const activeLabel = getActiveLabel(seamlessXAxis, seamlessYAxis);
|
||||||
|
return { activeLabel };
|
||||||
},
|
},
|
||||||
defaultSelectorOptions
|
defaultSelectorOptions
|
||||||
);
|
);
|
||||||
|
|
||||||
const ParamSeamlessCollapse = () => {
|
const ParamSeamlessCollapse = () => {
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
const { shouldUseSeamless } = useAppSelector(selector);
|
const { activeLabel } = useAppSelector(selector);
|
||||||
|
|
||||||
const isSeamlessEnabled = useFeatureStatus('seamless').isFeatureEnabled;
|
const isSeamlessEnabled = useFeatureStatus('seamless').isFeatureEnabled;
|
||||||
|
|
||||||
const dispatch = useAppDispatch();
|
|
||||||
|
|
||||||
const handleToggle = () => dispatch(setSeamless(!shouldUseSeamless));
|
|
||||||
|
|
||||||
if (!isSeamlessEnabled) {
|
if (!isSeamlessEnabled) {
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
@ -38,9 +48,7 @@ const ParamSeamlessCollapse = () => {
|
|||||||
return (
|
return (
|
||||||
<IAICollapse
|
<IAICollapse
|
||||||
label={t('parameters.seamlessTiling')}
|
label={t('parameters.seamlessTiling')}
|
||||||
isOpen={shouldUseSeamless}
|
activeLabel={activeLabel}
|
||||||
onToggle={handleToggle}
|
|
||||||
withSwitch
|
|
||||||
>
|
>
|
||||||
<Flex sx={{ gap: 5 }}>
|
<Flex sx={{ gap: 5 }}>
|
||||||
<Box flexGrow={1}>
|
<Box flexGrow={1}>
|
||||||
|
@ -1,39 +1,39 @@
|
|||||||
import { memo } from 'react';
|
|
||||||
import { Flex } from '@chakra-ui/react';
|
import { Flex } from '@chakra-ui/react';
|
||||||
|
import { memo } from 'react';
|
||||||
import ParamSymmetryHorizontal from './ParamSymmetryHorizontal';
|
import ParamSymmetryHorizontal from './ParamSymmetryHorizontal';
|
||||||
import ParamSymmetryVertical from './ParamSymmetryVertical';
|
import ParamSymmetryVertical from './ParamSymmetryVertical';
|
||||||
|
|
||||||
import { useTranslation } from 'react-i18next';
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
|
import { stateSelector } from 'app/store/store';
|
||||||
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||||
import IAICollapse from 'common/components/IAICollapse';
|
import IAICollapse from 'common/components/IAICollapse';
|
||||||
import { RootState } from 'app/store/store';
|
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
|
||||||
import { setShouldUseSymmetry } from 'features/parameters/store/generationSlice';
|
|
||||||
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
|
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
|
||||||
|
import { useTranslation } from 'react-i18next';
|
||||||
|
import ParamSymmetryToggle from './ParamSymmetryToggle';
|
||||||
|
|
||||||
|
const selector = createSelector(
|
||||||
|
stateSelector,
|
||||||
|
(state) => ({
|
||||||
|
activeLabel: state.generation.shouldUseSymmetry ? 'Enabled' : undefined,
|
||||||
|
}),
|
||||||
|
defaultSelectorOptions
|
||||||
|
);
|
||||||
|
|
||||||
const ParamSymmetryCollapse = () => {
|
const ParamSymmetryCollapse = () => {
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
const shouldUseSymmetry = useAppSelector(
|
const { activeLabel } = useAppSelector(selector);
|
||||||
(state: RootState) => state.generation.shouldUseSymmetry
|
|
||||||
);
|
|
||||||
|
|
||||||
const isSymmetryEnabled = useFeatureStatus('symmetry').isFeatureEnabled;
|
const isSymmetryEnabled = useFeatureStatus('symmetry').isFeatureEnabled;
|
||||||
|
|
||||||
const dispatch = useAppDispatch();
|
|
||||||
|
|
||||||
const handleToggle = () => dispatch(setShouldUseSymmetry(!shouldUseSymmetry));
|
|
||||||
|
|
||||||
if (!isSymmetryEnabled) {
|
if (!isSymmetryEnabled) {
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<IAICollapse
|
<IAICollapse label={t('parameters.symmetry')} activeLabel={activeLabel}>
|
||||||
label={t('parameters.symmetry')}
|
|
||||||
isOpen={shouldUseSymmetry}
|
|
||||||
onToggle={handleToggle}
|
|
||||||
withSwitch
|
|
||||||
>
|
|
||||||
<Flex sx={{ gap: 2, flexDirection: 'column' }}>
|
<Flex sx={{ gap: 2, flexDirection: 'column' }}>
|
||||||
|
<ParamSymmetryToggle />
|
||||||
<ParamSymmetryHorizontal />
|
<ParamSymmetryHorizontal />
|
||||||
<ParamSymmetryVertical />
|
<ParamSymmetryVertical />
|
||||||
</Flex>
|
</Flex>
|
||||||
|
@ -12,6 +12,7 @@ export default function ParamSymmetryToggle() {
|
|||||||
|
|
||||||
return (
|
return (
|
||||||
<IAISwitch
|
<IAISwitch
|
||||||
|
label="Enable Symmetry"
|
||||||
isChecked={shouldUseSymmetry}
|
isChecked={shouldUseSymmetry}
|
||||||
onChange={(e) => dispatch(setShouldUseSymmetry(e.target.checked))}
|
onChange={(e) => dispatch(setShouldUseSymmetry(e.target.checked))}
|
||||||
/>
|
/>
|
||||||
|
@ -1,39 +1,42 @@
|
|||||||
import ParamVariationWeights from './ParamVariationWeights';
|
|
||||||
import ParamVariationAmount from './ParamVariationAmount';
|
|
||||||
import { useTranslation } from 'react-i18next';
|
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
|
||||||
import { RootState } from 'app/store/store';
|
|
||||||
import { setShouldGenerateVariations } from 'features/parameters/store/generationSlice';
|
|
||||||
import { Flex } from '@chakra-ui/react';
|
import { Flex } from '@chakra-ui/react';
|
||||||
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
|
import { stateSelector } from 'app/store/store';
|
||||||
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||||
import IAICollapse from 'common/components/IAICollapse';
|
import IAICollapse from 'common/components/IAICollapse';
|
||||||
import { memo } from 'react';
|
|
||||||
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
|
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
|
||||||
|
import { memo } from 'react';
|
||||||
|
import { useTranslation } from 'react-i18next';
|
||||||
|
import ParamVariationAmount from './ParamVariationAmount';
|
||||||
|
import { ParamVariationToggle } from './ParamVariationToggle';
|
||||||
|
import ParamVariationWeights from './ParamVariationWeights';
|
||||||
|
|
||||||
|
const selector = createSelector(
|
||||||
|
stateSelector,
|
||||||
|
(state) => {
|
||||||
|
const activeLabel = state.generation.shouldGenerateVariations
|
||||||
|
? 'Enabled'
|
||||||
|
: undefined;
|
||||||
|
|
||||||
|
return { activeLabel };
|
||||||
|
},
|
||||||
|
defaultSelectorOptions
|
||||||
|
);
|
||||||
|
|
||||||
const ParamVariationCollapse = () => {
|
const ParamVariationCollapse = () => {
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
const shouldGenerateVariations = useAppSelector(
|
const { activeLabel } = useAppSelector(selector);
|
||||||
(state: RootState) => state.generation.shouldGenerateVariations
|
|
||||||
);
|
|
||||||
|
|
||||||
const isVariationEnabled = useFeatureStatus('variation').isFeatureEnabled;
|
const isVariationEnabled = useFeatureStatus('variation').isFeatureEnabled;
|
||||||
|
|
||||||
const dispatch = useAppDispatch();
|
|
||||||
|
|
||||||
const handleToggle = () =>
|
|
||||||
dispatch(setShouldGenerateVariations(!shouldGenerateVariations));
|
|
||||||
|
|
||||||
if (!isVariationEnabled) {
|
if (!isVariationEnabled) {
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<IAICollapse
|
<IAICollapse label={t('parameters.variations')} activeLabel={activeLabel}>
|
||||||
label={t('parameters.variations')}
|
|
||||||
isOpen={shouldGenerateVariations}
|
|
||||||
onToggle={handleToggle}
|
|
||||||
withSwitch
|
|
||||||
>
|
|
||||||
<Flex sx={{ gap: 2, flexDirection: 'column' }}>
|
<Flex sx={{ gap: 2, flexDirection: 'column' }}>
|
||||||
|
<ParamVariationToggle />
|
||||||
<ParamVariationAmount />
|
<ParamVariationAmount />
|
||||||
<ParamVariationWeights />
|
<ParamVariationWeights />
|
||||||
</Flex>
|
</Flex>
|
||||||
|
@ -0,0 +1,27 @@
|
|||||||
|
import type { RootState } from 'app/store/store';
|
||||||
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import IAISwitch from 'common/components/IAISwitch';
|
||||||
|
import { setShouldGenerateVariations } from 'features/parameters/store/generationSlice';
|
||||||
|
import { ChangeEvent } from 'react';
|
||||||
|
import { useTranslation } from 'react-i18next';
|
||||||
|
|
||||||
|
export const ParamVariationToggle = () => {
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
|
||||||
|
const shouldGenerateVariations = useAppSelector(
|
||||||
|
(state: RootState) => state.generation.shouldGenerateVariations
|
||||||
|
);
|
||||||
|
|
||||||
|
const { t } = useTranslation();
|
||||||
|
|
||||||
|
const handleChange = (e: ChangeEvent<HTMLInputElement>) =>
|
||||||
|
dispatch(setShouldGenerateVariations(e.target.checked));
|
||||||
|
|
||||||
|
return (
|
||||||
|
<IAISwitch
|
||||||
|
label="Enable Variations"
|
||||||
|
isChecked={shouldGenerateVariations}
|
||||||
|
onChange={handleChange}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
};
|
@ -14,6 +14,7 @@ import {
|
|||||||
SeedParam,
|
SeedParam,
|
||||||
StepsParam,
|
StepsParam,
|
||||||
StrengthParam,
|
StrengthParam,
|
||||||
|
VAEParam,
|
||||||
WidthParam,
|
WidthParam,
|
||||||
} from './parameterZodSchemas';
|
} from './parameterZodSchemas';
|
||||||
|
|
||||||
@ -47,7 +48,7 @@ export interface GenerationState {
|
|||||||
horizontalSymmetrySteps: number;
|
horizontalSymmetrySteps: number;
|
||||||
verticalSymmetrySteps: number;
|
verticalSymmetrySteps: number;
|
||||||
model: ModelParam;
|
model: ModelParam;
|
||||||
shouldUseSeamless: boolean;
|
vae: VAEParam;
|
||||||
seamlessXAxis: boolean;
|
seamlessXAxis: boolean;
|
||||||
seamlessYAxis: boolean;
|
seamlessYAxis: boolean;
|
||||||
}
|
}
|
||||||
@ -81,9 +82,9 @@ export const initialGenerationState: GenerationState = {
|
|||||||
horizontalSymmetrySteps: 0,
|
horizontalSymmetrySteps: 0,
|
||||||
verticalSymmetrySteps: 0,
|
verticalSymmetrySteps: 0,
|
||||||
model: '',
|
model: '',
|
||||||
shouldUseSeamless: false,
|
vae: '',
|
||||||
seamlessXAxis: true,
|
seamlessXAxis: false,
|
||||||
seamlessYAxis: true,
|
seamlessYAxis: false,
|
||||||
};
|
};
|
||||||
|
|
||||||
const initialState: GenerationState = initialGenerationState;
|
const initialState: GenerationState = initialGenerationState;
|
||||||
@ -141,9 +142,6 @@ export const generationSlice = createSlice({
|
|||||||
setImg2imgStrength: (state, action: PayloadAction<number>) => {
|
setImg2imgStrength: (state, action: PayloadAction<number>) => {
|
||||||
state.img2imgStrength = action.payload;
|
state.img2imgStrength = action.payload;
|
||||||
},
|
},
|
||||||
setSeamless: (state, action: PayloadAction<boolean>) => {
|
|
||||||
state.shouldUseSeamless = action.payload;
|
|
||||||
},
|
|
||||||
setSeamlessXAxis: (state, action: PayloadAction<boolean>) => {
|
setSeamlessXAxis: (state, action: PayloadAction<boolean>) => {
|
||||||
state.seamlessXAxis = action.payload;
|
state.seamlessXAxis = action.payload;
|
||||||
},
|
},
|
||||||
@ -216,6 +214,9 @@ export const generationSlice = createSlice({
|
|||||||
modelSelected: (state, action: PayloadAction<string>) => {
|
modelSelected: (state, action: PayloadAction<string>) => {
|
||||||
state.model = action.payload;
|
state.model = action.payload;
|
||||||
},
|
},
|
||||||
|
vaeSelected: (state, action: PayloadAction<string>) => {
|
||||||
|
state.vae = action.payload;
|
||||||
|
},
|
||||||
},
|
},
|
||||||
extraReducers: (builder) => {
|
extraReducers: (builder) => {
|
||||||
builder.addCase(configChanged, (state, action) => {
|
builder.addCase(configChanged, (state, action) => {
|
||||||
@ -260,8 +261,8 @@ export const {
|
|||||||
setVerticalSymmetrySteps,
|
setVerticalSymmetrySteps,
|
||||||
initialImageChanged,
|
initialImageChanged,
|
||||||
modelSelected,
|
modelSelected,
|
||||||
|
vaeSelected,
|
||||||
setShouldUseNoiseSettings,
|
setShouldUseNoiseSettings,
|
||||||
setSeamless,
|
|
||||||
setSeamlessXAxis,
|
setSeamlessXAxis,
|
||||||
setSeamlessYAxis,
|
setSeamlessYAxis,
|
||||||
} = generationSlice.actions;
|
} = generationSlice.actions;
|
||||||
|
@ -135,6 +135,15 @@ export const zModel = z.string();
|
|||||||
* Type alias for model parameter, inferred from its zod schema
|
* Type alias for model parameter, inferred from its zod schema
|
||||||
*/
|
*/
|
||||||
export type ModelParam = z.infer<typeof zModel>;
|
export type ModelParam = z.infer<typeof zModel>;
|
||||||
|
/**
|
||||||
|
* Zod schema for VAE parameter
|
||||||
|
* TODO: Make this a dynamically generated enum?
|
||||||
|
*/
|
||||||
|
export const zVAE = z.string();
|
||||||
|
/**
|
||||||
|
* Type alias for model parameter, inferred from its zod schema
|
||||||
|
*/
|
||||||
|
export type VAEParam = z.infer<typeof zVAE>;
|
||||||
/**
|
/**
|
||||||
* Validates/type-guards a value as a model parameter
|
* Validates/type-guards a value as a model parameter
|
||||||
*/
|
*/
|
||||||
|
@ -1,125 +0,0 @@
|
|||||||
import {
|
|
||||||
Button,
|
|
||||||
Flex,
|
|
||||||
Modal,
|
|
||||||
ModalBody,
|
|
||||||
ModalCloseButton,
|
|
||||||
ModalContent,
|
|
||||||
ModalFooter,
|
|
||||||
ModalHeader,
|
|
||||||
ModalOverlay,
|
|
||||||
Text,
|
|
||||||
useDisclosure,
|
|
||||||
} from '@chakra-ui/react';
|
|
||||||
|
|
||||||
import IAIButton from 'common/components/IAIButton';
|
|
||||||
|
|
||||||
import { FaArrowLeft, FaPlus } from 'react-icons/fa';
|
|
||||||
|
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
|
||||||
import { useTranslation } from 'react-i18next';
|
|
||||||
|
|
||||||
import type { RootState } from 'app/store/store';
|
|
||||||
import { setAddNewModelUIOption } from 'features/ui/store/uiSlice';
|
|
||||||
import AddCheckpointModel from './AddCheckpointModel';
|
|
||||||
import AddDiffusersModel from './AddDiffusersModel';
|
|
||||||
import IAIIconButton from 'common/components/IAIIconButton';
|
|
||||||
|
|
||||||
function AddModelBox({
|
|
||||||
text,
|
|
||||||
onClick,
|
|
||||||
}: {
|
|
||||||
text: string;
|
|
||||||
onClick?: () => void;
|
|
||||||
}) {
|
|
||||||
return (
|
|
||||||
<Flex
|
|
||||||
position="relative"
|
|
||||||
width="50%"
|
|
||||||
height={40}
|
|
||||||
justifyContent="center"
|
|
||||||
alignItems="center"
|
|
||||||
onClick={onClick}
|
|
||||||
as={Button}
|
|
||||||
>
|
|
||||||
<Text fontWeight="bold">{text}</Text>
|
|
||||||
</Flex>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
export default function AddModel() {
|
|
||||||
const { isOpen, onOpen, onClose } = useDisclosure();
|
|
||||||
|
|
||||||
const addNewModelUIOption = useAppSelector(
|
|
||||||
(state: RootState) => state.ui.addNewModelUIOption
|
|
||||||
);
|
|
||||||
|
|
||||||
const dispatch = useAppDispatch();
|
|
||||||
|
|
||||||
const { t } = useTranslation();
|
|
||||||
|
|
||||||
const addModelModalClose = () => {
|
|
||||||
onClose();
|
|
||||||
dispatch(setAddNewModelUIOption(null));
|
|
||||||
};
|
|
||||||
|
|
||||||
return (
|
|
||||||
<>
|
|
||||||
<IAIButton
|
|
||||||
aria-label={t('modelManager.addNewModel')}
|
|
||||||
tooltip={t('modelManager.addNewModel')}
|
|
||||||
onClick={onOpen}
|
|
||||||
size="sm"
|
|
||||||
>
|
|
||||||
<Flex columnGap={2} alignItems="center">
|
|
||||||
<FaPlus />
|
|
||||||
{t('modelManager.addNew')}
|
|
||||||
</Flex>
|
|
||||||
</IAIButton>
|
|
||||||
|
|
||||||
<Modal
|
|
||||||
isOpen={isOpen}
|
|
||||||
onClose={addModelModalClose}
|
|
||||||
size="3xl"
|
|
||||||
closeOnOverlayClick={false}
|
|
||||||
>
|
|
||||||
<ModalOverlay />
|
|
||||||
<ModalContent margin="auto">
|
|
||||||
<ModalHeader>{t('modelManager.addNewModel')} </ModalHeader>
|
|
||||||
{addNewModelUIOption !== null && (
|
|
||||||
<IAIIconButton
|
|
||||||
aria-label={t('common.back')}
|
|
||||||
tooltip={t('common.back')}
|
|
||||||
onClick={() => dispatch(setAddNewModelUIOption(null))}
|
|
||||||
position="absolute"
|
|
||||||
variant="ghost"
|
|
||||||
zIndex={1}
|
|
||||||
size="sm"
|
|
||||||
insetInlineEnd={12}
|
|
||||||
top={2}
|
|
||||||
icon={<FaArrowLeft />}
|
|
||||||
/>
|
|
||||||
)}
|
|
||||||
<ModalCloseButton />
|
|
||||||
<ModalBody>
|
|
||||||
{addNewModelUIOption == null && (
|
|
||||||
<Flex columnGap={4}>
|
|
||||||
<AddModelBox
|
|
||||||
text={t('modelManager.addCheckpointModel')}
|
|
||||||
onClick={() => dispatch(setAddNewModelUIOption('ckpt'))}
|
|
||||||
/>
|
|
||||||
<AddModelBox
|
|
||||||
text={t('modelManager.addDiffuserModel')}
|
|
||||||
onClick={() => dispatch(setAddNewModelUIOption('diffusers'))}
|
|
||||||
/>
|
|
||||||
</Flex>
|
|
||||||
)}
|
|
||||||
{addNewModelUIOption == 'ckpt' && <AddCheckpointModel />}
|
|
||||||
{addNewModelUIOption == 'diffusers' && <AddDiffusersModel />}
|
|
||||||
</ModalBody>
|
|
||||||
<ModalFooter />
|
|
||||||
</ModalContent>
|
|
||||||
</Modal>
|
|
||||||
</>
|
|
||||||
);
|
|
||||||
}
|
|
@ -1,339 +0,0 @@
|
|||||||
import { createSelector } from '@reduxjs/toolkit';
|
|
||||||
|
|
||||||
import IAIButton from 'common/components/IAIButton';
|
|
||||||
import IAIInput from 'common/components/IAIInput';
|
|
||||||
import IAINumberInput from 'common/components/IAINumberInput';
|
|
||||||
import { useEffect, useState } from 'react';
|
|
||||||
|
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
|
||||||
import { systemSelector } from 'features/system/store/systemSelectors';
|
|
||||||
|
|
||||||
import {
|
|
||||||
Flex,
|
|
||||||
FormControl,
|
|
||||||
FormLabel,
|
|
||||||
HStack,
|
|
||||||
Text,
|
|
||||||
VStack,
|
|
||||||
} from '@chakra-ui/react';
|
|
||||||
|
|
||||||
// import { addNewModel } from 'app/socketio/actions';
|
|
||||||
import { Field, Formik } from 'formik';
|
|
||||||
import { useTranslation } from 'react-i18next';
|
|
||||||
|
|
||||||
import type { InvokeModelConfigProps } from 'app/types/invokeai';
|
|
||||||
import type { RootState } from 'app/store/store';
|
|
||||||
import type { FieldInputProps, FormikProps } from 'formik';
|
|
||||||
import { isEqual, pickBy } from 'lodash-es';
|
|
||||||
import ModelConvert from './ModelConvert';
|
|
||||||
import IAIFormHelperText from 'common/components/IAIForms/IAIFormHelperText';
|
|
||||||
import IAIFormErrorMessage from 'common/components/IAIForms/IAIFormErrorMessage';
|
|
||||||
import IAIForm from 'common/components/IAIForm';
|
|
||||||
|
|
||||||
const selector = createSelector(
|
|
||||||
[systemSelector],
|
|
||||||
(system) => {
|
|
||||||
const { openModel, model_list } = system;
|
|
||||||
return {
|
|
||||||
model_list,
|
|
||||||
openModel,
|
|
||||||
};
|
|
||||||
},
|
|
||||||
{
|
|
||||||
memoizeOptions: {
|
|
||||||
resultEqualityCheck: isEqual,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
);
|
|
||||||
|
|
||||||
const MIN_MODEL_SIZE = 64;
|
|
||||||
const MAX_MODEL_SIZE = 2048;
|
|
||||||
|
|
||||||
export default function CheckpointModelEdit() {
|
|
||||||
const { openModel, model_list } = useAppSelector(selector);
|
|
||||||
const isProcessing = useAppSelector(
|
|
||||||
(state: RootState) => state.system.isProcessing
|
|
||||||
);
|
|
||||||
|
|
||||||
const dispatch = useAppDispatch();
|
|
||||||
|
|
||||||
const { t } = useTranslation();
|
|
||||||
|
|
||||||
const [editModelFormValues, setEditModelFormValues] =
|
|
||||||
useState<InvokeModelConfigProps>({
|
|
||||||
name: '',
|
|
||||||
description: '',
|
|
||||||
config: 'configs/stable-diffusion/v1-inference.yaml',
|
|
||||||
weights: '',
|
|
||||||
vae: '',
|
|
||||||
width: 512,
|
|
||||||
height: 512,
|
|
||||||
default: false,
|
|
||||||
format: 'ckpt',
|
|
||||||
});
|
|
||||||
|
|
||||||
useEffect(() => {
|
|
||||||
if (openModel) {
|
|
||||||
const retrievedModel = pickBy(model_list, (_val, key) => {
|
|
||||||
return isEqual(key, openModel);
|
|
||||||
});
|
|
||||||
setEditModelFormValues({
|
|
||||||
name: openModel,
|
|
||||||
description: retrievedModel[openModel]?.description,
|
|
||||||
config: retrievedModel[openModel]?.config,
|
|
||||||
weights: retrievedModel[openModel]?.weights,
|
|
||||||
vae: retrievedModel[openModel]?.vae,
|
|
||||||
width: retrievedModel[openModel]?.width,
|
|
||||||
height: retrievedModel[openModel]?.height,
|
|
||||||
default: retrievedModel[openModel]?.default,
|
|
||||||
format: 'ckpt',
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}, [model_list, openModel]);
|
|
||||||
|
|
||||||
const editModelFormSubmitHandler = (values: InvokeModelConfigProps) => {
|
|
||||||
dispatch(
|
|
||||||
addNewModel({
|
|
||||||
...values,
|
|
||||||
width: Number(values.width),
|
|
||||||
height: Number(values.height),
|
|
||||||
})
|
|
||||||
);
|
|
||||||
};
|
|
||||||
|
|
||||||
return openModel ? (
|
|
||||||
<Flex flexDirection="column" rowGap={4} width="100%">
|
|
||||||
<Flex alignItems="center" gap={4} justifyContent="space-between">
|
|
||||||
<Text fontSize="lg" fontWeight="bold">
|
|
||||||
{openModel}
|
|
||||||
</Text>
|
|
||||||
<ModelConvert model={openModel} />
|
|
||||||
</Flex>
|
|
||||||
<Flex
|
|
||||||
flexDirection="column"
|
|
||||||
maxHeight={window.innerHeight - 270}
|
|
||||||
overflowY="scroll"
|
|
||||||
paddingInlineEnd={8}
|
|
||||||
>
|
|
||||||
<Formik
|
|
||||||
enableReinitialize={true}
|
|
||||||
initialValues={editModelFormValues}
|
|
||||||
onSubmit={editModelFormSubmitHandler}
|
|
||||||
>
|
|
||||||
{({ handleSubmit, errors, touched }) => (
|
|
||||||
<IAIForm onSubmit={handleSubmit}>
|
|
||||||
<VStack rowGap={2} alignItems="start">
|
|
||||||
{/* Description */}
|
|
||||||
<FormControl
|
|
||||||
isInvalid={!!errors.description && touched.description}
|
|
||||||
isRequired
|
|
||||||
>
|
|
||||||
<FormLabel htmlFor="description" fontSize="sm">
|
|
||||||
{t('modelManager.description')}
|
|
||||||
</FormLabel>
|
|
||||||
<VStack alignItems="start">
|
|
||||||
<Field
|
|
||||||
as={IAIInput}
|
|
||||||
id="description"
|
|
||||||
name="description"
|
|
||||||
type="text"
|
|
||||||
width="full"
|
|
||||||
/>
|
|
||||||
{!!errors.description && touched.description ? (
|
|
||||||
<IAIFormErrorMessage>
|
|
||||||
{errors.description}
|
|
||||||
</IAIFormErrorMessage>
|
|
||||||
) : (
|
|
||||||
<IAIFormHelperText>
|
|
||||||
{t('modelManager.descriptionValidationMsg')}
|
|
||||||
</IAIFormHelperText>
|
|
||||||
)}
|
|
||||||
</VStack>
|
|
||||||
</FormControl>
|
|
||||||
|
|
||||||
{/* Config */}
|
|
||||||
<FormControl
|
|
||||||
isInvalid={!!errors.config && touched.config}
|
|
||||||
isRequired
|
|
||||||
>
|
|
||||||
<FormLabel htmlFor="config" fontSize="sm">
|
|
||||||
{t('modelManager.config')}
|
|
||||||
</FormLabel>
|
|
||||||
<VStack alignItems="start">
|
|
||||||
<Field
|
|
||||||
as={IAIInput}
|
|
||||||
id="config"
|
|
||||||
name="config"
|
|
||||||
type="text"
|
|
||||||
width="full"
|
|
||||||
/>
|
|
||||||
{!!errors.config && touched.config ? (
|
|
||||||
<IAIFormErrorMessage>{errors.config}</IAIFormErrorMessage>
|
|
||||||
) : (
|
|
||||||
<IAIFormHelperText>
|
|
||||||
{t('modelManager.configValidationMsg')}
|
|
||||||
</IAIFormHelperText>
|
|
||||||
)}
|
|
||||||
</VStack>
|
|
||||||
</FormControl>
|
|
||||||
|
|
||||||
{/* Weights */}
|
|
||||||
<FormControl
|
|
||||||
isInvalid={!!errors.weights && touched.weights}
|
|
||||||
isRequired
|
|
||||||
>
|
|
||||||
<FormLabel htmlFor="config" fontSize="sm">
|
|
||||||
{t('modelManager.modelLocation')}
|
|
||||||
</FormLabel>
|
|
||||||
<VStack alignItems="start">
|
|
||||||
<Field
|
|
||||||
as={IAIInput}
|
|
||||||
id="weights"
|
|
||||||
name="weights"
|
|
||||||
type="text"
|
|
||||||
width="full"
|
|
||||||
/>
|
|
||||||
{!!errors.weights && touched.weights ? (
|
|
||||||
<IAIFormErrorMessage>
|
|
||||||
{errors.weights}
|
|
||||||
</IAIFormErrorMessage>
|
|
||||||
) : (
|
|
||||||
<IAIFormHelperText>
|
|
||||||
{t('modelManager.modelLocationValidationMsg')}
|
|
||||||
</IAIFormHelperText>
|
|
||||||
)}
|
|
||||||
</VStack>
|
|
||||||
</FormControl>
|
|
||||||
|
|
||||||
{/* VAE */}
|
|
||||||
<FormControl isInvalid={!!errors.vae && touched.vae}>
|
|
||||||
<FormLabel htmlFor="vae" fontSize="sm">
|
|
||||||
{t('modelManager.vaeLocation')}
|
|
||||||
</FormLabel>
|
|
||||||
<VStack alignItems="start">
|
|
||||||
<Field
|
|
||||||
as={IAIInput}
|
|
||||||
id="vae"
|
|
||||||
name="vae"
|
|
||||||
type="text"
|
|
||||||
width="full"
|
|
||||||
/>
|
|
||||||
{!!errors.vae && touched.vae ? (
|
|
||||||
<IAIFormErrorMessage>{errors.vae}</IAIFormErrorMessage>
|
|
||||||
) : (
|
|
||||||
<IAIFormHelperText>
|
|
||||||
{t('modelManager.vaeLocationValidationMsg')}
|
|
||||||
</IAIFormHelperText>
|
|
||||||
)}
|
|
||||||
</VStack>
|
|
||||||
</FormControl>
|
|
||||||
|
|
||||||
<HStack width="100%">
|
|
||||||
{/* Width */}
|
|
||||||
<FormControl isInvalid={!!errors.width && touched.width}>
|
|
||||||
<FormLabel htmlFor="width" fontSize="sm">
|
|
||||||
{t('modelManager.width')}
|
|
||||||
</FormLabel>
|
|
||||||
<VStack alignItems="start">
|
|
||||||
<Field id="width" name="width">
|
|
||||||
{({
|
|
||||||
field,
|
|
||||||
form,
|
|
||||||
}: {
|
|
||||||
field: FieldInputProps<number>;
|
|
||||||
form: FormikProps<InvokeModelConfigProps>;
|
|
||||||
}) => (
|
|
||||||
<IAINumberInput
|
|
||||||
id="width"
|
|
||||||
name="width"
|
|
||||||
min={MIN_MODEL_SIZE}
|
|
||||||
max={MAX_MODEL_SIZE}
|
|
||||||
step={64}
|
|
||||||
value={form.values.width}
|
|
||||||
onChange={(value) =>
|
|
||||||
form.setFieldValue(field.name, Number(value))
|
|
||||||
}
|
|
||||||
/>
|
|
||||||
)}
|
|
||||||
</Field>
|
|
||||||
|
|
||||||
{!!errors.width && touched.width ? (
|
|
||||||
<IAIFormErrorMessage>
|
|
||||||
{errors.width}
|
|
||||||
</IAIFormErrorMessage>
|
|
||||||
) : (
|
|
||||||
<IAIFormHelperText>
|
|
||||||
{t('modelManager.widthValidationMsg')}
|
|
||||||
</IAIFormHelperText>
|
|
||||||
)}
|
|
||||||
</VStack>
|
|
||||||
</FormControl>
|
|
||||||
|
|
||||||
{/* Height */}
|
|
||||||
<FormControl isInvalid={!!errors.height && touched.height}>
|
|
||||||
<FormLabel htmlFor="height" fontSize="sm">
|
|
||||||
{t('modelManager.height')}
|
|
||||||
</FormLabel>
|
|
||||||
<VStack alignItems="start">
|
|
||||||
<Field id="height" name="height">
|
|
||||||
{({
|
|
||||||
field,
|
|
||||||
form,
|
|
||||||
}: {
|
|
||||||
field: FieldInputProps<number>;
|
|
||||||
form: FormikProps<InvokeModelConfigProps>;
|
|
||||||
}) => (
|
|
||||||
<IAINumberInput
|
|
||||||
id="height"
|
|
||||||
name="height"
|
|
||||||
min={MIN_MODEL_SIZE}
|
|
||||||
max={MAX_MODEL_SIZE}
|
|
||||||
step={64}
|
|
||||||
value={form.values.height}
|
|
||||||
onChange={(value) =>
|
|
||||||
form.setFieldValue(field.name, Number(value))
|
|
||||||
}
|
|
||||||
/>
|
|
||||||
)}
|
|
||||||
</Field>
|
|
||||||
|
|
||||||
{!!errors.height && touched.height ? (
|
|
||||||
<IAIFormErrorMessage>
|
|
||||||
{errors.height}
|
|
||||||
</IAIFormErrorMessage>
|
|
||||||
) : (
|
|
||||||
<IAIFormHelperText>
|
|
||||||
{t('modelManager.heightValidationMsg')}
|
|
||||||
</IAIFormHelperText>
|
|
||||||
)}
|
|
||||||
</VStack>
|
|
||||||
</FormControl>
|
|
||||||
</HStack>
|
|
||||||
|
|
||||||
<IAIButton
|
|
||||||
type="submit"
|
|
||||||
className="modal-close-btn"
|
|
||||||
isLoading={isProcessing}
|
|
||||||
>
|
|
||||||
{t('modelManager.updateModel')}
|
|
||||||
</IAIButton>
|
|
||||||
</VStack>
|
|
||||||
</IAIForm>
|
|
||||||
)}
|
|
||||||
</Formik>
|
|
||||||
</Flex>
|
|
||||||
</Flex>
|
|
||||||
) : (
|
|
||||||
<Flex
|
|
||||||
sx={{
|
|
||||||
width: '100%',
|
|
||||||
justifyContent: 'center',
|
|
||||||
alignItems: 'center',
|
|
||||||
borderRadius: 'base',
|
|
||||||
bg: 'base.900',
|
|
||||||
}}
|
|
||||||
>
|
|
||||||
<Text fontWeight={500}>Pick A Model To Edit</Text>
|
|
||||||
</Flex>
|
|
||||||
);
|
|
||||||
}
|
|
@ -1,281 +0,0 @@
|
|||||||
import { createSelector } from '@reduxjs/toolkit';
|
|
||||||
|
|
||||||
import IAIButton from 'common/components/IAIButton';
|
|
||||||
import IAIInput from 'common/components/IAIInput';
|
|
||||||
import { useEffect, useState } from 'react';
|
|
||||||
|
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
|
||||||
import { systemSelector } from 'features/system/store/systemSelectors';
|
|
||||||
|
|
||||||
import { Flex, FormControl, FormLabel, Text, VStack } from '@chakra-ui/react';
|
|
||||||
|
|
||||||
// import { addNewModel } from 'app/socketio/actions';
|
|
||||||
import { Field, Formik } from 'formik';
|
|
||||||
import { useTranslation } from 'react-i18next';
|
|
||||||
|
|
||||||
import type { InvokeDiffusersModelConfigProps } from 'app/types/invokeai';
|
|
||||||
import type { RootState } from 'app/store/store';
|
|
||||||
import { isEqual, pickBy } from 'lodash-es';
|
|
||||||
import IAIFormHelperText from 'common/components/IAIForms/IAIFormHelperText';
|
|
||||||
import IAIFormErrorMessage from 'common/components/IAIForms/IAIFormErrorMessage';
|
|
||||||
import IAIForm from 'common/components/IAIForm';
|
|
||||||
|
|
||||||
const selector = createSelector(
|
|
||||||
[systemSelector],
|
|
||||||
(system) => {
|
|
||||||
const { openModel, model_list } = system;
|
|
||||||
return {
|
|
||||||
model_list,
|
|
||||||
openModel,
|
|
||||||
};
|
|
||||||
},
|
|
||||||
{
|
|
||||||
memoizeOptions: {
|
|
||||||
resultEqualityCheck: isEqual,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
);
|
|
||||||
|
|
||||||
export default function DiffusersModelEdit() {
|
|
||||||
const { openModel, model_list } = useAppSelector(selector);
|
|
||||||
const isProcessing = useAppSelector(
|
|
||||||
(state: RootState) => state.system.isProcessing
|
|
||||||
);
|
|
||||||
|
|
||||||
const dispatch = useAppDispatch();
|
|
||||||
|
|
||||||
const { t } = useTranslation();
|
|
||||||
|
|
||||||
const [editModelFormValues, setEditModelFormValues] =
|
|
||||||
useState<InvokeDiffusersModelConfigProps>({
|
|
||||||
name: '',
|
|
||||||
description: '',
|
|
||||||
repo_id: '',
|
|
||||||
path: '',
|
|
||||||
vae: { repo_id: '', path: '' },
|
|
||||||
default: false,
|
|
||||||
format: 'diffusers',
|
|
||||||
});
|
|
||||||
|
|
||||||
useEffect(() => {
|
|
||||||
if (openModel) {
|
|
||||||
const retrievedModel = pickBy(model_list, (_val, key) => {
|
|
||||||
return isEqual(key, openModel);
|
|
||||||
});
|
|
||||||
|
|
||||||
setEditModelFormValues({
|
|
||||||
name: openModel,
|
|
||||||
description: retrievedModel[openModel]?.description,
|
|
||||||
path:
|
|
||||||
retrievedModel[openModel]?.path &&
|
|
||||||
retrievedModel[openModel]?.path !== 'None'
|
|
||||||
? retrievedModel[openModel]?.path
|
|
||||||
: '',
|
|
||||||
repo_id:
|
|
||||||
retrievedModel[openModel]?.repo_id &&
|
|
||||||
retrievedModel[openModel]?.repo_id !== 'None'
|
|
||||||
? retrievedModel[openModel]?.repo_id
|
|
||||||
: '',
|
|
||||||
vae: {
|
|
||||||
repo_id: retrievedModel[openModel]?.vae?.repo_id
|
|
||||||
? retrievedModel[openModel]?.vae?.repo_id
|
|
||||||
: '',
|
|
||||||
path: retrievedModel[openModel]?.vae?.path
|
|
||||||
? retrievedModel[openModel]?.vae?.path
|
|
||||||
: '',
|
|
||||||
},
|
|
||||||
default: retrievedModel[openModel]?.default,
|
|
||||||
format: 'diffusers',
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}, [model_list, openModel]);
|
|
||||||
|
|
||||||
const editModelFormSubmitHandler = (
|
|
||||||
values: InvokeDiffusersModelConfigProps
|
|
||||||
) => {
|
|
||||||
const diffusersModelToEdit = values;
|
|
||||||
|
|
||||||
if (values.path === '') delete diffusersModelToEdit.path;
|
|
||||||
if (values.repo_id === '') delete diffusersModelToEdit.repo_id;
|
|
||||||
if (values.vae.path === '') delete diffusersModelToEdit.vae.path;
|
|
||||||
if (values.vae.repo_id === '') delete diffusersModelToEdit.vae.repo_id;
|
|
||||||
|
|
||||||
dispatch(addNewModel(values));
|
|
||||||
};
|
|
||||||
|
|
||||||
return openModel ? (
|
|
||||||
<Flex flexDirection="column" rowGap={4} width="100%">
|
|
||||||
<Flex alignItems="center">
|
|
||||||
<Text fontSize="lg" fontWeight="bold">
|
|
||||||
{openModel}
|
|
||||||
</Text>
|
|
||||||
</Flex>
|
|
||||||
<Flex flexDirection="column" overflowY="scroll" paddingInlineEnd={8}>
|
|
||||||
<Formik
|
|
||||||
enableReinitialize={true}
|
|
||||||
initialValues={editModelFormValues}
|
|
||||||
onSubmit={editModelFormSubmitHandler}
|
|
||||||
>
|
|
||||||
{({ handleSubmit, errors, touched }) => (
|
|
||||||
<IAIForm onSubmit={handleSubmit}>
|
|
||||||
<VStack rowGap={2} alignItems="start">
|
|
||||||
{/* Description */}
|
|
||||||
<FormControl
|
|
||||||
isInvalid={!!errors.description && touched.description}
|
|
||||||
isRequired
|
|
||||||
>
|
|
||||||
<FormLabel htmlFor="description" fontSize="sm">
|
|
||||||
{t('modelManager.description')}
|
|
||||||
</FormLabel>
|
|
||||||
<VStack alignItems="start">
|
|
||||||
<Field
|
|
||||||
as={IAIInput}
|
|
||||||
id="description"
|
|
||||||
name="description"
|
|
||||||
type="text"
|
|
||||||
width="full"
|
|
||||||
/>
|
|
||||||
{!!errors.description && touched.description ? (
|
|
||||||
<IAIFormErrorMessage>
|
|
||||||
{errors.description}
|
|
||||||
</IAIFormErrorMessage>
|
|
||||||
) : (
|
|
||||||
<IAIFormHelperText>
|
|
||||||
{t('modelManager.descriptionValidationMsg')}
|
|
||||||
</IAIFormHelperText>
|
|
||||||
)}
|
|
||||||
</VStack>
|
|
||||||
</FormControl>
|
|
||||||
|
|
||||||
{/* Path */}
|
|
||||||
<FormControl
|
|
||||||
isInvalid={!!errors.path && touched.path}
|
|
||||||
isRequired
|
|
||||||
>
|
|
||||||
<FormLabel htmlFor="path" fontSize="sm">
|
|
||||||
{t('modelManager.modelLocation')}
|
|
||||||
</FormLabel>
|
|
||||||
<VStack alignItems="start">
|
|
||||||
<Field
|
|
||||||
as={IAIInput}
|
|
||||||
id="path"
|
|
||||||
name="path"
|
|
||||||
type="text"
|
|
||||||
width="full"
|
|
||||||
/>
|
|
||||||
{!!errors.path && touched.path ? (
|
|
||||||
<IAIFormErrorMessage>{errors.path}</IAIFormErrorMessage>
|
|
||||||
) : (
|
|
||||||
<IAIFormHelperText>
|
|
||||||
{t('modelManager.modelLocationValidationMsg')}
|
|
||||||
</IAIFormHelperText>
|
|
||||||
)}
|
|
||||||
</VStack>
|
|
||||||
</FormControl>
|
|
||||||
|
|
||||||
{/* Repo ID */}
|
|
||||||
<FormControl isInvalid={!!errors.repo_id && touched.repo_id}>
|
|
||||||
<FormLabel htmlFor="repo_id" fontSize="sm">
|
|
||||||
{t('modelManager.repo_id')}
|
|
||||||
</FormLabel>
|
|
||||||
<VStack alignItems="start">
|
|
||||||
<Field
|
|
||||||
as={IAIInput}
|
|
||||||
id="repo_id"
|
|
||||||
name="repo_id"
|
|
||||||
type="text"
|
|
||||||
width="full"
|
|
||||||
/>
|
|
||||||
{!!errors.repo_id && touched.repo_id ? (
|
|
||||||
<IAIFormErrorMessage>
|
|
||||||
{errors.repo_id}
|
|
||||||
</IAIFormErrorMessage>
|
|
||||||
) : (
|
|
||||||
<IAIFormHelperText>
|
|
||||||
{t('modelManager.repoIDValidationMsg')}
|
|
||||||
</IAIFormHelperText>
|
|
||||||
)}
|
|
||||||
</VStack>
|
|
||||||
</FormControl>
|
|
||||||
|
|
||||||
{/* VAE Path */}
|
|
||||||
<FormControl
|
|
||||||
isInvalid={!!errors.vae?.path && touched.vae?.path}
|
|
||||||
>
|
|
||||||
<FormLabel htmlFor="vae.path" fontSize="sm">
|
|
||||||
{t('modelManager.vaeLocation')}
|
|
||||||
</FormLabel>
|
|
||||||
<VStack alignItems="start">
|
|
||||||
<Field
|
|
||||||
as={IAIInput}
|
|
||||||
id="vae.path"
|
|
||||||
name="vae.path"
|
|
||||||
type="text"
|
|
||||||
width="full"
|
|
||||||
/>
|
|
||||||
{!!errors.vae?.path && touched.vae?.path ? (
|
|
||||||
<IAIFormErrorMessage>
|
|
||||||
{errors.vae?.path}
|
|
||||||
</IAIFormErrorMessage>
|
|
||||||
) : (
|
|
||||||
<IAIFormHelperText>
|
|
||||||
{t('modelManager.vaeLocationValidationMsg')}
|
|
||||||
</IAIFormHelperText>
|
|
||||||
)}
|
|
||||||
</VStack>
|
|
||||||
</FormControl>
|
|
||||||
|
|
||||||
{/* VAE Repo ID */}
|
|
||||||
<FormControl
|
|
||||||
isInvalid={!!errors.vae?.repo_id && touched.vae?.repo_id}
|
|
||||||
>
|
|
||||||
<FormLabel htmlFor="vae.repo_id" fontSize="sm">
|
|
||||||
{t('modelManager.vaeRepoID')}
|
|
||||||
</FormLabel>
|
|
||||||
<VStack alignItems="start">
|
|
||||||
<Field
|
|
||||||
as={IAIInput}
|
|
||||||
id="vae.repo_id"
|
|
||||||
name="vae.repo_id"
|
|
||||||
type="text"
|
|
||||||
width="full"
|
|
||||||
/>
|
|
||||||
{!!errors.vae?.repo_id && touched.vae?.repo_id ? (
|
|
||||||
<IAIFormErrorMessage>
|
|
||||||
{errors.vae?.repo_id}
|
|
||||||
</IAIFormErrorMessage>
|
|
||||||
) : (
|
|
||||||
<IAIFormHelperText>
|
|
||||||
{t('modelManager.vaeRepoIDValidationMsg')}
|
|
||||||
</IAIFormHelperText>
|
|
||||||
)}
|
|
||||||
</VStack>
|
|
||||||
</FormControl>
|
|
||||||
|
|
||||||
<IAIButton
|
|
||||||
type="submit"
|
|
||||||
className="modal-close-btn"
|
|
||||||
isLoading={isProcessing}
|
|
||||||
>
|
|
||||||
{t('modelManager.updateModel')}
|
|
||||||
</IAIButton>
|
|
||||||
</VStack>
|
|
||||||
</IAIForm>
|
|
||||||
)}
|
|
||||||
</Formik>
|
|
||||||
</Flex>
|
|
||||||
</Flex>
|
|
||||||
) : (
|
|
||||||
<Flex
|
|
||||||
sx={{
|
|
||||||
width: '100%',
|
|
||||||
justifyContent: 'center',
|
|
||||||
alignItems: 'center',
|
|
||||||
borderRadius: 'base',
|
|
||||||
bg: 'base.900',
|
|
||||||
}}
|
|
||||||
>
|
|
||||||
<Text fontWeight={'500'}>Pick A Model To Edit</Text>
|
|
||||||
</Flex>
|
|
||||||
);
|
|
||||||
}
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user