mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
resolve conflicts; blackify
This commit is contained in:
commit
38c1436f02
20
.github/workflows/pyflakes.yml
vendored
20
.github/workflows/pyflakes.yml
vendored
@ -1,20 +0,0 @@
|
||||
on:
|
||||
pull_request:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
- development
|
||||
- 'release-candidate-*'
|
||||
|
||||
jobs:
|
||||
pyflakes:
|
||||
name: runner / pyflakes
|
||||
if: github.event.pull_request.draft == false
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- name: pyflakes
|
||||
uses: reviewdog/action-pyflakes@v1
|
||||
with:
|
||||
github_token: ${{ secrets.GITHUB_TOKEN }}
|
||||
reporter: github-pr-review
|
7
.github/workflows/style-checks.yml
vendored
7
.github/workflows/style-checks.yml
vendored
@ -18,8 +18,7 @@ jobs:
|
||||
|
||||
- name: Install dependencies with pip
|
||||
run: |
|
||||
pip install black flake8 Flake8-pyproject isort
|
||||
pip install ruff
|
||||
|
||||
- run: isort --check-only .
|
||||
- run: black --check .
|
||||
- run: flake8
|
||||
- run: ruff check --output-format=github .
|
||||
- run: ruff format --check .
|
||||
|
@ -137,7 +137,7 @@ def dest_path(dest=None) -> Path:
|
||||
path_completer = PathCompleter(
|
||||
only_directories=True,
|
||||
expanduser=True,
|
||||
get_paths=lambda: [browse_start],
|
||||
get_paths=lambda: [browse_start], # noqa: B023
|
||||
# get_paths=lambda: [".."].extend(list(browse_start.iterdir()))
|
||||
)
|
||||
|
||||
@ -149,7 +149,7 @@ def dest_path(dest=None) -> Path:
|
||||
completer=path_completer,
|
||||
default=str(browse_start) + os.sep,
|
||||
vi_mode=True,
|
||||
complete_while_typing=True
|
||||
complete_while_typing=True,
|
||||
# Test that this is not needed on Windows
|
||||
# complete_style=CompleteStyle.READLINE_LIKE,
|
||||
)
|
||||
|
@ -28,7 +28,7 @@ class FastAPIEventService(EventServiceBase):
|
||||
self.__queue.put(None)
|
||||
|
||||
def dispatch(self, event_name: str, payload: Any) -> None:
|
||||
self.__queue.put(dict(event_name=event_name, payload=payload))
|
||||
self.__queue.put({"event_name": event_name, "payload": payload})
|
||||
|
||||
async def __dispatch_from_queue(self, stop_event: threading.Event):
|
||||
"""Get events on from the queue and dispatch them, from the correct thread"""
|
||||
|
@ -8,12 +8,20 @@ from typing import List, Optional
|
||||
|
||||
from fastapi import Body, Path, Query, Response
|
||||
from fastapi.routing import APIRouter
|
||||
from pydantic import BaseModel, ConfigDict, TypeAdapter
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from starlette.exceptions import HTTPException
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from invokeai.app.services.model_records import DuplicateModelException, InvalidModelException, UnknownModelException
|
||||
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelType
|
||||
from invokeai.app.services.model_records import (
|
||||
DuplicateModelException,
|
||||
InvalidModelException,
|
||||
UnknownModelException,
|
||||
)
|
||||
from invokeai.backend.model_manager.config import (
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
ModelType,
|
||||
)
|
||||
|
||||
from ..dependencies import ApiDependencies
|
||||
|
||||
@ -28,26 +36,32 @@ class ModelsList(BaseModel):
|
||||
model_config = ConfigDict(use_enum_values=True)
|
||||
|
||||
|
||||
ModelsListValidator = TypeAdapter(ModelsList)
|
||||
|
||||
|
||||
@model_records_router.get(
|
||||
"/",
|
||||
operation_id="list_model_records",
|
||||
)
|
||||
async def list_model_records(
|
||||
base_models: Optional[List[BaseModelType]] = Query(default=None, description="Base models to include"),
|
||||
model_type: Optional[ModelType] = Query(default=None, description="The type of model to get"),
|
||||
base_models: Optional[List[BaseModelType]] = Query(
|
||||
default=None, description="Base models to include"
|
||||
),
|
||||
model_type: Optional[ModelType] = Query(
|
||||
default=None, description="The type of model to get"
|
||||
),
|
||||
) -> ModelsList:
|
||||
"""Get a list of models."""
|
||||
record_store = ApiDependencies.invoker.services.model_records
|
||||
models = list()
|
||||
found_models: list[AnyModelConfig] = []
|
||||
if base_models:
|
||||
for base_model in base_models:
|
||||
models.extend(record_store.search_by_attr(base_model=base_model, model_type=model_type))
|
||||
found_models.extend(
|
||||
record_store.search_by_attr(
|
||||
base_model=base_model, model_type=model_type
|
||||
)
|
||||
)
|
||||
else:
|
||||
models.extend(record_store.search_by_attr(model_type=model_type))
|
||||
return ModelsList(models=models)
|
||||
found_models.extend(record_store.search_by_attr(model_type=model_type))
|
||||
return ModelsList(models=found_models)
|
||||
|
||||
|
||||
@model_records_router.get(
|
||||
"/i/{key}",
|
||||
@ -83,13 +97,16 @@ async def get_model_record(
|
||||
)
|
||||
async def update_model_record(
|
||||
key: Annotated[str, Path(description="Unique key of model")],
|
||||
info: Annotated[AnyModelConfig, Body(description="Model config", discriminator="type")],
|
||||
info: Annotated[
|
||||
AnyModelConfig, Body(description="Model config", discriminator="type")
|
||||
],
|
||||
) -> AnyModelConfig:
|
||||
"""Update model contents with a new config. If the model name or base fields are changed, then the model is renamed."""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
record_store = ApiDependencies.invoker.services.model_records
|
||||
try:
|
||||
model_response = record_store.update_model(key, config=info)
|
||||
logger.info(f"Updated model: {key}")
|
||||
except UnknownModelException as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except ValueError as e:
|
||||
@ -101,7 +118,10 @@ async def update_model_record(
|
||||
@model_records_router.delete(
|
||||
"/i/{key}",
|
||||
operation_id="del_model_record",
|
||||
responses={204: {"description": "Model deleted successfully"}, 404: {"description": "Model not found"}},
|
||||
responses={
|
||||
204: {"description": "Model deleted successfully"},
|
||||
404: {"description": "Model not found"},
|
||||
},
|
||||
status_code=204,
|
||||
)
|
||||
async def del_model_record(
|
||||
@ -125,13 +145,17 @@ async def del_model_record(
|
||||
operation_id="add_model_record",
|
||||
responses={
|
||||
201: {"description": "The model added successfully"},
|
||||
409: {"description": "There is already a model corresponding to this path or repo_id"},
|
||||
409: {
|
||||
"description": "There is already a model corresponding to this path or repo_id"
|
||||
},
|
||||
415: {"description": "Unrecognized file/folder format"},
|
||||
},
|
||||
status_code=201,
|
||||
)
|
||||
async def add_model_record(
|
||||
config: Annotated[AnyModelConfig, Body(description="Model config", discriminator="type")]
|
||||
config: Annotated[
|
||||
AnyModelConfig, Body(description="Model config", discriminator="type")
|
||||
]
|
||||
) -> AnyModelConfig:
|
||||
"""
|
||||
Add a model using the configuration information appropriate for its type.
|
||||
|
@ -54,7 +54,7 @@ async def list_models(
|
||||
) -> ModelsList:
|
||||
"""Gets a list of models"""
|
||||
if base_models and len(base_models) > 0:
|
||||
models_raw = list()
|
||||
models_raw = []
|
||||
for base_model in base_models:
|
||||
models_raw.extend(ApiDependencies.invoker.services.model_manager.list_models(base_model, model_type))
|
||||
else:
|
||||
|
@ -132,7 +132,7 @@ def custom_openapi() -> dict[str, Any]:
|
||||
# Add all outputs
|
||||
all_invocations = BaseInvocation.get_invocations()
|
||||
output_types = set()
|
||||
output_type_titles = dict()
|
||||
output_type_titles = {}
|
||||
for invoker in all_invocations:
|
||||
output_type = signature(invoker.invoke).return_annotation
|
||||
output_types.add(output_type)
|
||||
@ -173,12 +173,12 @@ def custom_openapi() -> dict[str, Any]:
|
||||
# print(f"Config with name {name} already defined")
|
||||
continue
|
||||
|
||||
openapi_schema["components"]["schemas"][name] = dict(
|
||||
title=name,
|
||||
description="An enumeration.",
|
||||
type="string",
|
||||
enum=list(v.value for v in model_config_format_enum),
|
||||
)
|
||||
openapi_schema["components"]["schemas"][name] = {
|
||||
"title": name,
|
||||
"description": "An enumeration.",
|
||||
"type": "string",
|
||||
"enum": [v.value for v in model_config_format_enum],
|
||||
}
|
||||
|
||||
app.openapi_schema = openapi_schema
|
||||
return app.openapi_schema
|
||||
|
@ -25,4 +25,4 @@ spec.loader.exec_module(module)
|
||||
|
||||
# add core nodes to __all__
|
||||
python_files = filter(lambda f: not f.name.startswith("_"), Path(__file__).parent.glob("*.py"))
|
||||
__all__ = list(f.stem for f in python_files) # type: ignore
|
||||
__all__ = [f.stem for f in python_files] # type: ignore
|
||||
|
@ -236,35 +236,35 @@ def InputField(
|
||||
Ignored for non-collection fields.
|
||||
"""
|
||||
|
||||
json_schema_extra_: dict[str, Any] = dict(
|
||||
input=input,
|
||||
ui_type=ui_type,
|
||||
ui_component=ui_component,
|
||||
ui_hidden=ui_hidden,
|
||||
ui_order=ui_order,
|
||||
item_default=item_default,
|
||||
ui_choice_labels=ui_choice_labels,
|
||||
_field_kind="input",
|
||||
)
|
||||
json_schema_extra_: dict[str, Any] = {
|
||||
"input": input,
|
||||
"ui_type": ui_type,
|
||||
"ui_component": ui_component,
|
||||
"ui_hidden": ui_hidden,
|
||||
"ui_order": ui_order,
|
||||
"item_default": item_default,
|
||||
"ui_choice_labels": ui_choice_labels,
|
||||
"_field_kind": "input",
|
||||
}
|
||||
|
||||
field_args = dict(
|
||||
default=default,
|
||||
default_factory=default_factory,
|
||||
title=title,
|
||||
description=description,
|
||||
pattern=pattern,
|
||||
strict=strict,
|
||||
gt=gt,
|
||||
ge=ge,
|
||||
lt=lt,
|
||||
le=le,
|
||||
multiple_of=multiple_of,
|
||||
allow_inf_nan=allow_inf_nan,
|
||||
max_digits=max_digits,
|
||||
decimal_places=decimal_places,
|
||||
min_length=min_length,
|
||||
max_length=max_length,
|
||||
)
|
||||
field_args = {
|
||||
"default": default,
|
||||
"default_factory": default_factory,
|
||||
"title": title,
|
||||
"description": description,
|
||||
"pattern": pattern,
|
||||
"strict": strict,
|
||||
"gt": gt,
|
||||
"ge": ge,
|
||||
"lt": lt,
|
||||
"le": le,
|
||||
"multiple_of": multiple_of,
|
||||
"allow_inf_nan": allow_inf_nan,
|
||||
"max_digits": max_digits,
|
||||
"decimal_places": decimal_places,
|
||||
"min_length": min_length,
|
||||
"max_length": max_length,
|
||||
}
|
||||
|
||||
"""
|
||||
Invocation definitions have their fields typed correctly for their `invoke()` functions.
|
||||
@ -299,24 +299,24 @@ def InputField(
|
||||
|
||||
# because we are manually making fields optional, we need to store the original required bool for reference later
|
||||
if default is PydanticUndefined and default_factory is PydanticUndefined:
|
||||
json_schema_extra_.update(dict(orig_required=True))
|
||||
json_schema_extra_.update({"orig_required": True})
|
||||
else:
|
||||
json_schema_extra_.update(dict(orig_required=False))
|
||||
json_schema_extra_.update({"orig_required": False})
|
||||
|
||||
# make Input.Any and Input.Connection fields optional, providing None as a default if the field doesn't already have one
|
||||
if (input is Input.Any or input is Input.Connection) and default_factory is PydanticUndefined:
|
||||
default_ = None if default is PydanticUndefined else default
|
||||
provided_args.update(dict(default=default_))
|
||||
provided_args.update({"default": default_})
|
||||
if default is not PydanticUndefined:
|
||||
# before invoking, we'll grab the original default value and set it on the field if the field wasn't provided a value
|
||||
json_schema_extra_.update(dict(default=default))
|
||||
json_schema_extra_.update(dict(orig_default=default))
|
||||
json_schema_extra_.update({"default": default})
|
||||
json_schema_extra_.update({"orig_default": default})
|
||||
elif default is not PydanticUndefined and default_factory is PydanticUndefined:
|
||||
default_ = default
|
||||
provided_args.update(dict(default=default_))
|
||||
json_schema_extra_.update(dict(orig_default=default_))
|
||||
provided_args.update({"default": default_})
|
||||
json_schema_extra_.update({"orig_default": default_})
|
||||
elif default_factory is not PydanticUndefined:
|
||||
provided_args.update(dict(default_factory=default_factory))
|
||||
provided_args.update({"default_factory": default_factory})
|
||||
# TODO: cannot serialize default_factory...
|
||||
# json_schema_extra_.update(dict(orig_default_factory=default_factory))
|
||||
|
||||
@ -383,12 +383,12 @@ def OutputField(
|
||||
decimal_places=decimal_places,
|
||||
min_length=min_length,
|
||||
max_length=max_length,
|
||||
json_schema_extra=dict(
|
||||
ui_type=ui_type,
|
||||
ui_hidden=ui_hidden,
|
||||
ui_order=ui_order,
|
||||
_field_kind="output",
|
||||
),
|
||||
json_schema_extra={
|
||||
"ui_type": ui_type,
|
||||
"ui_hidden": ui_hidden,
|
||||
"ui_order": ui_order,
|
||||
"_field_kind": "output",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@ -460,14 +460,14 @@ class BaseInvocationOutput(BaseModel):
|
||||
|
||||
@classmethod
|
||||
def get_output_types(cls) -> Iterable[str]:
|
||||
return map(lambda i: get_type(i), BaseInvocationOutput.get_outputs())
|
||||
return (get_type(i) for i in BaseInvocationOutput.get_outputs())
|
||||
|
||||
@staticmethod
|
||||
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None:
|
||||
# Because we use a pydantic Literal field with default value for the invocation type,
|
||||
# it will be typed as optional in the OpenAPI schema. Make it required manually.
|
||||
if "required" not in schema or not isinstance(schema["required"], list):
|
||||
schema["required"] = list()
|
||||
schema["required"] = []
|
||||
schema["required"].extend(["type"])
|
||||
|
||||
model_config = ConfigDict(
|
||||
@ -527,16 +527,11 @@ class BaseInvocation(ABC, BaseModel):
|
||||
@classmethod
|
||||
def get_invocations_map(cls) -> dict[str, BaseInvocation]:
|
||||
# Get the type strings out of the literals and into a dictionary
|
||||
return dict(
|
||||
map(
|
||||
lambda i: (get_type(i), i),
|
||||
BaseInvocation.get_invocations(),
|
||||
)
|
||||
)
|
||||
return {get_type(i): i for i in BaseInvocation.get_invocations()}
|
||||
|
||||
@classmethod
|
||||
def get_invocation_types(cls) -> Iterable[str]:
|
||||
return map(lambda i: get_type(i), BaseInvocation.get_invocations())
|
||||
return (get_type(i) for i in BaseInvocation.get_invocations())
|
||||
|
||||
@classmethod
|
||||
def get_output_type(cls) -> BaseInvocationOutput:
|
||||
@ -555,7 +550,7 @@ class BaseInvocation(ABC, BaseModel):
|
||||
if uiconfig and hasattr(uiconfig, "version"):
|
||||
schema["version"] = uiconfig.version
|
||||
if "required" not in schema or not isinstance(schema["required"], list):
|
||||
schema["required"] = list()
|
||||
schema["required"] = []
|
||||
schema["required"].extend(["type", "id"])
|
||||
|
||||
@abstractmethod
|
||||
@ -609,15 +604,15 @@ class BaseInvocation(ABC, BaseModel):
|
||||
id: str = Field(
|
||||
default_factory=uuid_string,
|
||||
description="The id of this instance of an invocation. Must be unique among all instances of invocations.",
|
||||
json_schema_extra=dict(_field_kind="internal"),
|
||||
json_schema_extra={"_field_kind": "internal"},
|
||||
)
|
||||
is_intermediate: bool = Field(
|
||||
default=False,
|
||||
description="Whether or not this is an intermediate invocation.",
|
||||
json_schema_extra=dict(ui_type=UIType.IsIntermediate, _field_kind="internal"),
|
||||
json_schema_extra={"ui_type": UIType.IsIntermediate, "_field_kind": "internal"},
|
||||
)
|
||||
use_cache: bool = Field(
|
||||
default=True, description="Whether or not to use the cache", json_schema_extra=dict(_field_kind="internal")
|
||||
default=True, description="Whether or not to use the cache", json_schema_extra={"_field_kind": "internal"}
|
||||
)
|
||||
|
||||
UIConfig: ClassVar[Type[UIConfigBase]]
|
||||
@ -651,7 +646,7 @@ class _Model(BaseModel):
|
||||
|
||||
|
||||
# Get all pydantic model attrs, methods, etc
|
||||
RESERVED_PYDANTIC_FIELD_NAMES = set(map(lambda m: m[0], inspect.getmembers(_Model())))
|
||||
RESERVED_PYDANTIC_FIELD_NAMES = {m[0] for m in inspect.getmembers(_Model())}
|
||||
|
||||
|
||||
def validate_fields(model_fields: dict[str, FieldInfo], model_type: str) -> None:
|
||||
@ -666,9 +661,7 @@ def validate_fields(model_fields: dict[str, FieldInfo], model_type: str) -> None
|
||||
|
||||
field_kind = (
|
||||
# _field_kind is defined via InputField(), OutputField() or by one of the internal fields defined in this file
|
||||
field.json_schema_extra.get("_field_kind", None)
|
||||
if field.json_schema_extra
|
||||
else None
|
||||
field.json_schema_extra.get("_field_kind", None) if field.json_schema_extra else None
|
||||
)
|
||||
|
||||
# must have a field_kind
|
||||
@ -729,7 +722,7 @@ def invocation(
|
||||
# Add OpenAPI schema extras
|
||||
uiconf_name = cls.__qualname__ + ".UIConfig"
|
||||
if not hasattr(cls, "UIConfig") or cls.UIConfig.__qualname__ != uiconf_name:
|
||||
cls.UIConfig = type(uiconf_name, (UIConfigBase,), dict())
|
||||
cls.UIConfig = type(uiconf_name, (UIConfigBase,), {})
|
||||
if title is not None:
|
||||
cls.UIConfig.title = title
|
||||
if tags is not None:
|
||||
@ -756,7 +749,7 @@ def invocation(
|
||||
|
||||
invocation_type_annotation = Literal[invocation_type] # type: ignore
|
||||
invocation_type_field = Field(
|
||||
title="type", default=invocation_type, json_schema_extra=dict(_field_kind="internal")
|
||||
title="type", default=invocation_type, json_schema_extra={"_field_kind": "internal"}
|
||||
)
|
||||
|
||||
docstring = cls.__doc__
|
||||
@ -802,7 +795,7 @@ def invocation_output(
|
||||
# Add the output type to the model.
|
||||
|
||||
output_type_annotation = Literal[output_type] # type: ignore
|
||||
output_type_field = Field(title="type", default=output_type, json_schema_extra=dict(_field_kind="internal"))
|
||||
output_type_field = Field(title="type", default=output_type, json_schema_extra={"_field_kind": "internal"})
|
||||
|
||||
docstring = cls.__doc__
|
||||
cls = create_model(
|
||||
@ -834,7 +827,7 @@ WorkflowFieldValidator = TypeAdapter(WorkflowField)
|
||||
|
||||
class WithWorkflow(BaseModel):
|
||||
workflow: Optional[WorkflowField] = Field(
|
||||
default=None, description=FieldDescriptions.workflow, json_schema_extra=dict(_field_kind="internal")
|
||||
default=None, description=FieldDescriptions.workflow, json_schema_extra={"_field_kind": "internal"}
|
||||
)
|
||||
|
||||
|
||||
@ -852,5 +845,5 @@ MetadataFieldValidator = TypeAdapter(MetadataField)
|
||||
|
||||
class WithMetadata(BaseModel):
|
||||
metadata: Optional[MetadataField] = Field(
|
||||
default=None, description=FieldDescriptions.metadata, json_schema_extra=dict(_field_kind="internal")
|
||||
default=None, description=FieldDescriptions.metadata, json_schema_extra={"_field_kind": "internal"}
|
||||
)
|
||||
|
@ -131,7 +131,7 @@ def prepare_faces_list(
|
||||
deduped_faces: list[FaceResultData] = []
|
||||
|
||||
if len(face_result_list) == 0:
|
||||
return list()
|
||||
return []
|
||||
|
||||
for candidate in face_result_list:
|
||||
should_add = True
|
||||
@ -210,7 +210,7 @@ def generate_face_box_mask(
|
||||
# Check if any face is detected.
|
||||
if results.multi_face_landmarks: # type: ignore # this are via protobuf and not typed
|
||||
# Search for the face_id in the detected faces.
|
||||
for face_id, face_landmarks in enumerate(results.multi_face_landmarks): # type: ignore #this are via protobuf and not typed
|
||||
for _face_id, face_landmarks in enumerate(results.multi_face_landmarks): # type: ignore #this are via protobuf and not typed
|
||||
# Get the bounding box of the face mesh.
|
||||
x_coordinates = [landmark.x for landmark in face_landmarks.landmark]
|
||||
y_coordinates = [landmark.y for landmark in face_landmarks.landmark]
|
||||
|
@ -77,7 +77,7 @@ if choose_torch_device() == torch.device("mps"):
|
||||
|
||||
DEFAULT_PRECISION = choose_precision(choose_torch_device())
|
||||
|
||||
SAMPLER_NAME_VALUES = Literal[tuple(list(SCHEDULER_MAP.keys()))]
|
||||
SAMPLER_NAME_VALUES = Literal[tuple(SCHEDULER_MAP.keys())]
|
||||
|
||||
|
||||
@invocation_output("scheduler_output")
|
||||
@ -1105,7 +1105,7 @@ class BlendLatentsInvocation(BaseInvocation):
|
||||
latents_b = context.services.latents.get(self.latents_b.latents_name)
|
||||
|
||||
if latents_a.shape != latents_b.shape:
|
||||
raise "Latents to blend must be the same size."
|
||||
raise Exception("Latents to blend must be the same size.")
|
||||
|
||||
# TODO:
|
||||
device = choose_torch_device()
|
||||
|
@ -145,17 +145,17 @@ INTEGER_OPERATIONS = Literal[
|
||||
]
|
||||
|
||||
|
||||
INTEGER_OPERATIONS_LABELS = dict(
|
||||
ADD="Add A+B",
|
||||
SUB="Subtract A-B",
|
||||
MUL="Multiply A*B",
|
||||
DIV="Divide A/B",
|
||||
EXP="Exponentiate A^B",
|
||||
MOD="Modulus A%B",
|
||||
ABS="Absolute Value of A",
|
||||
MIN="Minimum(A,B)",
|
||||
MAX="Maximum(A,B)",
|
||||
)
|
||||
INTEGER_OPERATIONS_LABELS = {
|
||||
"ADD": "Add A+B",
|
||||
"SUB": "Subtract A-B",
|
||||
"MUL": "Multiply A*B",
|
||||
"DIV": "Divide A/B",
|
||||
"EXP": "Exponentiate A^B",
|
||||
"MOD": "Modulus A%B",
|
||||
"ABS": "Absolute Value of A",
|
||||
"MIN": "Minimum(A,B)",
|
||||
"MAX": "Maximum(A,B)",
|
||||
}
|
||||
|
||||
|
||||
@invocation(
|
||||
@ -231,17 +231,17 @@ FLOAT_OPERATIONS = Literal[
|
||||
]
|
||||
|
||||
|
||||
FLOAT_OPERATIONS_LABELS = dict(
|
||||
ADD="Add A+B",
|
||||
SUB="Subtract A-B",
|
||||
MUL="Multiply A*B",
|
||||
DIV="Divide A/B",
|
||||
EXP="Exponentiate A^B",
|
||||
ABS="Absolute Value of A",
|
||||
SQRT="Square Root of A",
|
||||
MIN="Minimum(A,B)",
|
||||
MAX="Maximum(A,B)",
|
||||
)
|
||||
FLOAT_OPERATIONS_LABELS = {
|
||||
"ADD": "Add A+B",
|
||||
"SUB": "Subtract A-B",
|
||||
"MUL": "Multiply A*B",
|
||||
"DIV": "Divide A/B",
|
||||
"EXP": "Exponentiate A^B",
|
||||
"ABS": "Absolute Value of A",
|
||||
"SQRT": "Square Root of A",
|
||||
"MIN": "Minimum(A,B)",
|
||||
"MAX": "Maximum(A,B)",
|
||||
}
|
||||
|
||||
|
||||
@invocation(
|
||||
@ -266,7 +266,7 @@ class FloatMathInvocation(BaseInvocation):
|
||||
raise ValueError("Cannot divide by zero")
|
||||
elif info.data["operation"] == "EXP" and info.data["a"] == 0 and v < 0:
|
||||
raise ValueError("Cannot raise zero to a negative power")
|
||||
elif info.data["operation"] == "EXP" and type(info.data["a"] ** v) is complex:
|
||||
elif info.data["operation"] == "EXP" and isinstance(info.data["a"] ** v, complex):
|
||||
raise ValueError("Root operation resulted in a complex number")
|
||||
return v
|
||||
|
||||
|
@ -54,7 +54,7 @@ ORT_TO_NP_TYPE = {
|
||||
"tensor(double)": np.float64,
|
||||
}
|
||||
|
||||
PRECISION_VALUES = Literal[tuple(list(ORT_TO_NP_TYPE.keys()))]
|
||||
PRECISION_VALUES = Literal[tuple(ORT_TO_NP_TYPE.keys())]
|
||||
|
||||
|
||||
@invocation("prompt_onnx", title="ONNX Prompt (Raw)", tags=["prompt", "onnx"], category="conditioning", version="1.0.0")
|
||||
@ -252,7 +252,7 @@ class ONNXTextToLatentsInvocation(BaseInvocation):
|
||||
scheduler.set_timesteps(self.steps)
|
||||
latents = latents * np.float64(scheduler.init_noise_sigma)
|
||||
|
||||
extra_step_kwargs = dict()
|
||||
extra_step_kwargs = {}
|
||||
if "eta" in set(inspect.signature(scheduler.step).parameters.keys()):
|
||||
extra_step_kwargs.update(
|
||||
eta=0.0,
|
||||
|
@ -100,7 +100,7 @@ EASING_FUNCTIONS_MAP = {
|
||||
"BounceInOut": BounceEaseInOut,
|
||||
}
|
||||
|
||||
EASING_FUNCTION_KEYS = Literal[tuple(list(EASING_FUNCTIONS_MAP.keys()))]
|
||||
EASING_FUNCTION_KEYS = Literal[tuple(EASING_FUNCTIONS_MAP.keys())]
|
||||
|
||||
|
||||
# actually I think for now could just use CollectionOutput (which is list[Any]
|
||||
@ -161,7 +161,7 @@ class StepParamEasingInvocation(BaseInvocation):
|
||||
easing_class = EASING_FUNCTIONS_MAP[self.easing]
|
||||
if log_diagnostics:
|
||||
context.services.logger.debug("easing class: " + str(easing_class))
|
||||
easing_list = list()
|
||||
easing_list = []
|
||||
if self.mirror: # "expected" mirroring
|
||||
# if number of steps is even, squeeze duration down to (number_of_steps)/2
|
||||
# and create reverse copy of list to append
|
||||
@ -178,7 +178,7 @@ class StepParamEasingInvocation(BaseInvocation):
|
||||
end=self.end_value,
|
||||
duration=base_easing_duration - 1,
|
||||
)
|
||||
base_easing_vals = list()
|
||||
base_easing_vals = []
|
||||
for step_index in range(base_easing_duration):
|
||||
easing_val = easing_function.ease(step_index)
|
||||
base_easing_vals.append(easing_val)
|
||||
|
@ -139,7 +139,7 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
|
||||
(board_id,),
|
||||
)
|
||||
result = cast(list[sqlite3.Row], self._cursor.fetchall())
|
||||
images = list(map(lambda r: deserialize_image_record(dict(r)), result))
|
||||
images = [deserialize_image_record(dict(r)) for r in result]
|
||||
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
@ -167,7 +167,7 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
|
||||
(board_id,),
|
||||
)
|
||||
result = cast(list[sqlite3.Row], self._cursor.fetchall())
|
||||
image_names = list(map(lambda r: r[0], result))
|
||||
image_names = [r[0] for r in result]
|
||||
return image_names
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
|
@ -199,7 +199,7 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
|
||||
)
|
||||
|
||||
result = cast(list[sqlite3.Row], self._cursor.fetchall())
|
||||
boards = list(map(lambda r: deserialize_board_record(dict(r)), result))
|
||||
boards = [deserialize_board_record(dict(r)) for r in result]
|
||||
|
||||
# Get the total number of boards
|
||||
self._cursor.execute(
|
||||
@ -236,7 +236,7 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
|
||||
)
|
||||
|
||||
result = cast(list[sqlite3.Row], self._cursor.fetchall())
|
||||
boards = list(map(lambda r: deserialize_board_record(dict(r)), result))
|
||||
boards = [deserialize_board_record(dict(r)) for r in result]
|
||||
|
||||
return boards
|
||||
|
||||
|
@ -55,7 +55,7 @@ class InvokeAISettings(BaseSettings):
|
||||
"""
|
||||
cls = self.__class__
|
||||
type = get_args(get_type_hints(cls)["type"])[0]
|
||||
field_dict = dict({type: dict()})
|
||||
field_dict = {type: {}}
|
||||
for name, field in self.model_fields.items():
|
||||
if name in cls._excluded_from_yaml():
|
||||
continue
|
||||
@ -64,7 +64,7 @@ class InvokeAISettings(BaseSettings):
|
||||
)
|
||||
value = getattr(self, name)
|
||||
if category not in field_dict[type]:
|
||||
field_dict[type][category] = dict()
|
||||
field_dict[type][category] = {}
|
||||
# keep paths as strings to make it easier to read
|
||||
field_dict[type][category][name] = str(value) if isinstance(value, Path) else value
|
||||
conf = OmegaConf.create(field_dict)
|
||||
@ -89,7 +89,7 @@ class InvokeAISettings(BaseSettings):
|
||||
# create an upcase version of the environment in
|
||||
# order to achieve case-insensitive environment
|
||||
# variables (the way Windows does)
|
||||
upcase_environ = dict()
|
||||
upcase_environ = {}
|
||||
for key, value in os.environ.items():
|
||||
upcase_environ[key.upper()] = value
|
||||
|
||||
|
@ -188,18 +188,18 @@ DEFAULT_MAX_VRAM = 0.5
|
||||
|
||||
|
||||
class Categories(object):
|
||||
WebServer = dict(category="Web Server")
|
||||
Features = dict(category="Features")
|
||||
Paths = dict(category="Paths")
|
||||
Logging = dict(category="Logging")
|
||||
Development = dict(category="Development")
|
||||
Other = dict(category="Other")
|
||||
ModelCache = dict(category="Model Cache")
|
||||
Device = dict(category="Device")
|
||||
Generation = dict(category="Generation")
|
||||
Queue = dict(category="Queue")
|
||||
Nodes = dict(category="Nodes")
|
||||
MemoryPerformance = dict(category="Memory/Performance")
|
||||
WebServer = {"category": "Web Server"}
|
||||
Features = {"category": "Features"}
|
||||
Paths = {"category": "Paths"}
|
||||
Logging = {"category": "Logging"}
|
||||
Development = {"category": "Development"}
|
||||
Other = {"category": "Other"}
|
||||
ModelCache = {"category": "Model Cache"}
|
||||
Device = {"category": "Device"}
|
||||
Generation = {"category": "Generation"}
|
||||
Queue = {"category": "Queue"}
|
||||
Nodes = {"category": "Nodes"}
|
||||
MemoryPerformance = {"category": "Memory/Performance"}
|
||||
|
||||
|
||||
class InvokeAIAppConfig(InvokeAISettings):
|
||||
@ -482,7 +482,7 @@ def _find_root() -> Path:
|
||||
venv = Path(os.environ.get("VIRTUAL_ENV") or ".")
|
||||
if os.environ.get("INVOKEAI_ROOT"):
|
||||
root = Path(os.environ["INVOKEAI_ROOT"])
|
||||
elif any([(venv.parent / x).exists() for x in [INIT_FILE, LEGACY_INIT_FILE]]):
|
||||
elif any((venv.parent / x).exists() for x in [INIT_FILE, LEGACY_INIT_FILE]):
|
||||
root = (venv.parent).resolve()
|
||||
else:
|
||||
root = Path("~/invokeai").expanduser().resolve()
|
||||
|
@ -27,7 +27,7 @@ class EventServiceBase:
|
||||
payload["timestamp"] = get_timestamp()
|
||||
self.dispatch(
|
||||
event_name=EventServiceBase.queue_event,
|
||||
payload=dict(event=event_name, data=payload),
|
||||
payload={"event": event_name, "data": payload},
|
||||
)
|
||||
|
||||
# Define events here for every event in the system.
|
||||
@ -48,18 +48,18 @@ class EventServiceBase:
|
||||
"""Emitted when there is generation progress"""
|
||||
self.__emit_queue_event(
|
||||
event_name="generator_progress",
|
||||
payload=dict(
|
||||
queue_id=queue_id,
|
||||
queue_item_id=queue_item_id,
|
||||
queue_batch_id=queue_batch_id,
|
||||
graph_execution_state_id=graph_execution_state_id,
|
||||
node_id=node.get("id"),
|
||||
source_node_id=source_node_id,
|
||||
progress_image=progress_image.model_dump() if progress_image is not None else None,
|
||||
step=step,
|
||||
order=order,
|
||||
total_steps=total_steps,
|
||||
),
|
||||
payload={
|
||||
"queue_id": queue_id,
|
||||
"queue_item_id": queue_item_id,
|
||||
"queue_batch_id": queue_batch_id,
|
||||
"graph_execution_state_id": graph_execution_state_id,
|
||||
"node_id": node.get("id"),
|
||||
"source_node_id": source_node_id,
|
||||
"progress_image": progress_image.model_dump() if progress_image is not None else None,
|
||||
"step": step,
|
||||
"order": order,
|
||||
"total_steps": total_steps,
|
||||
},
|
||||
)
|
||||
|
||||
def emit_invocation_complete(
|
||||
@ -75,15 +75,15 @@ class EventServiceBase:
|
||||
"""Emitted when an invocation has completed"""
|
||||
self.__emit_queue_event(
|
||||
event_name="invocation_complete",
|
||||
payload=dict(
|
||||
queue_id=queue_id,
|
||||
queue_item_id=queue_item_id,
|
||||
queue_batch_id=queue_batch_id,
|
||||
graph_execution_state_id=graph_execution_state_id,
|
||||
node=node,
|
||||
source_node_id=source_node_id,
|
||||
result=result,
|
||||
),
|
||||
payload={
|
||||
"queue_id": queue_id,
|
||||
"queue_item_id": queue_item_id,
|
||||
"queue_batch_id": queue_batch_id,
|
||||
"graph_execution_state_id": graph_execution_state_id,
|
||||
"node": node,
|
||||
"source_node_id": source_node_id,
|
||||
"result": result,
|
||||
},
|
||||
)
|
||||
|
||||
def emit_invocation_error(
|
||||
@ -100,16 +100,16 @@ class EventServiceBase:
|
||||
"""Emitted when an invocation has completed"""
|
||||
self.__emit_queue_event(
|
||||
event_name="invocation_error",
|
||||
payload=dict(
|
||||
queue_id=queue_id,
|
||||
queue_item_id=queue_item_id,
|
||||
queue_batch_id=queue_batch_id,
|
||||
graph_execution_state_id=graph_execution_state_id,
|
||||
node=node,
|
||||
source_node_id=source_node_id,
|
||||
error_type=error_type,
|
||||
error=error,
|
||||
),
|
||||
payload={
|
||||
"queue_id": queue_id,
|
||||
"queue_item_id": queue_item_id,
|
||||
"queue_batch_id": queue_batch_id,
|
||||
"graph_execution_state_id": graph_execution_state_id,
|
||||
"node": node,
|
||||
"source_node_id": source_node_id,
|
||||
"error_type": error_type,
|
||||
"error": error,
|
||||
},
|
||||
)
|
||||
|
||||
def emit_invocation_started(
|
||||
@ -124,14 +124,14 @@ class EventServiceBase:
|
||||
"""Emitted when an invocation has started"""
|
||||
self.__emit_queue_event(
|
||||
event_name="invocation_started",
|
||||
payload=dict(
|
||||
queue_id=queue_id,
|
||||
queue_item_id=queue_item_id,
|
||||
queue_batch_id=queue_batch_id,
|
||||
graph_execution_state_id=graph_execution_state_id,
|
||||
node=node,
|
||||
source_node_id=source_node_id,
|
||||
),
|
||||
payload={
|
||||
"queue_id": queue_id,
|
||||
"queue_item_id": queue_item_id,
|
||||
"queue_batch_id": queue_batch_id,
|
||||
"graph_execution_state_id": graph_execution_state_id,
|
||||
"node": node,
|
||||
"source_node_id": source_node_id,
|
||||
},
|
||||
)
|
||||
|
||||
def emit_graph_execution_complete(
|
||||
@ -140,12 +140,12 @@ class EventServiceBase:
|
||||
"""Emitted when a session has completed all invocations"""
|
||||
self.__emit_queue_event(
|
||||
event_name="graph_execution_state_complete",
|
||||
payload=dict(
|
||||
queue_id=queue_id,
|
||||
queue_item_id=queue_item_id,
|
||||
queue_batch_id=queue_batch_id,
|
||||
graph_execution_state_id=graph_execution_state_id,
|
||||
),
|
||||
payload={
|
||||
"queue_id": queue_id,
|
||||
"queue_item_id": queue_item_id,
|
||||
"queue_batch_id": queue_batch_id,
|
||||
"graph_execution_state_id": graph_execution_state_id,
|
||||
},
|
||||
)
|
||||
|
||||
def emit_model_load_started(
|
||||
@ -162,16 +162,16 @@ class EventServiceBase:
|
||||
"""Emitted when a model is requested"""
|
||||
self.__emit_queue_event(
|
||||
event_name="model_load_started",
|
||||
payload=dict(
|
||||
queue_id=queue_id,
|
||||
queue_item_id=queue_item_id,
|
||||
queue_batch_id=queue_batch_id,
|
||||
graph_execution_state_id=graph_execution_state_id,
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=submodel,
|
||||
),
|
||||
payload={
|
||||
"queue_id": queue_id,
|
||||
"queue_item_id": queue_item_id,
|
||||
"queue_batch_id": queue_batch_id,
|
||||
"graph_execution_state_id": graph_execution_state_id,
|
||||
"model_name": model_name,
|
||||
"base_model": base_model,
|
||||
"model_type": model_type,
|
||||
"submodel": submodel,
|
||||
},
|
||||
)
|
||||
|
||||
def emit_model_load_completed(
|
||||
@ -189,19 +189,19 @@ class EventServiceBase:
|
||||
"""Emitted when a model is correctly loaded (returns model info)"""
|
||||
self.__emit_queue_event(
|
||||
event_name="model_load_completed",
|
||||
payload=dict(
|
||||
queue_id=queue_id,
|
||||
queue_item_id=queue_item_id,
|
||||
queue_batch_id=queue_batch_id,
|
||||
graph_execution_state_id=graph_execution_state_id,
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=submodel,
|
||||
hash=model_info.hash,
|
||||
location=str(model_info.location),
|
||||
precision=str(model_info.precision),
|
||||
),
|
||||
payload={
|
||||
"queue_id": queue_id,
|
||||
"queue_item_id": queue_item_id,
|
||||
"queue_batch_id": queue_batch_id,
|
||||
"graph_execution_state_id": graph_execution_state_id,
|
||||
"model_name": model_name,
|
||||
"base_model": base_model,
|
||||
"model_type": model_type,
|
||||
"submodel": submodel,
|
||||
"hash": model_info.hash,
|
||||
"location": str(model_info.location),
|
||||
"precision": str(model_info.precision),
|
||||
},
|
||||
)
|
||||
|
||||
def emit_session_retrieval_error(
|
||||
@ -216,14 +216,14 @@ class EventServiceBase:
|
||||
"""Emitted when session retrieval fails"""
|
||||
self.__emit_queue_event(
|
||||
event_name="session_retrieval_error",
|
||||
payload=dict(
|
||||
queue_id=queue_id,
|
||||
queue_item_id=queue_item_id,
|
||||
queue_batch_id=queue_batch_id,
|
||||
graph_execution_state_id=graph_execution_state_id,
|
||||
error_type=error_type,
|
||||
error=error,
|
||||
),
|
||||
payload={
|
||||
"queue_id": queue_id,
|
||||
"queue_item_id": queue_item_id,
|
||||
"queue_batch_id": queue_batch_id,
|
||||
"graph_execution_state_id": graph_execution_state_id,
|
||||
"error_type": error_type,
|
||||
"error": error,
|
||||
},
|
||||
)
|
||||
|
||||
def emit_invocation_retrieval_error(
|
||||
@ -239,15 +239,15 @@ class EventServiceBase:
|
||||
"""Emitted when invocation retrieval fails"""
|
||||
self.__emit_queue_event(
|
||||
event_name="invocation_retrieval_error",
|
||||
payload=dict(
|
||||
queue_id=queue_id,
|
||||
queue_item_id=queue_item_id,
|
||||
queue_batch_id=queue_batch_id,
|
||||
graph_execution_state_id=graph_execution_state_id,
|
||||
node_id=node_id,
|
||||
error_type=error_type,
|
||||
error=error,
|
||||
),
|
||||
payload={
|
||||
"queue_id": queue_id,
|
||||
"queue_item_id": queue_item_id,
|
||||
"queue_batch_id": queue_batch_id,
|
||||
"graph_execution_state_id": graph_execution_state_id,
|
||||
"node_id": node_id,
|
||||
"error_type": error_type,
|
||||
"error": error,
|
||||
},
|
||||
)
|
||||
|
||||
def emit_session_canceled(
|
||||
@ -260,12 +260,12 @@ class EventServiceBase:
|
||||
"""Emitted when a session is canceled"""
|
||||
self.__emit_queue_event(
|
||||
event_name="session_canceled",
|
||||
payload=dict(
|
||||
queue_id=queue_id,
|
||||
queue_item_id=queue_item_id,
|
||||
queue_batch_id=queue_batch_id,
|
||||
graph_execution_state_id=graph_execution_state_id,
|
||||
),
|
||||
payload={
|
||||
"queue_id": queue_id,
|
||||
"queue_item_id": queue_item_id,
|
||||
"queue_batch_id": queue_batch_id,
|
||||
"graph_execution_state_id": graph_execution_state_id,
|
||||
},
|
||||
)
|
||||
|
||||
def emit_queue_item_status_changed(
|
||||
@ -277,39 +277,39 @@ class EventServiceBase:
|
||||
"""Emitted when a queue item's status changes"""
|
||||
self.__emit_queue_event(
|
||||
event_name="queue_item_status_changed",
|
||||
payload=dict(
|
||||
queue_id=queue_status.queue_id,
|
||||
queue_item=dict(
|
||||
queue_id=session_queue_item.queue_id,
|
||||
item_id=session_queue_item.item_id,
|
||||
status=session_queue_item.status,
|
||||
batch_id=session_queue_item.batch_id,
|
||||
session_id=session_queue_item.session_id,
|
||||
error=session_queue_item.error,
|
||||
created_at=str(session_queue_item.created_at) if session_queue_item.created_at else None,
|
||||
updated_at=str(session_queue_item.updated_at) if session_queue_item.updated_at else None,
|
||||
started_at=str(session_queue_item.started_at) if session_queue_item.started_at else None,
|
||||
completed_at=str(session_queue_item.completed_at) if session_queue_item.completed_at else None,
|
||||
),
|
||||
batch_status=batch_status.model_dump(),
|
||||
queue_status=queue_status.model_dump(),
|
||||
),
|
||||
payload={
|
||||
"queue_id": queue_status.queue_id,
|
||||
"queue_item": {
|
||||
"queue_id": session_queue_item.queue_id,
|
||||
"item_id": session_queue_item.item_id,
|
||||
"status": session_queue_item.status,
|
||||
"batch_id": session_queue_item.batch_id,
|
||||
"session_id": session_queue_item.session_id,
|
||||
"error": session_queue_item.error,
|
||||
"created_at": str(session_queue_item.created_at) if session_queue_item.created_at else None,
|
||||
"updated_at": str(session_queue_item.updated_at) if session_queue_item.updated_at else None,
|
||||
"started_at": str(session_queue_item.started_at) if session_queue_item.started_at else None,
|
||||
"completed_at": str(session_queue_item.completed_at) if session_queue_item.completed_at else None,
|
||||
},
|
||||
"batch_status": batch_status.model_dump(),
|
||||
"queue_status": queue_status.model_dump(),
|
||||
},
|
||||
)
|
||||
|
||||
def emit_batch_enqueued(self, enqueue_result: EnqueueBatchResult) -> None:
|
||||
"""Emitted when a batch is enqueued"""
|
||||
self.__emit_queue_event(
|
||||
event_name="batch_enqueued",
|
||||
payload=dict(
|
||||
queue_id=enqueue_result.queue_id,
|
||||
batch_id=enqueue_result.batch.batch_id,
|
||||
enqueued=enqueue_result.enqueued,
|
||||
),
|
||||
payload={
|
||||
"queue_id": enqueue_result.queue_id,
|
||||
"batch_id": enqueue_result.batch.batch_id,
|
||||
"enqueued": enqueue_result.enqueued,
|
||||
},
|
||||
)
|
||||
|
||||
def emit_queue_cleared(self, queue_id: str) -> None:
|
||||
"""Emitted when the queue is cleared"""
|
||||
self.__emit_queue_event(
|
||||
event_name="queue_cleared",
|
||||
payload=dict(queue_id=queue_id),
|
||||
payload={"queue_id": queue_id},
|
||||
)
|
||||
|
@ -25,7 +25,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
|
||||
__invoker: Invoker
|
||||
|
||||
def __init__(self, output_folder: Union[str, Path]):
|
||||
self.__cache = dict()
|
||||
self.__cache = {}
|
||||
self.__cache_ids = Queue()
|
||||
self.__max_cache_size = 10 # TODO: get this from config
|
||||
|
||||
|
@ -90,25 +90,23 @@ class ImageRecordDeleteException(Exception):
|
||||
|
||||
|
||||
IMAGE_DTO_COLS = ", ".join(
|
||||
list(
|
||||
map(
|
||||
lambda c: "images." + c,
|
||||
[
|
||||
"image_name",
|
||||
"image_origin",
|
||||
"image_category",
|
||||
"width",
|
||||
"height",
|
||||
"session_id",
|
||||
"node_id",
|
||||
"is_intermediate",
|
||||
"created_at",
|
||||
"updated_at",
|
||||
"deleted_at",
|
||||
"starred",
|
||||
],
|
||||
)
|
||||
)
|
||||
[
|
||||
"images." + c
|
||||
for c in [
|
||||
"image_name",
|
||||
"image_origin",
|
||||
"image_category",
|
||||
"width",
|
||||
"height",
|
||||
"session_id",
|
||||
"node_id",
|
||||
"is_intermediate",
|
||||
"created_at",
|
||||
"updated_at",
|
||||
"deleted_at",
|
||||
"starred",
|
||||
]
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
|
@ -263,7 +263,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
|
||||
if categories is not None:
|
||||
# Convert the enum values to unique list of strings
|
||||
category_strings = list(map(lambda c: c.value, set(categories)))
|
||||
category_strings = [c.value for c in set(categories)]
|
||||
# Create the correct length of placeholders
|
||||
placeholders = ",".join("?" * len(category_strings))
|
||||
|
||||
@ -307,7 +307,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
# Build the list of images, deserializing each row
|
||||
self._cursor.execute(images_query, images_params)
|
||||
result = cast(list[sqlite3.Row], self._cursor.fetchall())
|
||||
images = list(map(lambda r: deserialize_image_record(dict(r)), result))
|
||||
images = [deserialize_image_record(dict(r)) for r in result]
|
||||
|
||||
# Set up and execute the count query, without pagination
|
||||
count_query += query_conditions + ";"
|
||||
@ -386,7 +386,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
"""
|
||||
)
|
||||
result = cast(list[sqlite3.Row], self._cursor.fetchall())
|
||||
image_names = list(map(lambda r: r[0], result))
|
||||
image_names = [r[0] for r in result]
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
DELETE FROM images
|
||||
|
@ -21,8 +21,8 @@ class ImageServiceABC(ABC):
|
||||
_on_deleted_callbacks: list[Callable[[str], None]]
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._on_changed_callbacks = list()
|
||||
self._on_deleted_callbacks = list()
|
||||
self._on_changed_callbacks = []
|
||||
self._on_deleted_callbacks = []
|
||||
|
||||
def on_changed(self, on_changed: Callable[[ImageDTO], None]) -> None:
|
||||
"""Register a callback for when an image is changed"""
|
||||
|
@ -217,18 +217,16 @@ class ImageService(ImageServiceABC):
|
||||
board_id,
|
||||
)
|
||||
|
||||
image_dtos = list(
|
||||
map(
|
||||
lambda r: image_record_to_dto(
|
||||
image_record=r,
|
||||
image_url=self.__invoker.services.urls.get_image_url(r.image_name),
|
||||
thumbnail_url=self.__invoker.services.urls.get_image_url(r.image_name, True),
|
||||
board_id=self.__invoker.services.board_image_records.get_board_for_image(r.image_name),
|
||||
workflow_id=self.__invoker.services.workflow_image_records.get_workflow_for_image(r.image_name),
|
||||
),
|
||||
results.items,
|
||||
image_dtos = [
|
||||
image_record_to_dto(
|
||||
image_record=r,
|
||||
image_url=self.__invoker.services.urls.get_image_url(r.image_name),
|
||||
thumbnail_url=self.__invoker.services.urls.get_image_url(r.image_name, True),
|
||||
board_id=self.__invoker.services.board_image_records.get_board_for_image(r.image_name),
|
||||
workflow_id=self.__invoker.services.workflow_image_records.get_workflow_for_image(r.image_name),
|
||||
)
|
||||
)
|
||||
for r in results.items
|
||||
]
|
||||
|
||||
return OffsetPaginatedResults[ImageDTO](
|
||||
items=image_dtos,
|
||||
|
@ -1,5 +1,5 @@
|
||||
from abc import ABC
|
||||
|
||||
|
||||
class InvocationProcessorABC(ABC):
|
||||
class InvocationProcessorABC(ABC): # noqa: B024
|
||||
pass
|
||||
|
@ -26,7 +26,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
self.__invoker_thread = Thread(
|
||||
name="invoker_processor",
|
||||
target=self.__process,
|
||||
kwargs=dict(stop_event=self.__stop_event),
|
||||
kwargs={"stop_event": self.__stop_event},
|
||||
)
|
||||
self.__invoker_thread.daemon = True # TODO: make async and do not use threads
|
||||
self.__invoker_thread.start()
|
||||
|
@ -14,7 +14,7 @@ class MemoryInvocationQueue(InvocationQueueABC):
|
||||
|
||||
def __init__(self):
|
||||
self.__queue = Queue()
|
||||
self.__cancellations = dict()
|
||||
self.__cancellations = {}
|
||||
|
||||
def get(self) -> InvocationQueueItem:
|
||||
item = self.__queue.get()
|
||||
|
@ -122,7 +122,7 @@ class InvocationStatsService(InvocationStatsServiceBase):
|
||||
def log_stats(self):
|
||||
completed = set()
|
||||
errored = set()
|
||||
for graph_id, node_log in self._stats.items():
|
||||
for graph_id, _node_log in self._stats.items():
|
||||
try:
|
||||
current_graph_state = self._invoker.services.graph_execution_manager.get(graph_id)
|
||||
except Exception:
|
||||
@ -142,7 +142,7 @@ class InvocationStatsService(InvocationStatsServiceBase):
|
||||
cache_stats = self._cache_stats[graph_id]
|
||||
hwm = cache_stats.high_watermark / GIG
|
||||
tot = cache_stats.cache_size / GIG
|
||||
loaded = sum([v for v in cache_stats.loaded_model_sizes.values()]) / GIG
|
||||
loaded = sum(list(cache_stats.loaded_model_sizes.values())) / GIG
|
||||
|
||||
logger.info(f"TOTAL GRAPH EXECUTION TIME: {total_time:7.3f}s")
|
||||
logger.info("RAM used by InvokeAI process: " + "%4.2fG" % self.ram_used + f" ({self.ram_changed:+5.3f}G)")
|
||||
|
@ -15,8 +15,8 @@ class ItemStorageABC(ABC, Generic[T]):
|
||||
_on_deleted_callbacks: list[Callable[[str], None]]
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._on_changed_callbacks = list()
|
||||
self._on_deleted_callbacks = list()
|
||||
self._on_changed_callbacks = []
|
||||
self._on_deleted_callbacks = []
|
||||
|
||||
"""Base item storage class"""
|
||||
|
||||
|
@ -112,7 +112,7 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
||||
)
|
||||
result = self._cursor.fetchall()
|
||||
|
||||
items = list(map(lambda r: self._parse_item(r[0]), result))
|
||||
items = [self._parse_item(r[0]) for r in result]
|
||||
|
||||
self._cursor.execute(f"""SELECT count(*) FROM {self._table_name};""")
|
||||
count = self._cursor.fetchone()[0]
|
||||
@ -132,7 +132,7 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
||||
)
|
||||
result = self._cursor.fetchall()
|
||||
|
||||
items = list(map(lambda r: self._parse_item(r[0]), result))
|
||||
items = [self._parse_item(r[0]) for r in result]
|
||||
|
||||
self._cursor.execute(
|
||||
f"""SELECT count(*) FROM {self._table_name} WHERE item LIKE ?;""",
|
||||
|
@ -13,8 +13,8 @@ class LatentsStorageBase(ABC):
|
||||
_on_deleted_callbacks: list[Callable[[str], None]]
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._on_changed_callbacks = list()
|
||||
self._on_deleted_callbacks = list()
|
||||
self._on_changed_callbacks = []
|
||||
self._on_deleted_callbacks = []
|
||||
|
||||
@abstractmethod
|
||||
def get(self, name: str) -> torch.Tensor:
|
||||
|
@ -19,7 +19,7 @@ class ForwardCacheLatentsStorage(LatentsStorageBase):
|
||||
def __init__(self, underlying_storage: LatentsStorageBase, max_cache_size: int = 20):
|
||||
super().__init__()
|
||||
self.__underlying_storage = underlying_storage
|
||||
self.__cache = dict()
|
||||
self.__cache = {}
|
||||
self.__cache_ids = Queue()
|
||||
self.__max_cache_size = max_cache_size
|
||||
|
||||
|
@ -33,9 +33,11 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
||||
self.__thread = Thread(
|
||||
name="session_processor",
|
||||
target=self.__process,
|
||||
kwargs=dict(
|
||||
stop_event=self.__stop_event, poll_now_event=self.__poll_now_event, resume_event=self.__resume_event
|
||||
),
|
||||
kwargs={
|
||||
"stop_event": self.__stop_event,
|
||||
"poll_now_event": self.__poll_now_event,
|
||||
"resume_event": self.__resume_event,
|
||||
},
|
||||
)
|
||||
self.__thread.start()
|
||||
|
||||
|
@ -129,12 +129,12 @@ class Batch(BaseModel):
|
||||
return v
|
||||
|
||||
model_config = ConfigDict(
|
||||
json_schema_extra=dict(
|
||||
required=[
|
||||
json_schema_extra={
|
||||
"required": [
|
||||
"graph",
|
||||
"runs",
|
||||
]
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@ -191,8 +191,8 @@ class SessionQueueItemWithoutGraph(BaseModel):
|
||||
return SessionQueueItemDTO(**queue_item_dict)
|
||||
|
||||
model_config = ConfigDict(
|
||||
json_schema_extra=dict(
|
||||
required=[
|
||||
json_schema_extra={
|
||||
"required": [
|
||||
"item_id",
|
||||
"status",
|
||||
"batch_id",
|
||||
@ -203,7 +203,7 @@ class SessionQueueItemWithoutGraph(BaseModel):
|
||||
"created_at",
|
||||
"updated_at",
|
||||
]
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@ -222,8 +222,8 @@ class SessionQueueItem(SessionQueueItemWithoutGraph):
|
||||
return SessionQueueItem(**queue_item_dict)
|
||||
|
||||
model_config = ConfigDict(
|
||||
json_schema_extra=dict(
|
||||
required=[
|
||||
json_schema_extra={
|
||||
"required": [
|
||||
"item_id",
|
||||
"status",
|
||||
"batch_id",
|
||||
@ -235,7 +235,7 @@ class SessionQueueItem(SessionQueueItemWithoutGraph):
|
||||
"created_at",
|
||||
"updated_at",
|
||||
]
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@ -355,7 +355,7 @@ def create_session_nfv_tuples(
|
||||
for item in batch_datum.items
|
||||
]
|
||||
node_field_values_to_zip.append(node_field_values)
|
||||
data.append(list(zip(*node_field_values_to_zip))) # type: ignore [arg-type]
|
||||
data.append(list(zip(*node_field_values_to_zip, strict=True))) # type: ignore [arg-type]
|
||||
|
||||
# create generator to yield session,nfv tuples
|
||||
count = 0
|
||||
@ -383,7 +383,7 @@ def calc_session_count(batch: Batch) -> int:
|
||||
for batch_datum in batch_datum_list:
|
||||
batch_data_items = range(len(batch_datum.items))
|
||||
to_zip.append(batch_data_items)
|
||||
data.append(list(zip(*to_zip)))
|
||||
data.append(list(zip(*to_zip, strict=True)))
|
||||
data_product = list(product(*data))
|
||||
return len(data_product) * batch.runs
|
||||
|
||||
|
@ -78,7 +78,7 @@ def create_system_graphs(graph_library: ItemStorageABC[LibraryGraph]) -> list[Li
|
||||
"""Creates the default system graphs, or adds new versions if the old ones don't match"""
|
||||
|
||||
# TODO: Uncomment this when we are ready to fix this up to prevent breaking changes
|
||||
graphs: list[LibraryGraph] = list()
|
||||
graphs: list[LibraryGraph] = []
|
||||
|
||||
text_to_image = graph_library.get(default_text_to_image_graph_id)
|
||||
|
||||
|
@ -352,7 +352,7 @@ class Graph(BaseModel):
|
||||
|
||||
# Validate that all node ids are unique
|
||||
node_ids = [n.id for n in self.nodes.values()]
|
||||
duplicate_node_ids = set([node_id for node_id in node_ids if node_ids.count(node_id) >= 2])
|
||||
duplicate_node_ids = {node_id for node_id in node_ids if node_ids.count(node_id) >= 2}
|
||||
if duplicate_node_ids:
|
||||
raise DuplicateNodeIdError(f"Node ids must be unique, found duplicates {duplicate_node_ids}")
|
||||
|
||||
@ -616,7 +616,7 @@ class Graph(BaseModel):
|
||||
self, node_path: str, prefix: Optional[str] = None
|
||||
) -> list[tuple["Graph", Union[str, None], Edge]]:
|
||||
"""Gets all input edges for a node along with the graph they are in and the graph's path"""
|
||||
edges = list()
|
||||
edges = []
|
||||
|
||||
# Return any input edges that appear in this graph
|
||||
edges.extend([(self, prefix, e) for e in self.edges if e.destination.node_id == node_path])
|
||||
@ -658,7 +658,7 @@ class Graph(BaseModel):
|
||||
self, node_path: str, prefix: Optional[str] = None
|
||||
) -> list[tuple["Graph", Union[str, None], Edge]]:
|
||||
"""Gets all output edges for a node along with the graph they are in and the graph's path"""
|
||||
edges = list()
|
||||
edges = []
|
||||
|
||||
# Return any input edges that appear in this graph
|
||||
edges.extend([(self, prefix, e) for e in self.edges if e.source.node_id == node_path])
|
||||
@ -680,8 +680,8 @@ class Graph(BaseModel):
|
||||
new_input: Optional[EdgeConnection] = None,
|
||||
new_output: Optional[EdgeConnection] = None,
|
||||
) -> bool:
|
||||
inputs = list([e.source for e in self._get_input_edges(node_path, "collection")])
|
||||
outputs = list([e.destination for e in self._get_output_edges(node_path, "item")])
|
||||
inputs = [e.source for e in self._get_input_edges(node_path, "collection")]
|
||||
outputs = [e.destination for e in self._get_output_edges(node_path, "item")]
|
||||
|
||||
if new_input is not None:
|
||||
inputs.append(new_input)
|
||||
@ -694,7 +694,7 @@ class Graph(BaseModel):
|
||||
|
||||
# Get input and output fields (the fields linked to the iterator's input/output)
|
||||
input_field = get_output_field(self.get_node(inputs[0].node_id), inputs[0].field)
|
||||
output_fields = list([get_input_field(self.get_node(e.node_id), e.field) for e in outputs])
|
||||
output_fields = [get_input_field(self.get_node(e.node_id), e.field) for e in outputs]
|
||||
|
||||
# Input type must be a list
|
||||
if get_origin(input_field) != list:
|
||||
@ -713,8 +713,8 @@ class Graph(BaseModel):
|
||||
new_input: Optional[EdgeConnection] = None,
|
||||
new_output: Optional[EdgeConnection] = None,
|
||||
) -> bool:
|
||||
inputs = list([e.source for e in self._get_input_edges(node_path, "item")])
|
||||
outputs = list([e.destination for e in self._get_output_edges(node_path, "collection")])
|
||||
inputs = [e.source for e in self._get_input_edges(node_path, "item")]
|
||||
outputs = [e.destination for e in self._get_output_edges(node_path, "collection")]
|
||||
|
||||
if new_input is not None:
|
||||
inputs.append(new_input)
|
||||
@ -722,18 +722,16 @@ class Graph(BaseModel):
|
||||
outputs.append(new_output)
|
||||
|
||||
# Get input and output fields (the fields linked to the iterator's input/output)
|
||||
input_fields = list([get_output_field(self.get_node(e.node_id), e.field) for e in inputs])
|
||||
output_fields = list([get_input_field(self.get_node(e.node_id), e.field) for e in outputs])
|
||||
input_fields = [get_output_field(self.get_node(e.node_id), e.field) for e in inputs]
|
||||
output_fields = [get_input_field(self.get_node(e.node_id), e.field) for e in outputs]
|
||||
|
||||
# Validate that all inputs are derived from or match a single type
|
||||
input_field_types = set(
|
||||
[
|
||||
t
|
||||
for input_field in input_fields
|
||||
for t in ([input_field] if get_origin(input_field) is None else get_args(input_field))
|
||||
if t != NoneType
|
||||
]
|
||||
) # Get unique types
|
||||
input_field_types = {
|
||||
t
|
||||
for input_field in input_fields
|
||||
for t in ([input_field] if get_origin(input_field) is None else get_args(input_field))
|
||||
if t != NoneType
|
||||
} # Get unique types
|
||||
type_tree = nx.DiGraph()
|
||||
type_tree.add_nodes_from(input_field_types)
|
||||
type_tree.add_edges_from([e for e in itertools.permutations(input_field_types, 2) if issubclass(e[1], e[0])])
|
||||
@ -761,15 +759,15 @@ class Graph(BaseModel):
|
||||
"""Returns a NetworkX DiGraph representing the layout of this graph"""
|
||||
# TODO: Cache this?
|
||||
g = nx.DiGraph()
|
||||
g.add_nodes_from([n for n in self.nodes.keys()])
|
||||
g.add_edges_from(set([(e.source.node_id, e.destination.node_id) for e in self.edges]))
|
||||
g.add_nodes_from(list(self.nodes.keys()))
|
||||
g.add_edges_from({(e.source.node_id, e.destination.node_id) for e in self.edges})
|
||||
return g
|
||||
|
||||
def nx_graph_with_data(self) -> nx.DiGraph:
|
||||
"""Returns a NetworkX DiGraph representing the data and layout of this graph"""
|
||||
g = nx.DiGraph()
|
||||
g.add_nodes_from([n for n in self.nodes.items()])
|
||||
g.add_edges_from(set([(e.source.node_id, e.destination.node_id) for e in self.edges]))
|
||||
g.add_nodes_from(list(self.nodes.items()))
|
||||
g.add_edges_from({(e.source.node_id, e.destination.node_id) for e in self.edges})
|
||||
return g
|
||||
|
||||
def nx_graph_flat(self, nx_graph: Optional[nx.DiGraph] = None, prefix: Optional[str] = None) -> nx.DiGraph:
|
||||
@ -791,7 +789,7 @@ class Graph(BaseModel):
|
||||
|
||||
# TODO: figure out if iteration nodes need to be expanded
|
||||
|
||||
unique_edges = set([(e.source.node_id, e.destination.node_id) for e in self.edges])
|
||||
unique_edges = {(e.source.node_id, e.destination.node_id) for e in self.edges}
|
||||
g.add_edges_from([(self._get_node_path(e[0], prefix), self._get_node_path(e[1], prefix)) for e in unique_edges])
|
||||
return g
|
||||
|
||||
@ -843,8 +841,8 @@ class GraphExecutionState(BaseModel):
|
||||
return v
|
||||
|
||||
model_config = ConfigDict(
|
||||
json_schema_extra=dict(
|
||||
required=[
|
||||
json_schema_extra={
|
||||
"required": [
|
||||
"id",
|
||||
"graph",
|
||||
"execution_graph",
|
||||
@ -855,7 +853,7 @@ class GraphExecutionState(BaseModel):
|
||||
"prepared_source_mapping",
|
||||
"source_prepared_mapping",
|
||||
]
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
def next(self) -> Optional[BaseInvocation]:
|
||||
@ -895,7 +893,7 @@ class GraphExecutionState(BaseModel):
|
||||
source_node = self.prepared_source_mapping[node_id]
|
||||
prepared_nodes = self.source_prepared_mapping[source_node]
|
||||
|
||||
if all([n in self.executed for n in prepared_nodes]):
|
||||
if all(n in self.executed for n in prepared_nodes):
|
||||
self.executed.add(source_node)
|
||||
self.executed_history.append(source_node)
|
||||
|
||||
@ -930,7 +928,7 @@ class GraphExecutionState(BaseModel):
|
||||
input_collection = getattr(input_collection_prepared_node_output, input_collection_edge.source.field)
|
||||
self_iteration_count = len(input_collection)
|
||||
|
||||
new_nodes: list[str] = list()
|
||||
new_nodes: list[str] = []
|
||||
if self_iteration_count == 0:
|
||||
# TODO: should this raise a warning? It might just happen if an empty collection is input, and should be valid.
|
||||
return new_nodes
|
||||
@ -940,7 +938,7 @@ class GraphExecutionState(BaseModel):
|
||||
|
||||
# Create new edges for this iteration
|
||||
# For collect nodes, this may contain multiple inputs to the same field
|
||||
new_edges: list[Edge] = list()
|
||||
new_edges: list[Edge] = []
|
||||
for edge in input_edges:
|
||||
for input_node_id in (n[1] for n in iteration_node_map if n[0] == edge.source.node_id):
|
||||
new_edge = Edge(
|
||||
@ -1034,7 +1032,7 @@ class GraphExecutionState(BaseModel):
|
||||
|
||||
# Create execution nodes
|
||||
next_node = self.graph.get_node(next_node_id)
|
||||
new_node_ids = list()
|
||||
new_node_ids = []
|
||||
if isinstance(next_node, CollectInvocation):
|
||||
# Collapse all iterator input mappings and create a single execution node for the collect invocation
|
||||
all_iteration_mappings = list(
|
||||
@ -1055,7 +1053,10 @@ class GraphExecutionState(BaseModel):
|
||||
# For every iterator, the parent must either not be a child of that iterator, or must match the prepared iteration for that iterator
|
||||
# TODO: Handle a node mapping to none
|
||||
eg = self.execution_graph.nx_graph_flat()
|
||||
prepared_parent_mappings = [[(n, self._get_iteration_node(n, g, eg, it)) for n in next_node_parents] for it in iterator_node_prepared_combinations] # type: ignore
|
||||
prepared_parent_mappings = [
|
||||
[(n, self._get_iteration_node(n, g, eg, it)) for n in next_node_parents]
|
||||
for it in iterator_node_prepared_combinations
|
||||
] # type: ignore
|
||||
|
||||
# Create execution node for each iteration
|
||||
for iteration_mappings in prepared_parent_mappings:
|
||||
@ -1121,7 +1122,7 @@ class GraphExecutionState(BaseModel):
|
||||
for edge in input_edges
|
||||
if edge.destination.field == "item"
|
||||
]
|
||||
setattr(node, "collection", output_collection)
|
||||
node.collection = output_collection
|
||||
else:
|
||||
for edge in input_edges:
|
||||
output_value = getattr(self.results[edge.source.node_id], edge.source.field)
|
||||
@ -1201,7 +1202,7 @@ class LibraryGraph(BaseModel):
|
||||
|
||||
@field_validator("exposed_inputs", "exposed_outputs")
|
||||
def validate_exposed_aliases(cls, v: list[Union[ExposedNodeInput, ExposedNodeOutput]]):
|
||||
if len(v) != len(set(i.alias for i in v)):
|
||||
if len(v) != len({i.alias for i in v}):
|
||||
raise ValueError("Duplicate exposed alias")
|
||||
return v
|
||||
|
||||
|
@ -59,7 +59,7 @@ def thin_one_time(x, kernels):
|
||||
|
||||
def lvmin_thin(x, prunings=True):
|
||||
y = x
|
||||
for i in range(32):
|
||||
for _i in range(32):
|
||||
y, is_done = thin_one_time(y, lvmin_kernels)
|
||||
if is_done:
|
||||
break
|
||||
|
@ -21,11 +21,11 @@ def get_metadata_graph_from_raw_session(session_raw: str) -> Optional[dict]:
|
||||
|
||||
# sanity check make sure the graph is at least reasonably shaped
|
||||
if (
|
||||
type(graph) is not dict
|
||||
not isinstance(graph, dict)
|
||||
or "nodes" not in graph
|
||||
or type(graph["nodes"]) is not dict
|
||||
or not isinstance(graph["nodes"], dict)
|
||||
or "edges" not in graph
|
||||
or type(graph["edges"]) is not list
|
||||
or not isinstance(graph["edges"], list)
|
||||
):
|
||||
# something has gone terribly awry, return an empty dict
|
||||
return None
|
||||
|
@ -88,7 +88,7 @@ class PromptFormatter:
|
||||
t2i = self.t2i
|
||||
opt = self.opt
|
||||
|
||||
switches = list()
|
||||
switches = []
|
||||
switches.append(f'"{opt.prompt}"')
|
||||
switches.append(f"-s{opt.steps or t2i.steps}")
|
||||
switches.append(f"-W{opt.width or t2i.width}")
|
||||
|
@ -88,7 +88,7 @@ class Txt2Mask(object):
|
||||
provided image and returns a SegmentedGrayscale object in which the brighter
|
||||
pixels indicate where the object is inferred to be.
|
||||
"""
|
||||
if type(image) is str:
|
||||
if isinstance(image, str):
|
||||
image = Image.open(image).convert("RGB")
|
||||
|
||||
image = ImageOps.exif_transpose(image)
|
||||
|
@ -40,7 +40,7 @@ class InitImageResizer:
|
||||
(rw, rh) = (int(scale * im.width), int(scale * im.height))
|
||||
|
||||
# round everything to multiples of 64
|
||||
width, height, rw, rh = map(lambda x: x - x % 64, (width, height, rw, rh))
|
||||
width, height, rw, rh = (x - x % 64 for x in (width, height, rw, rh))
|
||||
|
||||
# no resize necessary, but return a copy
|
||||
if im.width == width and im.height == height:
|
||||
|
@ -197,7 +197,7 @@ def download_with_progress_bar(model_url: str, model_dest: str, label: str = "th
|
||||
|
||||
def download_conversion_models():
|
||||
target_dir = config.models_path / "core/convert"
|
||||
kwargs = dict() # for future use
|
||||
kwargs = {} # for future use
|
||||
try:
|
||||
logger.info("Downloading core tokenizers and text encoders")
|
||||
|
||||
@ -252,26 +252,26 @@ def download_conversion_models():
|
||||
def download_realesrgan():
|
||||
logger.info("Installing ESRGAN Upscaling models...")
|
||||
URLs = [
|
||||
dict(
|
||||
url="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
|
||||
dest="core/upscaling/realesrgan/RealESRGAN_x4plus.pth",
|
||||
description="RealESRGAN_x4plus.pth",
|
||||
),
|
||||
dict(
|
||||
url="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth",
|
||||
dest="core/upscaling/realesrgan/RealESRGAN_x4plus_anime_6B.pth",
|
||||
description="RealESRGAN_x4plus_anime_6B.pth",
|
||||
),
|
||||
dict(
|
||||
url="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth",
|
||||
dest="core/upscaling/realesrgan/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth",
|
||||
description="ESRGAN_SRx4_DF2KOST_official.pth",
|
||||
),
|
||||
dict(
|
||||
url="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
|
||||
dest="core/upscaling/realesrgan/RealESRGAN_x2plus.pth",
|
||||
description="RealESRGAN_x2plus.pth",
|
||||
),
|
||||
{
|
||||
"url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
|
||||
"dest": "core/upscaling/realesrgan/RealESRGAN_x4plus.pth",
|
||||
"description": "RealESRGAN_x4plus.pth",
|
||||
},
|
||||
{
|
||||
"url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth",
|
||||
"dest": "core/upscaling/realesrgan/RealESRGAN_x4plus_anime_6B.pth",
|
||||
"description": "RealESRGAN_x4plus_anime_6B.pth",
|
||||
},
|
||||
{
|
||||
"url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth",
|
||||
"dest": "core/upscaling/realesrgan/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth",
|
||||
"description": "ESRGAN_SRx4_DF2KOST_official.pth",
|
||||
},
|
||||
{
|
||||
"url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
|
||||
"dest": "core/upscaling/realesrgan/RealESRGAN_x2plus.pth",
|
||||
"description": "RealESRGAN_x2plus.pth",
|
||||
},
|
||||
]
|
||||
for model in URLs:
|
||||
download_with_progress_bar(model["url"], config.models_path / model["dest"], model["description"])
|
||||
@ -680,7 +680,7 @@ def default_user_selections(program_opts: Namespace) -> InstallSelections:
|
||||
if program_opts.default_only
|
||||
else [models[x].path or models[x].repo_id for x in installer.recommended_models()]
|
||||
if program_opts.yes_to_all
|
||||
else list(),
|
||||
else [],
|
||||
)
|
||||
|
||||
|
||||
|
@ -123,8 +123,6 @@ class MigrateTo3(object):
|
||||
logger.error(str(e))
|
||||
except KeyboardInterrupt:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(str(e))
|
||||
for f in files:
|
||||
# don't copy raw learned_embeds.bin or pytorch_lora_weights.bin
|
||||
# let them be copied as part of a tree copy operation
|
||||
@ -143,8 +141,6 @@ class MigrateTo3(object):
|
||||
logger.error(str(e))
|
||||
except KeyboardInterrupt:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(str(e))
|
||||
|
||||
def migrate_support_models(self):
|
||||
"""
|
||||
@ -182,10 +178,10 @@ class MigrateTo3(object):
|
||||
"""
|
||||
|
||||
dest_directory = self.dest_models
|
||||
kwargs = dict(
|
||||
cache_dir=self.root_directory / "models/hub",
|
||||
kwargs = {
|
||||
"cache_dir": self.root_directory / "models/hub",
|
||||
# local_files_only = True
|
||||
)
|
||||
}
|
||||
try:
|
||||
logger.info("Migrating core tokenizers and text encoders")
|
||||
target_dir = dest_directory / "core" / "convert"
|
||||
@ -316,11 +312,11 @@ class MigrateTo3(object):
|
||||
dest_dir = self.dest_models
|
||||
|
||||
cache = self.root_directory / "models/hub"
|
||||
kwargs = dict(
|
||||
cache_dir=cache,
|
||||
safety_checker=None,
|
||||
kwargs = {
|
||||
"cache_dir": cache,
|
||||
"safety_checker": None,
|
||||
# local_files_only = True,
|
||||
)
|
||||
}
|
||||
|
||||
owner, repo_name = repo_id.split("/")
|
||||
model_name = model_name or repo_name
|
||||
|
@ -120,7 +120,7 @@ class ModelInstall(object):
|
||||
be treated uniformly. It also sorts the models alphabetically
|
||||
by their name, to improve the display somewhat.
|
||||
"""
|
||||
model_dict = dict()
|
||||
model_dict = {}
|
||||
|
||||
# first populate with the entries in INITIAL_MODELS.yaml
|
||||
for key, value in self.datasets.items():
|
||||
@ -134,7 +134,7 @@ class ModelInstall(object):
|
||||
model_dict[key] = model_info
|
||||
|
||||
# supplement with entries in models.yaml
|
||||
installed_models = [x for x in self.mgr.list_models()]
|
||||
installed_models = list(self.mgr.list_models())
|
||||
|
||||
for md in installed_models:
|
||||
base = md["base_model"]
|
||||
@ -176,7 +176,7 @@ class ModelInstall(object):
|
||||
# logic here a little reversed to maintain backward compatibility
|
||||
def starter_models(self, all_models: bool = False) -> Set[str]:
|
||||
models = set()
|
||||
for key, value in self.datasets.items():
|
||||
for key, _value in self.datasets.items():
|
||||
name, base, model_type = ModelManager.parse_key(key)
|
||||
if all_models or model_type in [ModelType.Main, ModelType.Vae]:
|
||||
models.add(key)
|
||||
@ -184,7 +184,7 @@ class ModelInstall(object):
|
||||
|
||||
def recommended_models(self) -> Set[str]:
|
||||
starters = self.starter_models(all_models=True)
|
||||
return set([x for x in starters if self.datasets[x].get("recommended", False)])
|
||||
return {x for x in starters if self.datasets[x].get("recommended", False)}
|
||||
|
||||
def default_model(self) -> str:
|
||||
starters = self.starter_models()
|
||||
@ -234,7 +234,7 @@ class ModelInstall(object):
|
||||
"""
|
||||
|
||||
if not models_installed:
|
||||
models_installed = dict()
|
||||
models_installed = {}
|
||||
|
||||
model_path_id_or_url = str(model_path_id_or_url).strip("\"' ")
|
||||
|
||||
@ -252,16 +252,14 @@ class ModelInstall(object):
|
||||
|
||||
# folders style or similar
|
||||
elif path.is_dir() and any(
|
||||
[
|
||||
(path / x).exists()
|
||||
for x in {
|
||||
"config.json",
|
||||
"model_index.json",
|
||||
"learned_embeds.bin",
|
||||
"pytorch_lora_weights.bin",
|
||||
"pytorch_lora_weights.safetensors",
|
||||
}
|
||||
]
|
||||
(path / x).exists()
|
||||
for x in {
|
||||
"config.json",
|
||||
"model_index.json",
|
||||
"learned_embeds.bin",
|
||||
"pytorch_lora_weights.bin",
|
||||
"pytorch_lora_weights.safetensors",
|
||||
}
|
||||
):
|
||||
models_installed.update({str(model_path_id_or_url): self._install_path(path)})
|
||||
|
||||
@ -433,17 +431,17 @@ class ModelInstall(object):
|
||||
|
||||
rel_path = self.relative_to_root(path, self.config.models_path)
|
||||
|
||||
attributes = dict(
|
||||
path=str(rel_path),
|
||||
description=str(description),
|
||||
model_format=info.format,
|
||||
)
|
||||
attributes = {
|
||||
"path": str(rel_path),
|
||||
"description": str(description),
|
||||
"model_format": info.format,
|
||||
}
|
||||
legacy_conf = None
|
||||
if info.model_type == ModelType.Main or info.model_type == ModelType.ONNX:
|
||||
attributes.update(
|
||||
dict(
|
||||
variant=info.variant_type,
|
||||
)
|
||||
{
|
||||
"variant": info.variant_type,
|
||||
}
|
||||
)
|
||||
if info.format == "checkpoint":
|
||||
try:
|
||||
@ -474,7 +472,7 @@ class ModelInstall(object):
|
||||
)
|
||||
|
||||
if legacy_conf:
|
||||
attributes.update(dict(config=str(legacy_conf)))
|
||||
attributes.update({"config": str(legacy_conf)})
|
||||
return attributes
|
||||
|
||||
def relative_to_root(self, path: Path, root: Optional[Path] = None) -> Path:
|
||||
@ -519,7 +517,7 @@ class ModelInstall(object):
|
||||
def _download_hf_model(self, repo_id: str, files: List[str], staging: Path, subfolder: None) -> Path:
|
||||
_, name = repo_id.split("/")
|
||||
location = staging / name
|
||||
paths = list()
|
||||
paths = []
|
||||
for filename in files:
|
||||
filePath = Path(filename)
|
||||
p = hf_download_with_resume(
|
||||
|
@ -130,7 +130,9 @@ class IPAttnProcessor2_0(torch.nn.Module):
|
||||
assert ip_adapter_image_prompt_embeds is not None
|
||||
assert len(ip_adapter_image_prompt_embeds) == len(self._weights)
|
||||
|
||||
for ipa_embed, ipa_weights, scale in zip(ip_adapter_image_prompt_embeds, self._weights, self._scales):
|
||||
for ipa_embed, ipa_weights, scale in zip(
|
||||
ip_adapter_image_prompt_embeds, self._weights, self._scales, strict=True
|
||||
):
|
||||
# The batch dimensions should match.
|
||||
assert ipa_embed.shape[0] == encoder_hidden_states.shape[0]
|
||||
# The token_len dimensions should match.
|
||||
|
@ -56,7 +56,7 @@ class PerceiverAttention(nn.Module):
|
||||
x = self.norm1(x)
|
||||
latents = self.norm2(latents)
|
||||
|
||||
b, l, _ = latents.shape
|
||||
b, L, _ = latents.shape
|
||||
|
||||
q = self.to_q(latents)
|
||||
kv_input = torch.cat((x, latents), dim=-2)
|
||||
@ -72,7 +72,7 @@ class PerceiverAttention(nn.Module):
|
||||
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
|
||||
out = weight @ v
|
||||
|
||||
out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
|
||||
out = out.permute(0, 2, 1, 3).reshape(b, L, -1)
|
||||
|
||||
return self.to_out(out)
|
||||
|
||||
|
@ -269,7 +269,7 @@ def create_unet_diffusers_config(original_config, image_size: int, controlnet=Fa
|
||||
resolution *= 2
|
||||
|
||||
up_block_types = []
|
||||
for i in range(len(block_out_channels)):
|
||||
for _i in range(len(block_out_channels)):
|
||||
block_type = "CrossAttnUpBlock2D" if resolution in unet_params.attention_resolutions else "UpBlock2D"
|
||||
up_block_types.append(block_type)
|
||||
resolution //= 2
|
||||
@ -1223,7 +1223,7 @@ def download_from_original_stable_diffusion_ckpt(
|
||||
# scan model
|
||||
scan_result = scan_file_path(checkpoint_path)
|
||||
if scan_result.infected_files != 0:
|
||||
raise "The model {checkpoint_path} is potentially infected by malware. Aborting import."
|
||||
raise Exception("The model {checkpoint_path} is potentially infected by malware. Aborting import.")
|
||||
if device is None:
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
checkpoint = torch.load(checkpoint_path, map_location=device)
|
||||
@ -1664,7 +1664,7 @@ def download_controlnet_from_original_ckpt(
|
||||
# scan model
|
||||
scan_result = scan_file_path(checkpoint_path)
|
||||
if scan_result.infected_files != 0:
|
||||
raise "The model {checkpoint_path} is potentially infected by malware. Aborting import."
|
||||
raise Exception("The model {checkpoint_path} is potentially infected by malware. Aborting import.")
|
||||
if device is None:
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
checkpoint = torch.load(checkpoint_path, map_location=device)
|
||||
|
@ -104,7 +104,7 @@ class ModelPatcher:
|
||||
loras: List[Tuple[LoRAModel, float]],
|
||||
prefix: str,
|
||||
):
|
||||
original_weights = dict()
|
||||
original_weights = {}
|
||||
try:
|
||||
with torch.no_grad():
|
||||
for lora, lora_weight in loras:
|
||||
@ -242,7 +242,7 @@ class ModelPatcher:
|
||||
):
|
||||
skipped_layers = []
|
||||
try:
|
||||
for i in range(clip_skip):
|
||||
for _i in range(clip_skip):
|
||||
skipped_layers.append(text_encoder.text_model.encoder.layers.pop(-1))
|
||||
|
||||
yield
|
||||
@ -324,7 +324,7 @@ class TextualInversionManager(BaseTextualInversionManager):
|
||||
tokenizer: CLIPTokenizer
|
||||
|
||||
def __init__(self, tokenizer: CLIPTokenizer):
|
||||
self.pad_tokens = dict()
|
||||
self.pad_tokens = {}
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
def expand_textual_inversion_token_ids_if_necessary(self, token_ids: list[int]) -> list[int]:
|
||||
@ -385,10 +385,10 @@ class ONNXModelPatcher:
|
||||
if not isinstance(model, IAIOnnxRuntimeModel):
|
||||
raise Exception("Only IAIOnnxRuntimeModel models supported")
|
||||
|
||||
orig_weights = dict()
|
||||
orig_weights = {}
|
||||
|
||||
try:
|
||||
blended_loras = dict()
|
||||
blended_loras = {}
|
||||
|
||||
for lora, lora_weight in loras:
|
||||
for layer_key, layer in lora.layers.items():
|
||||
@ -404,7 +404,7 @@ class ONNXModelPatcher:
|
||||
else:
|
||||
blended_loras[layer_key] = layer_weight
|
||||
|
||||
node_names = dict()
|
||||
node_names = {}
|
||||
for node in model.nodes.values():
|
||||
node_names[node.name.replace("/", "_").replace(".", "_").lstrip("_")] = node.name
|
||||
|
||||
|
@ -66,11 +66,13 @@ class CacheStats(object):
|
||||
|
||||
class ModelLocker(object):
|
||||
"Forward declaration"
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ModelCache(object):
|
||||
"Forward declaration"
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@ -132,7 +134,7 @@ class ModelCache(object):
|
||||
snapshots, so it is recommended to disable this feature unless you are actively inspecting the model cache's
|
||||
behaviour.
|
||||
"""
|
||||
self.model_infos: Dict[str, ModelBase] = dict()
|
||||
self.model_infos: Dict[str, ModelBase] = {}
|
||||
# allow lazy offloading only when vram cache enabled
|
||||
self.lazy_offloading = lazy_offloading and max_vram_cache_size > 0
|
||||
self.precision: torch.dtype = precision
|
||||
@ -147,8 +149,8 @@ class ModelCache(object):
|
||||
# used for stats collection
|
||||
self.stats = None
|
||||
|
||||
self._cached_models = dict()
|
||||
self._cache_stack = list()
|
||||
self._cached_models = {}
|
||||
self._cache_stack = []
|
||||
|
||||
def _capture_memory_snapshot(self) -> Optional[MemorySnapshot]:
|
||||
if self._log_memory_usage:
|
||||
|
@ -26,5 +26,5 @@ def skip_torch_weight_init():
|
||||
|
||||
yield None
|
||||
finally:
|
||||
for torch_module, saved_function in zip(torch_modules, saved_functions):
|
||||
for torch_module, saved_function in zip(torch_modules, saved_functions, strict=True):
|
||||
torch_module.reset_parameters = saved_function
|
||||
|
@ -363,7 +363,7 @@ class ModelManager(object):
|
||||
else:
|
||||
return
|
||||
|
||||
self.models = dict()
|
||||
self.models = {}
|
||||
for model_key, model_config in config.items():
|
||||
if model_key.startswith("_"):
|
||||
continue
|
||||
@ -374,7 +374,7 @@ class ModelManager(object):
|
||||
self.models[model_key] = model_class.create_config(**model_config)
|
||||
|
||||
# check config version number and update on disk/RAM if necessary
|
||||
self.cache_keys = dict()
|
||||
self.cache_keys = {}
|
||||
|
||||
# add controlnet, lora and textual_inversion models from disk
|
||||
self.scan_models_directory()
|
||||
@ -655,7 +655,7 @@ class ModelManager(object):
|
||||
"""
|
||||
# TODO: redo
|
||||
for model_dict in self.list_models():
|
||||
for model_name, model_info in model_dict.items():
|
||||
for _model_name, model_info in model_dict.items():
|
||||
line = f'{model_info["name"]:25s} {model_info["type"]:10s} {model_info["description"]}'
|
||||
print(line)
|
||||
|
||||
@ -902,7 +902,7 @@ class ModelManager(object):
|
||||
"""
|
||||
Write current configuration out to the indicated file.
|
||||
"""
|
||||
data_to_save = dict()
|
||||
data_to_save = {}
|
||||
data_to_save["__metadata__"] = self.config_meta.model_dump()
|
||||
|
||||
for model_key, model_config in self.models.items():
|
||||
@ -1034,7 +1034,7 @@ class ModelManager(object):
|
||||
self.ignore = ignore
|
||||
|
||||
def on_search_started(self):
|
||||
self.new_models_found = dict()
|
||||
self.new_models_found = {}
|
||||
|
||||
def on_model_found(self, model: Path):
|
||||
if model not in self.ignore:
|
||||
@ -1106,7 +1106,7 @@ class ModelManager(object):
|
||||
# avoid circular import here
|
||||
from invokeai.backend.install.model_install_backend import ModelInstall
|
||||
|
||||
successfully_installed = dict()
|
||||
successfully_installed = {}
|
||||
|
||||
installer = ModelInstall(
|
||||
config=self.app_config, prediction_type_helper=prediction_type_helper, model_manager=self
|
||||
|
@ -92,7 +92,7 @@ class ModelMerger(object):
|
||||
**kwargs - the default DiffusionPipeline.get_config_dict kwargs:
|
||||
cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map
|
||||
"""
|
||||
model_paths = list()
|
||||
model_paths = []
|
||||
config = self.manager.app_config
|
||||
base_model = BaseModelType(base_model)
|
||||
vae = None
|
||||
@ -124,13 +124,13 @@ class ModelMerger(object):
|
||||
dump_path = (dump_path / merged_model_name).as_posix()
|
||||
|
||||
merged_pipe.save_pretrained(dump_path, safe_serialization=True)
|
||||
attributes = dict(
|
||||
path=dump_path,
|
||||
description=f"Merge of models {', '.join(model_names)}",
|
||||
model_format="diffusers",
|
||||
variant=ModelVariantType.Normal.value,
|
||||
vae=vae,
|
||||
)
|
||||
attributes = {
|
||||
"path": dump_path,
|
||||
"description": f"Merge of models {', '.join(model_names)}",
|
||||
"model_format": "diffusers",
|
||||
"variant": ModelVariantType.Normal.value,
|
||||
"vae": vae,
|
||||
}
|
||||
return self.manager.add_model(
|
||||
merged_model_name,
|
||||
base_model=base_model,
|
||||
|
@ -237,7 +237,7 @@ class ModelProbe(object):
|
||||
# scan model
|
||||
scan_result = scan_file_path(checkpoint)
|
||||
if scan_result.infected_files != 0:
|
||||
raise "The model {model_name} is potentially infected by malware. Aborting import."
|
||||
raise Exception("The model {model_name} is potentially infected by malware. Aborting import.")
|
||||
|
||||
|
||||
# ##################################################3
|
||||
|
@ -59,7 +59,7 @@ class ModelSearch(ABC):
|
||||
for root, dirs, files in os.walk(path, followlinks=True):
|
||||
if str(Path(root).name).startswith("."):
|
||||
self._pruned_paths.add(root)
|
||||
if any([Path(root).is_relative_to(x) for x in self._pruned_paths]):
|
||||
if any(Path(root).is_relative_to(x) for x in self._pruned_paths):
|
||||
continue
|
||||
|
||||
self._items_scanned += len(dirs) + len(files)
|
||||
@ -69,16 +69,14 @@ class ModelSearch(ABC):
|
||||
self._scanned_dirs.add(path)
|
||||
continue
|
||||
if any(
|
||||
[
|
||||
(path / x).exists()
|
||||
for x in {
|
||||
"config.json",
|
||||
"model_index.json",
|
||||
"learned_embeds.bin",
|
||||
"pytorch_lora_weights.bin",
|
||||
"image_encoder.txt",
|
||||
}
|
||||
]
|
||||
(path / x).exists()
|
||||
for x in {
|
||||
"config.json",
|
||||
"model_index.json",
|
||||
"learned_embeds.bin",
|
||||
"pytorch_lora_weights.bin",
|
||||
"image_encoder.txt",
|
||||
}
|
||||
):
|
||||
try:
|
||||
self.on_model_found(path)
|
||||
|
@ -97,8 +97,8 @@ MODEL_CLASSES = {
|
||||
# },
|
||||
}
|
||||
|
||||
MODEL_CONFIGS = list()
|
||||
OPENAPI_MODEL_CONFIGS = list()
|
||||
MODEL_CONFIGS = []
|
||||
OPENAPI_MODEL_CONFIGS = []
|
||||
|
||||
|
||||
class OpenAPIModelInfoBase(BaseModel):
|
||||
@ -109,7 +109,7 @@ class OpenAPIModelInfoBase(BaseModel):
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
|
||||
for base_model, models in MODEL_CLASSES.items():
|
||||
for _base_model, models in MODEL_CLASSES.items():
|
||||
for model_type, model_class in models.items():
|
||||
model_configs = set(model_class._get_configs().values())
|
||||
model_configs.discard(None)
|
||||
@ -133,7 +133,7 @@ for base_model, models in MODEL_CLASSES.items():
|
||||
|
||||
|
||||
def get_model_config_enums():
|
||||
enums = list()
|
||||
enums = []
|
||||
|
||||
for model_config in MODEL_CONFIGS:
|
||||
if hasattr(inspect, "get_annotations"):
|
||||
|
@ -153,7 +153,7 @@ class ModelBase(metaclass=ABCMeta):
|
||||
|
||||
else:
|
||||
res_type = sys.modules["diffusers"]
|
||||
res_type = getattr(res_type, "pipelines")
|
||||
res_type = res_type.pipelines
|
||||
|
||||
for subtype in subtypes:
|
||||
res_type = getattr(res_type, subtype)
|
||||
@ -164,7 +164,7 @@ class ModelBase(metaclass=ABCMeta):
|
||||
with suppress(Exception):
|
||||
return cls.__configs
|
||||
|
||||
configs = dict()
|
||||
configs = {}
|
||||
for name in dir(cls):
|
||||
if name.startswith("__"):
|
||||
continue
|
||||
@ -246,8 +246,8 @@ class DiffusersModel(ModelBase):
|
||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||
super().__init__(model_path, base_model, model_type)
|
||||
|
||||
self.child_types: Dict[str, Type] = dict()
|
||||
self.child_sizes: Dict[str, int] = dict()
|
||||
self.child_types: Dict[str, Type] = {}
|
||||
self.child_sizes: Dict[str, int] = {}
|
||||
|
||||
try:
|
||||
config_data = DiffusionPipeline.load_config(self.model_path)
|
||||
@ -326,8 +326,8 @@ def calc_model_size_by_fs(model_path: str, subfolder: Optional[str] = None, vari
|
||||
all_files = os.listdir(model_path)
|
||||
all_files = [f for f in all_files if os.path.isfile(os.path.join(model_path, f))]
|
||||
|
||||
fp16_files = set([f for f in all_files if ".fp16." in f or ".fp16-" in f])
|
||||
bit8_files = set([f for f in all_files if ".8bit." in f or ".8bit-" in f])
|
||||
fp16_files = {f for f in all_files if ".fp16." in f or ".fp16-" in f}
|
||||
bit8_files = {f for f in all_files if ".8bit." in f or ".8bit-" in f}
|
||||
other_files = set(all_files) - fp16_files - bit8_files
|
||||
|
||||
if variant is None:
|
||||
@ -413,7 +413,7 @@ def _calc_onnx_model_by_data(model) -> int:
|
||||
|
||||
|
||||
def _fast_safetensors_reader(path: str):
|
||||
checkpoint = dict()
|
||||
checkpoint = {}
|
||||
device = torch.device("meta")
|
||||
with open(path, "rb") as f:
|
||||
definition_len = int.from_bytes(f.read(8), "little")
|
||||
@ -483,7 +483,7 @@ class IAIOnnxRuntimeModel:
|
||||
class _tensor_access:
|
||||
def __init__(self, model):
|
||||
self.model = model
|
||||
self.indexes = dict()
|
||||
self.indexes = {}
|
||||
for idx, obj in enumerate(self.model.proto.graph.initializer):
|
||||
self.indexes[obj.name] = idx
|
||||
|
||||
@ -524,7 +524,7 @@ class IAIOnnxRuntimeModel:
|
||||
|
||||
class _access_helper:
|
||||
def __init__(self, raw_proto):
|
||||
self.indexes = dict()
|
||||
self.indexes = {}
|
||||
self.raw_proto = raw_proto
|
||||
for idx, obj in enumerate(raw_proto):
|
||||
self.indexes[obj.name] = idx
|
||||
@ -549,7 +549,7 @@ class IAIOnnxRuntimeModel:
|
||||
return self.indexes.keys()
|
||||
|
||||
def values(self):
|
||||
return [obj for obj in self.raw_proto]
|
||||
return list(self.raw_proto)
|
||||
|
||||
def __init__(self, model_path: str, provider: Optional[str]):
|
||||
self.path = model_path
|
||||
|
@ -104,7 +104,7 @@ class ControlNetModel(ModelBase):
|
||||
return ControlNetModelFormat.Diffusers
|
||||
|
||||
if os.path.isfile(path):
|
||||
if any([path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt", "pth"]]):
|
||||
if any(path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt", "pth"]):
|
||||
return ControlNetModelFormat.Checkpoint
|
||||
|
||||
raise InvalidModelException(f"Not a valid model: {path}")
|
||||
|
@ -73,7 +73,7 @@ class LoRAModel(ModelBase):
|
||||
return LoRAModelFormat.Diffusers
|
||||
|
||||
if os.path.isfile(path):
|
||||
if any([path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]]):
|
||||
if any(path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]):
|
||||
return LoRAModelFormat.LyCORIS
|
||||
|
||||
raise InvalidModelException(f"Not a valid model: {path}")
|
||||
@ -462,7 +462,7 @@ class LoRAModelRaw: # (torch.nn.Module):
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
# TODO: try revert if exception?
|
||||
for key, layer in self.layers.items():
|
||||
for _key, layer in self.layers.items():
|
||||
layer.to(device=device, dtype=dtype)
|
||||
|
||||
def calc_size(self) -> int:
|
||||
@ -499,7 +499,7 @@ class LoRAModelRaw: # (torch.nn.Module):
|
||||
stability_unet_keys = list(SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP)
|
||||
stability_unet_keys.sort()
|
||||
|
||||
new_state_dict = dict()
|
||||
new_state_dict = {}
|
||||
for full_key, value in state_dict.items():
|
||||
if full_key.startswith("lora_unet_"):
|
||||
search_key = full_key.replace("lora_unet_", "")
|
||||
@ -545,7 +545,7 @@ class LoRAModelRaw: # (torch.nn.Module):
|
||||
|
||||
model = cls(
|
||||
name=file_path.stem, # TODO:
|
||||
layers=dict(),
|
||||
layers={},
|
||||
)
|
||||
|
||||
if file_path.suffix == ".safetensors":
|
||||
@ -593,12 +593,12 @@ class LoRAModelRaw: # (torch.nn.Module):
|
||||
|
||||
@staticmethod
|
||||
def _group_state(state_dict: dict):
|
||||
state_dict_groupped = dict()
|
||||
state_dict_groupped = {}
|
||||
|
||||
for key, value in state_dict.items():
|
||||
stem, leaf = key.split(".", 1)
|
||||
if stem not in state_dict_groupped:
|
||||
state_dict_groupped[stem] = dict()
|
||||
state_dict_groupped[stem] = {}
|
||||
state_dict_groupped[stem][leaf] = value
|
||||
|
||||
return state_dict_groupped
|
||||
|
@ -110,7 +110,7 @@ class StableDiffusion1Model(DiffusersModel):
|
||||
return StableDiffusion1ModelFormat.Diffusers
|
||||
|
||||
if os.path.isfile(model_path):
|
||||
if any([model_path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]]):
|
||||
if any(model_path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]):
|
||||
return StableDiffusion1ModelFormat.Checkpoint
|
||||
|
||||
raise InvalidModelException(f"Not a valid model: {model_path}")
|
||||
@ -221,7 +221,7 @@ class StableDiffusion2Model(DiffusersModel):
|
||||
return StableDiffusion2ModelFormat.Diffusers
|
||||
|
||||
if os.path.isfile(model_path):
|
||||
if any([model_path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]]):
|
||||
if any(model_path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]):
|
||||
return StableDiffusion2ModelFormat.Checkpoint
|
||||
|
||||
raise InvalidModelException(f"Not a valid model: {model_path}")
|
||||
|
@ -71,7 +71,7 @@ class TextualInversionModel(ModelBase):
|
||||
return None # diffusers-ti
|
||||
|
||||
if os.path.isfile(path):
|
||||
if any([path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt", "bin"]]):
|
||||
if any(path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt", "bin"]):
|
||||
return None
|
||||
|
||||
raise InvalidModelException(f"Not a valid model: {path}")
|
||||
|
@ -89,7 +89,7 @@ class VaeModel(ModelBase):
|
||||
return VaeModelFormat.Diffusers
|
||||
|
||||
if os.path.isfile(path):
|
||||
if any([path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]]):
|
||||
if any(path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]):
|
||||
return VaeModelFormat.Checkpoint
|
||||
|
||||
raise InvalidModelException(f"Not a valid model: {path}")
|
||||
|
@ -114,8 +114,10 @@ class ModelConfigBase(BaseModel):
|
||||
current_hash: Optional[str] = Field(
|
||||
description="current fasthash of model contents", default=None
|
||||
) # if model is converted or otherwise modified, this will hold updated hash
|
||||
description: Optional[str] = Field(None)
|
||||
source: Optional[str] = Field(description="Model download source (URL or repo_id)", default=None)
|
||||
description: Optional[str] = Field(default=None)
|
||||
source: Optional[str] = Field(
|
||||
description="Model download source (URL or repo_id)", default=None
|
||||
)
|
||||
|
||||
model_config = ConfigDict(
|
||||
use_enum_values=False,
|
||||
@ -249,12 +251,19 @@ class T2IConfig(ModelConfigBase):
|
||||
format: Literal[ModelFormat.Diffusers]
|
||||
|
||||
|
||||
_ONNXConfig = Annotated[Union[ONNXSD1Config, ONNXSD2Config], Field(discriminator="base")]
|
||||
_ControlNetConfig = Annotated[
|
||||
Union[ControlNetDiffusersConfig, ControlNetCheckpointConfig], Field(discriminator="format")
|
||||
_ONNXConfig = Annotated[
|
||||
Union[ONNXSD1Config, ONNXSD2Config], Field(discriminator="base")
|
||||
]
|
||||
_ControlNetConfig = Annotated[
|
||||
Union[ControlNetDiffusersConfig, ControlNetCheckpointConfig],
|
||||
Field(discriminator="format"),
|
||||
]
|
||||
_VaeConfig = Annotated[
|
||||
Union[VaeDiffusersConfig, VaeCheckpointConfig], Field(discriminator="format")
|
||||
]
|
||||
_MainModelConfig = Annotated[
|
||||
Union[MainDiffusersConfig, MainCheckpointConfig], Field(discriminator="format")
|
||||
]
|
||||
_VaeConfig = Annotated[Union[VaeDiffusersConfig, VaeCheckpointConfig], Field(discriminator="format")]
|
||||
_MainModelConfig = Annotated[Union[MainDiffusersConfig, MainCheckpointConfig], Field(discriminator="format")]
|
||||
|
||||
AnyModelConfig = Union[
|
||||
_MainModelConfig,
|
||||
|
@ -49,7 +49,7 @@ class FastModelHash(object):
|
||||
def _hash_dir(cls, model_location: Union[str, Path]) -> str:
|
||||
components: Dict[str, str] = {}
|
||||
|
||||
for root, dirs, files in os.walk(model_location):
|
||||
for root, _dirs, files in os.walk(model_location):
|
||||
for file in files:
|
||||
# only tally tensor files because diffusers config files change slightly
|
||||
# depending on how the model was downloaded/converted.
|
||||
@ -61,6 +61,6 @@ class FastModelHash(object):
|
||||
|
||||
# hash all the model hashes together, using alphabetic file order
|
||||
md5 = hashlib.md5()
|
||||
for path, fast_hash in sorted(components.items()):
|
||||
for _path, fast_hash in sorted(components.items()):
|
||||
md5.update(fast_hash.encode("utf-8"))
|
||||
return md5.hexdigest()
|
||||
|
@ -7,9 +7,16 @@ from omegaconf import DictConfig, OmegaConf
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.model_records import DuplicateModelException, ModelRecordServiceSQL
|
||||
from invokeai.app.services.model_records import (
|
||||
DuplicateModelException,
|
||||
ModelRecordServiceSQL,
|
||||
)
|
||||
from invokeai.app.services.shared.sqlite import SqliteDatabase
|
||||
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelType
|
||||
from invokeai.backend.model_manager.config import (
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
ModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.hash import FastModelHash
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
|
@ -193,6 +193,7 @@ class InvokeAIStableDiffusionPipelineOutput(StableDiffusionPipelineOutput):
|
||||
attention_map_saver (`AttentionMapSaver`): Object containing attention maps that can be displayed to the user
|
||||
after generation completes. Optional.
|
||||
"""
|
||||
|
||||
attention_map_saver: Optional[AttentionMapSaver]
|
||||
|
||||
|
||||
|
@ -54,13 +54,13 @@ class Context:
|
||||
self.clear_requests(cleanup=True)
|
||||
|
||||
def register_cross_attention_modules(self, model):
|
||||
for name, module in get_cross_attention_modules(model, CrossAttentionType.SELF):
|
||||
for name, _module in get_cross_attention_modules(model, CrossAttentionType.SELF):
|
||||
if name in self.self_cross_attention_module_identifiers:
|
||||
assert False, f"name {name} cannot appear more than once"
|
||||
raise AssertionError(f"name {name} cannot appear more than once")
|
||||
self.self_cross_attention_module_identifiers.append(name)
|
||||
for name, module in get_cross_attention_modules(model, CrossAttentionType.TOKENS):
|
||||
for name, _module in get_cross_attention_modules(model, CrossAttentionType.TOKENS):
|
||||
if name in self.tokens_cross_attention_module_identifiers:
|
||||
assert False, f"name {name} cannot appear more than once"
|
||||
raise AssertionError(f"name {name} cannot appear more than once")
|
||||
self.tokens_cross_attention_module_identifiers.append(name)
|
||||
|
||||
def request_save_attention_maps(self, cross_attention_type: CrossAttentionType):
|
||||
@ -170,7 +170,7 @@ class Context:
|
||||
self.saved_cross_attention_maps = {}
|
||||
|
||||
def offload_saved_attention_slices_to_cpu(self):
|
||||
for key, map_dict in self.saved_cross_attention_maps.items():
|
||||
for _key, map_dict in self.saved_cross_attention_maps.items():
|
||||
for offset, slice in map_dict["slices"].items():
|
||||
map_dict[offset] = slice.to("cpu")
|
||||
|
||||
@ -433,7 +433,7 @@ def inject_attention_function(unet, context: Context):
|
||||
module.identifier = identifier
|
||||
try:
|
||||
module.set_attention_slice_wrangler(attention_slice_wrangler)
|
||||
module.set_slicing_strategy_getter(lambda module: context.get_slicing_strategy(identifier))
|
||||
module.set_slicing_strategy_getter(lambda module: context.get_slicing_strategy(identifier)) # noqa: B023
|
||||
except AttributeError as e:
|
||||
if is_attribute_error_about(e, "set_attention_slice_wrangler"):
|
||||
print(f"TODO: implement set_attention_slice_wrangler for {type(module)}") # TODO
|
||||
@ -445,7 +445,7 @@ def remove_attention_function(unet):
|
||||
cross_attention_modules = get_cross_attention_modules(
|
||||
unet, CrossAttentionType.TOKENS
|
||||
) + get_cross_attention_modules(unet, CrossAttentionType.SELF)
|
||||
for identifier, module in cross_attention_modules:
|
||||
for _identifier, module in cross_attention_modules:
|
||||
try:
|
||||
# clear wrangler callback
|
||||
module.set_attention_slice_wrangler(None)
|
||||
|
@ -56,7 +56,7 @@ class AttentionMapSaver:
|
||||
|
||||
merged = None
|
||||
|
||||
for key, maps in self.collated_maps.items():
|
||||
for _key, maps in self.collated_maps.items():
|
||||
# maps has shape [(H*W), N] for N tokens
|
||||
# but we want [N, H, W]
|
||||
this_scale_factor = math.sqrt(maps.shape[0] / (latents_width * latents_height))
|
||||
|
@ -123,7 +123,7 @@ class InvokeAIDiffuserComponent:
|
||||
# control_data should be type List[ControlNetData]
|
||||
# this loop covers both ControlNet (one ControlNetData in list)
|
||||
# and MultiControlNet (multiple ControlNetData in list)
|
||||
for i, control_datum in enumerate(control_data):
|
||||
for _i, control_datum in enumerate(control_data):
|
||||
control_mode = control_datum.control_mode
|
||||
# soft_injection and cfg_injection are the two ControlNet control_mode booleans
|
||||
# that are combined at higher level to make control_mode enum
|
||||
@ -214,7 +214,7 @@ class InvokeAIDiffuserComponent:
|
||||
# add controlnet outputs together if have multiple controlnets
|
||||
down_block_res_samples = [
|
||||
samples_prev + samples_curr
|
||||
for samples_prev, samples_curr in zip(down_block_res_samples, down_samples)
|
||||
for samples_prev, samples_curr in zip(down_block_res_samples, down_samples, strict=True)
|
||||
]
|
||||
mid_block_res_sample += mid_sample
|
||||
|
||||
@ -642,7 +642,9 @@ class InvokeAIDiffuserComponent:
|
||||
|
||||
deltas = None
|
||||
uncond_latents = None
|
||||
weighted_cond_list = c_or_weighted_c_list if type(c_or_weighted_c_list) is list else [(c_or_weighted_c_list, 1)]
|
||||
weighted_cond_list = (
|
||||
c_or_weighted_c_list if isinstance(c_or_weighted_c_list, list) else [(c_or_weighted_c_list, 1)]
|
||||
)
|
||||
|
||||
# below is fugly omg
|
||||
conditionings = [uc] + [c for c, weight in weighted_cond_list]
|
||||
|
@ -16,28 +16,28 @@ from diffusers import (
|
||||
UniPCMultistepScheduler,
|
||||
)
|
||||
|
||||
SCHEDULER_MAP = dict(
|
||||
ddim=(DDIMScheduler, dict()),
|
||||
ddpm=(DDPMScheduler, dict()),
|
||||
deis=(DEISMultistepScheduler, dict()),
|
||||
lms=(LMSDiscreteScheduler, dict(use_karras_sigmas=False)),
|
||||
lms_k=(LMSDiscreteScheduler, dict(use_karras_sigmas=True)),
|
||||
pndm=(PNDMScheduler, dict()),
|
||||
heun=(HeunDiscreteScheduler, dict(use_karras_sigmas=False)),
|
||||
heun_k=(HeunDiscreteScheduler, dict(use_karras_sigmas=True)),
|
||||
euler=(EulerDiscreteScheduler, dict(use_karras_sigmas=False)),
|
||||
euler_k=(EulerDiscreteScheduler, dict(use_karras_sigmas=True)),
|
||||
euler_a=(EulerAncestralDiscreteScheduler, dict()),
|
||||
kdpm_2=(KDPM2DiscreteScheduler, dict()),
|
||||
kdpm_2_a=(KDPM2AncestralDiscreteScheduler, dict()),
|
||||
dpmpp_2s=(DPMSolverSinglestepScheduler, dict(use_karras_sigmas=False)),
|
||||
dpmpp_2s_k=(DPMSolverSinglestepScheduler, dict(use_karras_sigmas=True)),
|
||||
dpmpp_2m=(DPMSolverMultistepScheduler, dict(use_karras_sigmas=False)),
|
||||
dpmpp_2m_k=(DPMSolverMultistepScheduler, dict(use_karras_sigmas=True)),
|
||||
dpmpp_2m_sde=(DPMSolverMultistepScheduler, dict(use_karras_sigmas=False, algorithm_type="sde-dpmsolver++")),
|
||||
dpmpp_2m_sde_k=(DPMSolverMultistepScheduler, dict(use_karras_sigmas=True, algorithm_type="sde-dpmsolver++")),
|
||||
dpmpp_sde=(DPMSolverSDEScheduler, dict(use_karras_sigmas=False, noise_sampler_seed=0)),
|
||||
dpmpp_sde_k=(DPMSolverSDEScheduler, dict(use_karras_sigmas=True, noise_sampler_seed=0)),
|
||||
unipc=(UniPCMultistepScheduler, dict(cpu_only=True)),
|
||||
lcm=(LCMScheduler, dict()),
|
||||
)
|
||||
SCHEDULER_MAP = {
|
||||
"ddim": (DDIMScheduler, {}),
|
||||
"ddpm": (DDPMScheduler, {}),
|
||||
"deis": (DEISMultistepScheduler, {}),
|
||||
"lms": (LMSDiscreteScheduler, {"use_karras_sigmas": False}),
|
||||
"lms_k": (LMSDiscreteScheduler, {"use_karras_sigmas": True}),
|
||||
"pndm": (PNDMScheduler, {}),
|
||||
"heun": (HeunDiscreteScheduler, {"use_karras_sigmas": False}),
|
||||
"heun_k": (HeunDiscreteScheduler, {"use_karras_sigmas": True}),
|
||||
"euler": (EulerDiscreteScheduler, {"use_karras_sigmas": False}),
|
||||
"euler_k": (EulerDiscreteScheduler, {"use_karras_sigmas": True}),
|
||||
"euler_a": (EulerAncestralDiscreteScheduler, {}),
|
||||
"kdpm_2": (KDPM2DiscreteScheduler, {}),
|
||||
"kdpm_2_a": (KDPM2AncestralDiscreteScheduler, {}),
|
||||
"dpmpp_2s": (DPMSolverSinglestepScheduler, {"use_karras_sigmas": False}),
|
||||
"dpmpp_2s_k": (DPMSolverSinglestepScheduler, {"use_karras_sigmas": True}),
|
||||
"dpmpp_2m": (DPMSolverMultistepScheduler, {"use_karras_sigmas": False}),
|
||||
"dpmpp_2m_k": (DPMSolverMultistepScheduler, {"use_karras_sigmas": True}),
|
||||
"dpmpp_2m_sde": (DPMSolverMultistepScheduler, {"use_karras_sigmas": False, "algorithm_type": "sde-dpmsolver++"}),
|
||||
"dpmpp_2m_sde_k": (DPMSolverMultistepScheduler, {"use_karras_sigmas": True, "algorithm_type": "sde-dpmsolver++"}),
|
||||
"dpmpp_sde": (DPMSolverSDEScheduler, {"use_karras_sigmas": False, "noise_sampler_seed": 0}),
|
||||
"dpmpp_sde_k": (DPMSolverSDEScheduler, {"use_karras_sigmas": True, "noise_sampler_seed": 0}),
|
||||
"unipc": (UniPCMultistepScheduler, {"cpu_only": True}),
|
||||
"lcm": (LCMScheduler, {}),
|
||||
}
|
||||
|
@ -615,7 +615,7 @@ def do_textual_inversion_training(
|
||||
vae_info = model_manager.get_model(*model_meta, submodel=SubModelType.Vae)
|
||||
unet_info = model_manager.get_model(*model_meta, submodel=SubModelType.UNet)
|
||||
|
||||
pipeline_args = dict(local_files_only=True)
|
||||
pipeline_args = {"local_files_only": True}
|
||||
if tokenizer_name:
|
||||
tokenizer = CLIPTokenizer.from_pretrained(tokenizer_name, **pipeline_args)
|
||||
else:
|
||||
|
@ -732,7 +732,9 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
|
||||
|
||||
controlnet_down_block_res_samples = ()
|
||||
|
||||
for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):
|
||||
for down_block_res_sample, controlnet_block in zip(
|
||||
down_block_res_samples, self.controlnet_down_blocks, strict=True
|
||||
):
|
||||
down_block_res_sample = controlnet_block(down_block_res_sample)
|
||||
controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,)
|
||||
|
||||
@ -745,7 +747,9 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
|
||||
scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0
|
||||
|
||||
scales = scales * conditioning_scale
|
||||
down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)]
|
||||
down_block_res_samples = [
|
||||
sample * scale for sample, scale in zip(down_block_res_samples, scales, strict=False)
|
||||
]
|
||||
mid_block_res_sample = mid_block_res_sample * scales[-1] # last one
|
||||
else:
|
||||
down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
|
||||
|
@ -225,34 +225,34 @@ def basicConfig(**kwargs):
|
||||
|
||||
|
||||
_FACILITY_MAP = (
|
||||
dict(
|
||||
LOG_KERN=syslog.LOG_KERN,
|
||||
LOG_USER=syslog.LOG_USER,
|
||||
LOG_MAIL=syslog.LOG_MAIL,
|
||||
LOG_DAEMON=syslog.LOG_DAEMON,
|
||||
LOG_AUTH=syslog.LOG_AUTH,
|
||||
LOG_LPR=syslog.LOG_LPR,
|
||||
LOG_NEWS=syslog.LOG_NEWS,
|
||||
LOG_UUCP=syslog.LOG_UUCP,
|
||||
LOG_CRON=syslog.LOG_CRON,
|
||||
LOG_SYSLOG=syslog.LOG_SYSLOG,
|
||||
LOG_LOCAL0=syslog.LOG_LOCAL0,
|
||||
LOG_LOCAL1=syslog.LOG_LOCAL1,
|
||||
LOG_LOCAL2=syslog.LOG_LOCAL2,
|
||||
LOG_LOCAL3=syslog.LOG_LOCAL3,
|
||||
LOG_LOCAL4=syslog.LOG_LOCAL4,
|
||||
LOG_LOCAL5=syslog.LOG_LOCAL5,
|
||||
LOG_LOCAL6=syslog.LOG_LOCAL6,
|
||||
LOG_LOCAL7=syslog.LOG_LOCAL7,
|
||||
)
|
||||
{
|
||||
"LOG_KERN": syslog.LOG_KERN,
|
||||
"LOG_USER": syslog.LOG_USER,
|
||||
"LOG_MAIL": syslog.LOG_MAIL,
|
||||
"LOG_DAEMON": syslog.LOG_DAEMON,
|
||||
"LOG_AUTH": syslog.LOG_AUTH,
|
||||
"LOG_LPR": syslog.LOG_LPR,
|
||||
"LOG_NEWS": syslog.LOG_NEWS,
|
||||
"LOG_UUCP": syslog.LOG_UUCP,
|
||||
"LOG_CRON": syslog.LOG_CRON,
|
||||
"LOG_SYSLOG": syslog.LOG_SYSLOG,
|
||||
"LOG_LOCAL0": syslog.LOG_LOCAL0,
|
||||
"LOG_LOCAL1": syslog.LOG_LOCAL1,
|
||||
"LOG_LOCAL2": syslog.LOG_LOCAL2,
|
||||
"LOG_LOCAL3": syslog.LOG_LOCAL3,
|
||||
"LOG_LOCAL4": syslog.LOG_LOCAL4,
|
||||
"LOG_LOCAL5": syslog.LOG_LOCAL5,
|
||||
"LOG_LOCAL6": syslog.LOG_LOCAL6,
|
||||
"LOG_LOCAL7": syslog.LOG_LOCAL7,
|
||||
}
|
||||
if SYSLOG_AVAILABLE
|
||||
else dict()
|
||||
else {}
|
||||
)
|
||||
|
||||
_SOCK_MAP = dict(
|
||||
SOCK_STREAM=socket.SOCK_STREAM,
|
||||
SOCK_DGRAM=socket.SOCK_DGRAM,
|
||||
)
|
||||
_SOCK_MAP = {
|
||||
"SOCK_STREAM": socket.SOCK_STREAM,
|
||||
"SOCK_DGRAM": socket.SOCK_DGRAM,
|
||||
}
|
||||
|
||||
|
||||
class InvokeAIFormatter(logging.Formatter):
|
||||
@ -344,7 +344,7 @@ LOG_FORMATTERS = {
|
||||
|
||||
|
||||
class InvokeAILogger(object):
|
||||
loggers = dict()
|
||||
loggers = {}
|
||||
|
||||
@classmethod
|
||||
def get_logger(
|
||||
@ -364,7 +364,7 @@ class InvokeAILogger(object):
|
||||
@classmethod
|
||||
def get_loggers(cls, config: InvokeAIAppConfig) -> list[logging.Handler]:
|
||||
handler_strs = config.log_handlers
|
||||
handlers = list()
|
||||
handlers = []
|
||||
for handler in handler_strs:
|
||||
handler_name, *args = handler.split("=", 2)
|
||||
args = args[0] if len(args) > 0 else None
|
||||
@ -398,7 +398,7 @@ class InvokeAILogger(object):
|
||||
raise ValueError("syslog is not available on this system")
|
||||
if not args:
|
||||
args = "/dev/log" if Path("/dev/log").exists() else "address:localhost:514"
|
||||
syslog_args = dict()
|
||||
syslog_args = {}
|
||||
try:
|
||||
for a in args.split(","):
|
||||
arg_name, *arg_value = a.split(":", 2)
|
||||
@ -434,7 +434,7 @@ class InvokeAILogger(object):
|
||||
path = url.path
|
||||
port = url.port or 80
|
||||
|
||||
syslog_args = dict()
|
||||
syslog_args = {}
|
||||
for a in arg_list:
|
||||
arg_name, *arg_value = a.split(":", 2)
|
||||
if arg_name == "method":
|
||||
|
@ -29,7 +29,7 @@ def log_txt_as_img(wh, xc, size=10):
|
||||
# wh a tuple of (width, height)
|
||||
# xc a list of captions to plot
|
||||
b = len(xc)
|
||||
txts = list()
|
||||
txts = []
|
||||
for bi in range(b):
|
||||
txt = Image.new("RGB", wh, color="white")
|
||||
draw = ImageDraw.Draw(txt)
|
||||
@ -93,7 +93,7 @@ def instantiate_from_config(config, **kwargs):
|
||||
elif config == "__is_unconditional__":
|
||||
return None
|
||||
raise KeyError("Expected key `target` to instantiate.")
|
||||
return get_obj_from_str(config["target"])(**config.get("params", dict()), **kwargs)
|
||||
return get_obj_from_str(config["target"])(**config.get("params", {}), **kwargs)
|
||||
|
||||
|
||||
def get_obj_from_str(string, reload=False):
|
||||
@ -231,11 +231,12 @@ def rand_perlin_2d(shape, res, device, fade=lambda t: 6 * t**5 - 15 * t**4 + 10
|
||||
angles = 2 * math.pi * rand_val
|
||||
gradients = torch.stack((torch.cos(angles), torch.sin(angles)), dim=-1).to(device)
|
||||
|
||||
tile_grads = (
|
||||
lambda slice1, slice2: gradients[slice1[0] : slice1[1], slice2[0] : slice2[1]]
|
||||
.repeat_interleave(d[0], 0)
|
||||
.repeat_interleave(d[1], 1)
|
||||
)
|
||||
def tile_grads(slice1, slice2):
|
||||
return (
|
||||
gradients[slice1[0] : slice1[1], slice2[0] : slice2[1]]
|
||||
.repeat_interleave(d[0], 0)
|
||||
.repeat_interleave(d[1], 1)
|
||||
)
|
||||
|
||||
def dot(grad, shift):
|
||||
return (
|
||||
|
@ -341,19 +341,19 @@ class InvokeAIMetadataParser:
|
||||
# this was more elegant as a case statement, but that's not available in python 3.9
|
||||
if old_scheduler is None:
|
||||
return None
|
||||
scheduler_map = dict(
|
||||
ddim="ddim",
|
||||
plms="pnmd",
|
||||
k_lms="lms",
|
||||
k_dpm_2="kdpm_2",
|
||||
k_dpm_2_a="kdpm_2_a",
|
||||
dpmpp_2="dpmpp_2s",
|
||||
k_dpmpp_2="dpmpp_2m",
|
||||
k_dpmpp_2_a=None, # invalid, in 2.3.x, selecting this sample would just fallback to last run or plms if new session
|
||||
k_euler="euler",
|
||||
k_euler_a="euler_a",
|
||||
k_heun="heun",
|
||||
)
|
||||
scheduler_map = {
|
||||
"ddim": "ddim",
|
||||
"plms": "pnmd",
|
||||
"k_lms": "lms",
|
||||
"k_dpm_2": "kdpm_2",
|
||||
"k_dpm_2_a": "kdpm_2_a",
|
||||
"dpmpp_2": "dpmpp_2s",
|
||||
"k_dpmpp_2": "dpmpp_2m",
|
||||
"k_dpmpp_2_a": None, # invalid, in 2.3.x, selecting this sample would just fallback to last run or plms if new session
|
||||
"k_euler": "euler",
|
||||
"k_euler_a": "euler_a",
|
||||
"k_heun": "heun",
|
||||
}
|
||||
return scheduler_map.get(old_scheduler)
|
||||
|
||||
def split_prompt(self, raw_prompt: str):
|
||||
|
@ -72,7 +72,7 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
||||
def __init__(self, parentApp, name, multipage=False, *args, **keywords):
|
||||
self.multipage = multipage
|
||||
self.subprocess = None
|
||||
super().__init__(parentApp=parentApp, name=name, *args, **keywords)
|
||||
super().__init__(parentApp=parentApp, name=name, *args, **keywords) # noqa: B026 # TODO: maybe this is bad?
|
||||
|
||||
def create(self):
|
||||
self.keypress_timeout = 10
|
||||
@ -203,14 +203,14 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
||||
)
|
||||
|
||||
# This restores the selected page on return from an installation
|
||||
for i in range(1, self.current_tab + 1):
|
||||
for _i in range(1, self.current_tab + 1):
|
||||
self.tabs.h_cursor_line_down(1)
|
||||
self._toggle_tables([self.current_tab])
|
||||
|
||||
############# diffusers tab ##########
|
||||
def add_starter_pipelines(self) -> dict[str, npyscreen.widget]:
|
||||
"""Add widgets responsible for selecting diffusers models"""
|
||||
widgets = dict()
|
||||
widgets = {}
|
||||
models = self.all_models
|
||||
starters = self.starter_models
|
||||
starter_model_labels = self.model_labels
|
||||
@ -258,10 +258,12 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
||||
model_type: ModelType,
|
||||
window_width: int = 120,
|
||||
install_prompt: str = None,
|
||||
exclude: set = set(),
|
||||
exclude: set = None,
|
||||
) -> dict[str, npyscreen.widget]:
|
||||
"""Generic code to create model selection widgets"""
|
||||
widgets = dict()
|
||||
if exclude is None:
|
||||
exclude = set()
|
||||
widgets = {}
|
||||
model_list = [x for x in self.all_models if self.all_models[x].model_type == model_type and x not in exclude]
|
||||
model_labels = [self.model_labels[x] for x in model_list]
|
||||
|
||||
@ -366,13 +368,13 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
||||
]
|
||||
|
||||
for group in widgets:
|
||||
for k, v in group.items():
|
||||
for _k, v in group.items():
|
||||
try:
|
||||
v.hidden = True
|
||||
v.editable = False
|
||||
except Exception:
|
||||
pass
|
||||
for k, v in widgets[selected_tab].items():
|
||||
for _k, v in widgets[selected_tab].items():
|
||||
try:
|
||||
v.hidden = False
|
||||
if not isinstance(v, (npyscreen.FixedText, npyscreen.TitleFixedText, CenteredTitleText)):
|
||||
@ -391,7 +393,7 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
||||
label_width = max([len(models[x].name) for x in models])
|
||||
description_width = window_width - label_width - checkbox_width - spacing_width
|
||||
|
||||
result = dict()
|
||||
result = {}
|
||||
for x in models.keys():
|
||||
description = models[x].description
|
||||
description = (
|
||||
@ -433,11 +435,11 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
||||
parent_conn, child_conn = Pipe()
|
||||
p = Process(
|
||||
target=process_and_execute,
|
||||
kwargs=dict(
|
||||
opt=app.program_opts,
|
||||
selections=app.install_selections,
|
||||
conn_out=child_conn,
|
||||
),
|
||||
kwargs={
|
||||
"opt": app.program_opts,
|
||||
"selections": app.install_selections,
|
||||
"conn_out": child_conn,
|
||||
},
|
||||
)
|
||||
p.start()
|
||||
child_conn.close()
|
||||
@ -558,7 +560,7 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
||||
for section in ui_sections:
|
||||
if "models_selected" not in section:
|
||||
continue
|
||||
selected = set([section["models"][x] for x in section["models_selected"].value])
|
||||
selected = {section["models"][x] for x in section["models_selected"].value}
|
||||
models_to_install = [x for x in selected if not self.all_models[x].installed]
|
||||
models_to_remove = [x for x in section["models"] if x not in selected and self.all_models[x].installed]
|
||||
selections.remove_models.extend(models_to_remove)
|
||||
|
@ -11,6 +11,7 @@ import sys
|
||||
import textwrap
|
||||
from curses import BUTTON2_CLICKED, BUTTON3_CLICKED
|
||||
from shutil import get_terminal_size
|
||||
from typing import Optional
|
||||
|
||||
import npyscreen
|
||||
import npyscreen.wgmultiline as wgmultiline
|
||||
@ -243,7 +244,9 @@ class SelectColumnBase:
|
||||
|
||||
|
||||
class MultiSelectColumns(SelectColumnBase, npyscreen.MultiSelect):
|
||||
def __init__(self, screen, columns: int = 1, values: list = [], **keywords):
|
||||
def __init__(self, screen, columns: int = 1, values: Optional[list] = None, **keywords):
|
||||
if values is None:
|
||||
values = []
|
||||
self.columns = columns
|
||||
self.value_cnt = len(values)
|
||||
self.rows = math.ceil(self.value_cnt / self.columns)
|
||||
@ -267,7 +270,9 @@ class SingleSelectWithChanged(npyscreen.SelectOne):
|
||||
class SingleSelectColumnsSimple(SelectColumnBase, SingleSelectWithChanged):
|
||||
"""Row of radio buttons. Spacebar to select."""
|
||||
|
||||
def __init__(self, screen, columns: int = 1, values: list = [], **keywords):
|
||||
def __init__(self, screen, columns: int = 1, values: list = None, **keywords):
|
||||
if values is None:
|
||||
values = []
|
||||
self.columns = columns
|
||||
self.value_cnt = len(values)
|
||||
self.rows = math.ceil(self.value_cnt / self.columns)
|
||||
|
@ -275,14 +275,14 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
|
||||
interp = self.interpolations[self.merge_method.value[0]]
|
||||
|
||||
bases = ["sd-1", "sd-2", "sdxl"]
|
||||
args = dict(
|
||||
model_names=models,
|
||||
base_model=BaseModelType(bases[self.base_select.value[0]]),
|
||||
alpha=self.alpha.value,
|
||||
interp=interp,
|
||||
force=self.force.value,
|
||||
merged_model_name=self.merged_model_name.value,
|
||||
)
|
||||
args = {
|
||||
"model_names": models,
|
||||
"base_model": BaseModelType(bases[self.base_select.value[0]]),
|
||||
"alpha": self.alpha.value,
|
||||
"interp": interp,
|
||||
"force": self.force.value,
|
||||
"merged_model_name": self.merged_model_name.value,
|
||||
}
|
||||
return args
|
||||
|
||||
def check_for_overwrite(self) -> bool:
|
||||
@ -297,7 +297,7 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
|
||||
def validate_field_values(self) -> bool:
|
||||
bad_fields = []
|
||||
model_names = self.model_names
|
||||
selected_models = set((model_names[self.model1.value[0]], model_names[self.model2.value[0]]))
|
||||
selected_models = {model_names[self.model1.value[0]], model_names[self.model2.value[0]]}
|
||||
if self.model3.value[0] > 0:
|
||||
selected_models.add(model_names[self.model3.value[0] - 1])
|
||||
if len(selected_models) < 2:
|
||||
|
@ -276,13 +276,13 @@ class textualInversionForm(npyscreen.FormMultiPageAction):
|
||||
|
||||
def get_model_names(self) -> Tuple[List[str], int]:
|
||||
conf = OmegaConf.load(config.root_dir / "configs/models.yaml")
|
||||
model_names = [idx for idx in sorted(list(conf.keys())) if conf[idx].get("format", None) == "diffusers"]
|
||||
model_names = [idx for idx in sorted(conf.keys()) if conf[idx].get("format", None) == "diffusers"]
|
||||
defaults = [idx for idx in range(len(model_names)) if "default" in conf[model_names[idx]]]
|
||||
default = defaults[0] if len(defaults) > 0 else 0
|
||||
return (model_names, default)
|
||||
|
||||
def marshall_arguments(self) -> dict:
|
||||
args = dict()
|
||||
args = {}
|
||||
|
||||
# the choices
|
||||
args.update(
|
||||
|
@ -24,6 +24,7 @@ module.exports = {
|
||||
root: true,
|
||||
rules: {
|
||||
curly: 'error',
|
||||
'react/jsx-no-bind': ['error', { allowBind: true }],
|
||||
'react/jsx-curly-brace-presence': [
|
||||
'error',
|
||||
{ props: 'never', children: 'never' },
|
||||
|
@ -583,6 +583,7 @@
|
||||
"strength": "Image to image strength",
|
||||
"Threshold": "Noise Threshold",
|
||||
"variations": "Seed-weight pairs",
|
||||
"vae": "VAE",
|
||||
"width": "Width",
|
||||
"workflow": "Workflow"
|
||||
},
|
||||
|
@ -1487,5 +1487,18 @@
|
||||
"scheduler": "Campionatore",
|
||||
"recallParameters": "Richiama i parametri",
|
||||
"noRecallParameters": "Nessun parametro da richiamare trovato"
|
||||
},
|
||||
"hrf": {
|
||||
"enableHrf": "Abilita Correzione Alta Risoluzione",
|
||||
"upscaleMethod": "Metodo di ampliamento",
|
||||
"enableHrfTooltip": "Genera con una risoluzione iniziale inferiore, esegue l'ampliamento alla risoluzione di base, quindi esegue Immagine a Immagine.",
|
||||
"metadata": {
|
||||
"strength": "Forza della Correzione Alta Risoluzione",
|
||||
"enabled": "Correzione Alta Risoluzione Abilitata",
|
||||
"method": "Metodo della Correzione Alta Risoluzione"
|
||||
},
|
||||
"hrf": "Correzione Alta Risoluzione",
|
||||
"hrfStrength": "Forza della Correzione Alta Risoluzione",
|
||||
"strengthTooltip": "Valori più bassi comportano meno dettagli, il che può ridurre potenziali artefatti."
|
||||
}
|
||||
}
|
||||
|
@ -19,7 +19,7 @@ import sdxlReducer from 'features/sdxl/store/sdxlSlice';
|
||||
import configReducer from 'features/system/store/configSlice';
|
||||
import systemReducer from 'features/system/store/systemSlice';
|
||||
import queueReducer from 'features/queue/store/queueSlice';
|
||||
import modelmanagerReducer from 'features/ui/components/tabs/ModelManager/store/modelManagerSlice';
|
||||
import modelmanagerReducer from 'features/modelManager/store/modelManagerSlice';
|
||||
import hotkeysReducer from 'features/ui/store/hotkeysSlice';
|
||||
import uiReducer from 'features/ui/store/uiSlice';
|
||||
import dynamicMiddlewares from 'redux-dynamic-middlewares';
|
||||
|
@ -8,7 +8,14 @@ import {
|
||||
forwardRef,
|
||||
useDisclosure,
|
||||
} from '@chakra-ui/react';
|
||||
import { cloneElement, memo, ReactElement, ReactNode, useRef } from 'react';
|
||||
import {
|
||||
cloneElement,
|
||||
memo,
|
||||
ReactElement,
|
||||
ReactNode,
|
||||
useCallback,
|
||||
useRef,
|
||||
} from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import IAIButton from './IAIButton';
|
||||
|
||||
@ -38,15 +45,15 @@ const IAIAlertDialog = forwardRef((props: Props, ref) => {
|
||||
const { isOpen, onOpen, onClose } = useDisclosure();
|
||||
const cancelRef = useRef<HTMLButtonElement | null>(null);
|
||||
|
||||
const handleAccept = () => {
|
||||
const handleAccept = useCallback(() => {
|
||||
acceptCallback();
|
||||
onClose();
|
||||
};
|
||||
}, [acceptCallback, onClose]);
|
||||
|
||||
const handleCancel = () => {
|
||||
const handleCancel = useCallback(() => {
|
||||
cancelCallback && cancelCallback();
|
||||
onClose();
|
||||
};
|
||||
}, [cancelCallback, onClose]);
|
||||
|
||||
return (
|
||||
<>
|
||||
|
@ -1,9 +1,12 @@
|
||||
import { Box, ChakraProps } from '@chakra-ui/react';
|
||||
import { memo } from 'react';
|
||||
import { ChakraProps, Flex } from '@chakra-ui/react';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { RgbaColorPicker } from 'react-colorful';
|
||||
import { ColorPickerBaseProps, RgbaColor } from 'react-colorful/dist/types';
|
||||
import IAINumberInput from './IAINumberInput';
|
||||
|
||||
type IAIColorPickerProps = ColorPickerBaseProps<RgbaColor>;
|
||||
type IAIColorPickerProps = ColorPickerBaseProps<RgbaColor> & {
|
||||
withNumberInput?: boolean;
|
||||
};
|
||||
|
||||
const colorPickerStyles: NonNullable<ChakraProps['sx']> = {
|
||||
width: 6,
|
||||
@ -11,17 +14,84 @@ const colorPickerStyles: NonNullable<ChakraProps['sx']> = {
|
||||
borderColor: 'base.100',
|
||||
};
|
||||
|
||||
const sx = {
|
||||
const sx: ChakraProps['sx'] = {
|
||||
'.react-colorful__hue-pointer': colorPickerStyles,
|
||||
'.react-colorful__saturation-pointer': colorPickerStyles,
|
||||
'.react-colorful__alpha-pointer': colorPickerStyles,
|
||||
gap: 2,
|
||||
flexDir: 'column',
|
||||
};
|
||||
|
||||
const numberInputWidth: ChakraProps['w'] = '4.2rem';
|
||||
|
||||
const IAIColorPicker = (props: IAIColorPickerProps) => {
|
||||
const { color, onChange, withNumberInput, ...rest } = props;
|
||||
const handleChangeR = useCallback(
|
||||
(r: number) => onChange({ ...color, r }),
|
||||
[color, onChange]
|
||||
);
|
||||
const handleChangeG = useCallback(
|
||||
(g: number) => onChange({ ...color, g }),
|
||||
[color, onChange]
|
||||
);
|
||||
const handleChangeB = useCallback(
|
||||
(b: number) => onChange({ ...color, b }),
|
||||
[color, onChange]
|
||||
);
|
||||
const handleChangeA = useCallback(
|
||||
(a: number) => onChange({ ...color, a }),
|
||||
[color, onChange]
|
||||
);
|
||||
return (
|
||||
<Box sx={sx}>
|
||||
<RgbaColorPicker {...props} />
|
||||
</Box>
|
||||
<Flex sx={sx}>
|
||||
<RgbaColorPicker
|
||||
color={color}
|
||||
onChange={onChange}
|
||||
style={{ width: '100%' }}
|
||||
{...rest}
|
||||
/>
|
||||
{withNumberInput && (
|
||||
<Flex>
|
||||
<IAINumberInput
|
||||
value={color.r}
|
||||
onChange={handleChangeR}
|
||||
min={0}
|
||||
max={255}
|
||||
step={1}
|
||||
label="Red"
|
||||
w={numberInputWidth}
|
||||
/>
|
||||
<IAINumberInput
|
||||
value={color.g}
|
||||
onChange={handleChangeG}
|
||||
min={0}
|
||||
max={255}
|
||||
step={1}
|
||||
label="Green"
|
||||
w={numberInputWidth}
|
||||
/>
|
||||
<IAINumberInput
|
||||
value={color.b}
|
||||
onChange={handleChangeB}
|
||||
min={0}
|
||||
max={255}
|
||||
step={1}
|
||||
label="Blue"
|
||||
w={numberInputWidth}
|
||||
/>
|
||||
<IAINumberInput
|
||||
value={color.a}
|
||||
onChange={handleChangeA}
|
||||
step={0.1}
|
||||
min={0}
|
||||
max={1}
|
||||
label="Alpha"
|
||||
w={numberInputWidth}
|
||||
isInteger={false}
|
||||
/>
|
||||
</Flex>
|
||||
)}
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
|
@ -1,43 +0,0 @@
|
||||
import { Box, Flex, Icon } from '@chakra-ui/react';
|
||||
import { memo } from 'react';
|
||||
import { FaExclamation } from 'react-icons/fa';
|
||||
|
||||
const IAIErrorLoadingImageFallback = () => {
|
||||
return (
|
||||
<Box
|
||||
sx={{
|
||||
position: 'relative',
|
||||
height: 'full',
|
||||
width: 'full',
|
||||
'::before': {
|
||||
content: "''",
|
||||
display: 'block',
|
||||
pt: '100%',
|
||||
},
|
||||
}}
|
||||
>
|
||||
<Flex
|
||||
sx={{
|
||||
position: 'absolute',
|
||||
top: 0,
|
||||
insetInlineStart: 0,
|
||||
height: 'full',
|
||||
width: 'full',
|
||||
alignItems: 'center',
|
||||
justifyContent: 'center',
|
||||
borderRadius: 'base',
|
||||
bg: 'base.100',
|
||||
color: 'base.500',
|
||||
_dark: {
|
||||
color: 'base.700',
|
||||
bg: 'base.850',
|
||||
},
|
||||
}}
|
||||
>
|
||||
<Icon as={FaExclamation} boxSize={16} opacity={0.7} />
|
||||
</Flex>
|
||||
</Box>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(IAIErrorLoadingImageFallback);
|
@ -1,8 +0,0 @@
|
||||
import { chakra } from '@chakra-ui/react';
|
||||
|
||||
/**
|
||||
* Chakra-enabled <form />
|
||||
*/
|
||||
const IAIForm = chakra.form;
|
||||
|
||||
export default IAIForm;
|
@ -1,15 +0,0 @@
|
||||
import { FormErrorMessage, FormErrorMessageProps } from '@chakra-ui/react';
|
||||
import { ReactNode } from 'react';
|
||||
|
||||
type IAIFormErrorMessageProps = FormErrorMessageProps & {
|
||||
children: ReactNode | string;
|
||||
};
|
||||
|
||||
export default function IAIFormErrorMessage(props: IAIFormErrorMessageProps) {
|
||||
const { children, ...rest } = props;
|
||||
return (
|
||||
<FormErrorMessage color="error.400" {...rest}>
|
||||
{children}
|
||||
</FormErrorMessage>
|
||||
);
|
||||
}
|
@ -1,15 +0,0 @@
|
||||
import { FormHelperText, FormHelperTextProps } from '@chakra-ui/react';
|
||||
import { ReactNode } from 'react';
|
||||
|
||||
type IAIFormHelperTextProps = FormHelperTextProps & {
|
||||
children: ReactNode | string;
|
||||
};
|
||||
|
||||
export default function IAIFormHelperText(props: IAIFormHelperTextProps) {
|
||||
const { children, ...rest } = props;
|
||||
return (
|
||||
<FormHelperText margin={0} color="base.400" {...rest}>
|
||||
{children}
|
||||
</FormHelperText>
|
||||
);
|
||||
}
|
@ -1,25 +0,0 @@
|
||||
import { Flex, useColorMode } from '@chakra-ui/react';
|
||||
import { ReactElement } from 'react';
|
||||
import { mode } from 'theme/util/mode';
|
||||
|
||||
export function IAIFormItemWrapper({
|
||||
children,
|
||||
}: {
|
||||
children: ReactElement | ReactElement[];
|
||||
}) {
|
||||
const { colorMode } = useColorMode();
|
||||
return (
|
||||
<Flex
|
||||
sx={{
|
||||
flexDirection: 'column',
|
||||
padding: 4,
|
||||
rowGap: 4,
|
||||
borderRadius: 'base',
|
||||
width: 'full',
|
||||
bg: mode('base.100', 'base.900')(colorMode),
|
||||
}}
|
||||
>
|
||||
{children}
|
||||
</Flex>
|
||||
);
|
||||
}
|
@ -1,25 +0,0 @@
|
||||
import {
|
||||
Checkbox,
|
||||
CheckboxProps,
|
||||
FormControl,
|
||||
FormControlProps,
|
||||
FormLabel,
|
||||
} from '@chakra-ui/react';
|
||||
import { memo, ReactNode } from 'react';
|
||||
|
||||
type IAIFullCheckboxProps = CheckboxProps & {
|
||||
label: string | ReactNode;
|
||||
formControlProps?: FormControlProps;
|
||||
};
|
||||
|
||||
const IAIFullCheckbox = (props: IAIFullCheckboxProps) => {
|
||||
const { label, formControlProps, ...rest } = props;
|
||||
return (
|
||||
<FormControl {...formControlProps}>
|
||||
<FormLabel>{label}</FormLabel>
|
||||
<Checkbox colorScheme="accent" {...rest} />
|
||||
</FormControl>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(IAIFullCheckbox);
|
@ -1,6 +1,7 @@
|
||||
import { useColorMode } from '@chakra-ui/react';
|
||||
import { TextInput, TextInputProps } from '@mantine/core';
|
||||
import { useChakraThemeTokens } from 'common/hooks/useChakraThemeTokens';
|
||||
import { useCallback } from 'react';
|
||||
import { mode } from 'theme/util/mode';
|
||||
|
||||
type IAIMantineTextInputProps = TextInputProps;
|
||||
@ -20,26 +21,37 @@ export default function IAIMantineTextInput(props: IAIMantineTextInputProps) {
|
||||
} = useChakraThemeTokens();
|
||||
const { colorMode } = useColorMode();
|
||||
|
||||
return (
|
||||
<TextInput
|
||||
styles={() => ({
|
||||
input: {
|
||||
color: mode(base900, base100)(colorMode),
|
||||
backgroundColor: mode(base50, base900)(colorMode),
|
||||
borderColor: mode(base200, base800)(colorMode),
|
||||
borderWidth: 2,
|
||||
outline: 'none',
|
||||
':focus': {
|
||||
borderColor: mode(accent300, accent500)(colorMode),
|
||||
},
|
||||
const stylesFunc = useCallback(
|
||||
() => ({
|
||||
input: {
|
||||
color: mode(base900, base100)(colorMode),
|
||||
backgroundColor: mode(base50, base900)(colorMode),
|
||||
borderColor: mode(base200, base800)(colorMode),
|
||||
borderWidth: 2,
|
||||
outline: 'none',
|
||||
':focus': {
|
||||
borderColor: mode(accent300, accent500)(colorMode),
|
||||
},
|
||||
label: {
|
||||
color: mode(base700, base300)(colorMode),
|
||||
fontWeight: 'normal',
|
||||
marginBottom: 4,
|
||||
},
|
||||
})}
|
||||
{...rest}
|
||||
/>
|
||||
},
|
||||
label: {
|
||||
color: mode(base700, base300)(colorMode),
|
||||
fontWeight: 'normal' as const,
|
||||
marginBottom: 4,
|
||||
},
|
||||
}),
|
||||
[
|
||||
accent300,
|
||||
accent500,
|
||||
base100,
|
||||
base200,
|
||||
base300,
|
||||
base50,
|
||||
base700,
|
||||
base800,
|
||||
base900,
|
||||
colorMode,
|
||||
]
|
||||
);
|
||||
|
||||
return <TextInput styles={stylesFunc} {...rest} />;
|
||||
}
|
||||
|
@ -98,28 +98,34 @@ const IAINumberInput = forwardRef((props: Props, ref) => {
|
||||
}
|
||||
}, [value, valueAsString]);
|
||||
|
||||
const handleOnChange = (v: string) => {
|
||||
setValueAsString(v);
|
||||
// This allows negatives and decimals e.g. '-123', `.5`, `-0.2`, etc.
|
||||
if (!v.match(numberStringRegex)) {
|
||||
// Cast the value to number. Floor it if it should be an integer.
|
||||
onChange(isInteger ? Math.floor(Number(v)) : Number(v));
|
||||
}
|
||||
};
|
||||
const handleOnChange = useCallback(
|
||||
(v: string) => {
|
||||
setValueAsString(v);
|
||||
// This allows negatives and decimals e.g. '-123', `.5`, `-0.2`, etc.
|
||||
if (!v.match(numberStringRegex)) {
|
||||
// Cast the value to number. Floor it if it should be an integer.
|
||||
onChange(isInteger ? Math.floor(Number(v)) : Number(v));
|
||||
}
|
||||
},
|
||||
[isInteger, onChange]
|
||||
);
|
||||
|
||||
/**
|
||||
* Clicking the steppers allows the value to go outside bounds; we need to
|
||||
* clamp it on blur and floor it if needed.
|
||||
*/
|
||||
const handleBlur = (e: FocusEvent<HTMLInputElement>) => {
|
||||
const clamped = clamp(
|
||||
isInteger ? Math.floor(Number(e.target.value)) : Number(e.target.value),
|
||||
min,
|
||||
max
|
||||
);
|
||||
setValueAsString(String(clamped));
|
||||
onChange(clamped);
|
||||
};
|
||||
const handleBlur = useCallback(
|
||||
(e: FocusEvent<HTMLInputElement>) => {
|
||||
const clamped = clamp(
|
||||
isInteger ? Math.floor(Number(e.target.value)) : Number(e.target.value),
|
||||
min,
|
||||
max
|
||||
);
|
||||
setValueAsString(String(clamped));
|
||||
onChange(clamped);
|
||||
},
|
||||
[isInteger, max, min, onChange]
|
||||
);
|
||||
|
||||
const handleKeyDown = useCallback(
|
||||
(e: KeyboardEvent<HTMLInputElement>) => {
|
||||
|
@ -6,7 +6,7 @@ import {
|
||||
Tooltip,
|
||||
TooltipProps,
|
||||
} from '@chakra-ui/react';
|
||||
import { memo, MouseEvent } from 'react';
|
||||
import { memo, MouseEvent, useCallback } from 'react';
|
||||
import IAIOption from './IAIOption';
|
||||
|
||||
type IAISelectProps = SelectProps & {
|
||||
@ -33,15 +33,16 @@ const IAISelect = (props: IAISelectProps) => {
|
||||
spaceEvenly,
|
||||
...rest
|
||||
} = props;
|
||||
const handleClick = useCallback((e: MouseEvent<HTMLDivElement>) => {
|
||||
e.stopPropagation();
|
||||
e.nativeEvent.stopImmediatePropagation();
|
||||
e.nativeEvent.stopPropagation();
|
||||
e.nativeEvent.cancelBubble = true;
|
||||
}, []);
|
||||
return (
|
||||
<FormControl
|
||||
isDisabled={isDisabled}
|
||||
onClick={(e: MouseEvent<HTMLDivElement>) => {
|
||||
e.stopPropagation();
|
||||
e.nativeEvent.stopImmediatePropagation();
|
||||
e.nativeEvent.stopPropagation();
|
||||
e.nativeEvent.cancelBubble = true;
|
||||
}}
|
||||
onClick={handleClick}
|
||||
sx={
|
||||
horizontal
|
||||
? {
|
||||
|
@ -186,6 +186,13 @@ const IAISlider = forwardRef((props: IAIFullSliderProps, ref) => {
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
const handleMouseEnter = useCallback(() => setShowTooltip(true), []);
|
||||
const handleMouseLeave = useCallback(() => setShowTooltip(false), []);
|
||||
const handleStepperClick = useCallback(
|
||||
() => onChange(Number(localInputValue)),
|
||||
[localInputValue, onChange]
|
||||
);
|
||||
|
||||
return (
|
||||
<FormControl
|
||||
ref={ref}
|
||||
@ -219,8 +226,8 @@ const IAISlider = forwardRef((props: IAIFullSliderProps, ref) => {
|
||||
max={max}
|
||||
step={step}
|
||||
onChange={handleSliderChange}
|
||||
onMouseEnter={() => setShowTooltip(true)}
|
||||
onMouseLeave={() => setShowTooltip(false)}
|
||||
onMouseEnter={handleMouseEnter}
|
||||
onMouseLeave={handleMouseLeave}
|
||||
focusThumbOnChange={false}
|
||||
isDisabled={isDisabled}
|
||||
{...rest}
|
||||
@ -332,12 +339,8 @@ const IAISlider = forwardRef((props: IAIFullSliderProps, ref) => {
|
||||
{...sliderNumberInputFieldProps}
|
||||
/>
|
||||
<NumberInputStepper {...sliderNumberInputStepperProps}>
|
||||
<NumberIncrementStepper
|
||||
onClick={() => onChange(Number(localInputValue))}
|
||||
/>
|
||||
<NumberDecrementStepper
|
||||
onClick={() => onChange(Number(localInputValue))}
|
||||
/>
|
||||
<NumberIncrementStepper onClick={handleStepperClick} />
|
||||
<NumberDecrementStepper onClick={handleStepperClick} />
|
||||
</NumberInputStepper>
|
||||
</NumberInput>
|
||||
)}
|
||||
|
@ -146,16 +146,15 @@ const ImageUploader = (props: ImageUploaderProps) => {
|
||||
};
|
||||
}, [inputRef]);
|
||||
|
||||
const handleKeyDown = useCallback((e: KeyboardEvent) => {
|
||||
// Bail out if user hits spacebar - do not open the uploader
|
||||
if (e.key === ' ') {
|
||||
return;
|
||||
}
|
||||
}, []);
|
||||
|
||||
return (
|
||||
<Box
|
||||
{...getRootProps({ style: {} })}
|
||||
onKeyDown={(e: KeyboardEvent) => {
|
||||
// Bail out if user hits spacebar - do not open the uploader
|
||||
if (e.key === ' ') {
|
||||
return;
|
||||
}
|
||||
}}
|
||||
>
|
||||
<Box {...getRootProps({ style: {} })} onKeyDown={handleKeyDown}>
|
||||
<input {...getInputProps()} />
|
||||
{children}
|
||||
<AnimatePresence>
|
||||
|
@ -1,23 +0,0 @@
|
||||
import { Flex, Icon } from '@chakra-ui/react';
|
||||
import { memo } from 'react';
|
||||
import { FaImage } from 'react-icons/fa';
|
||||
|
||||
const SelectImagePlaceholder = () => {
|
||||
return (
|
||||
<Flex
|
||||
sx={{
|
||||
w: 'full',
|
||||
h: 'full',
|
||||
// bg: 'base.800',
|
||||
borderRadius: 'base',
|
||||
alignItems: 'center',
|
||||
justifyContent: 'center',
|
||||
aspectRatio: '1/1',
|
||||
}}
|
||||
>
|
||||
<Icon color="base.400" boxSize={32} as={FaImage}></Icon>
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(SelectImagePlaceholder);
|
@ -1,24 +0,0 @@
|
||||
import { useBreakpoint } from '@chakra-ui/react';
|
||||
|
||||
export default function useResolution():
|
||||
| 'mobile'
|
||||
| 'tablet'
|
||||
| 'desktop'
|
||||
| 'unknown' {
|
||||
const breakpointValue = useBreakpoint();
|
||||
|
||||
const mobileResolutions = ['base', 'sm'];
|
||||
const tabletResolutions = ['md', 'lg'];
|
||||
const desktopResolutions = ['xl', '2xl'];
|
||||
|
||||
if (mobileResolutions.includes(breakpointValue)) {
|
||||
return 'mobile';
|
||||
}
|
||||
if (tabletResolutions.includes(breakpointValue)) {
|
||||
return 'tablet';
|
||||
}
|
||||
if (desktopResolutions.includes(breakpointValue)) {
|
||||
return 'desktop';
|
||||
}
|
||||
return 'unknown';
|
||||
}
|
@ -1,7 +0,0 @@
|
||||
import dateFormat from 'dateformat';
|
||||
|
||||
/**
|
||||
* Get a `now` timestamp with 1s precision, formatted as ISO datetime.
|
||||
*/
|
||||
export const getTimestamp = () =>
|
||||
dateFormat(new Date(), `yyyy-mm-dd'T'HH:MM:ss:lo`);
|
@ -1,71 +0,0 @@
|
||||
// TODO: Restore variations
|
||||
// Support code from v2.3 in here.
|
||||
|
||||
// export const stringToSeedWeights = (
|
||||
// string: string
|
||||
// ): InvokeAI.SeedWeights | boolean => {
|
||||
// const stringPairs = string.split(',');
|
||||
// const arrPairs = stringPairs.map((p) => p.split(':'));
|
||||
// const pairs = arrPairs.map((p: Array<string>): InvokeAI.SeedWeightPair => {
|
||||
// return { seed: Number(p[0]), weight: Number(p[1]) };
|
||||
// });
|
||||
|
||||
// if (!validateSeedWeights(pairs)) {
|
||||
// return false;
|
||||
// }
|
||||
|
||||
// return pairs;
|
||||
// };
|
||||
|
||||
// export const validateSeedWeights = (
|
||||
// seedWeights: InvokeAI.SeedWeights | string
|
||||
// ): boolean => {
|
||||
// return typeof seedWeights === 'string'
|
||||
// ? Boolean(stringToSeedWeights(seedWeights))
|
||||
// : Boolean(
|
||||
// seedWeights.length &&
|
||||
// !seedWeights.some((pair: InvokeAI.SeedWeightPair) => {
|
||||
// const { seed, weight } = pair;
|
||||
// const isSeedValid = !isNaN(parseInt(seed.toString(), 10));
|
||||
// const isWeightValid =
|
||||
// !isNaN(parseInt(weight.toString(), 10)) &&
|
||||
// weight >= 0 &&
|
||||
// weight <= 1;
|
||||
// return !(isSeedValid && isWeightValid);
|
||||
// })
|
||||
// );
|
||||
// };
|
||||
|
||||
// export const seedWeightsToString = (
|
||||
// seedWeights: InvokeAI.SeedWeights
|
||||
// ): string => {
|
||||
// return seedWeights.reduce((acc, pair, i, arr) => {
|
||||
// const { seed, weight } = pair;
|
||||
// acc += `${seed}:${weight}`;
|
||||
// if (i !== arr.length - 1) {
|
||||
// acc += ',';
|
||||
// }
|
||||
// return acc;
|
||||
// }, '');
|
||||
// };
|
||||
|
||||
// export const seedWeightsToArray = (
|
||||
// seedWeights: InvokeAI.SeedWeights
|
||||
// ): Array<Array<number>> => {
|
||||
// return seedWeights.map((pair: InvokeAI.SeedWeightPair) => [
|
||||
// pair.seed,
|
||||
// pair.weight,
|
||||
// ]);
|
||||
// };
|
||||
|
||||
// export const stringToSeedWeightsArray = (
|
||||
// string: string
|
||||
// ): Array<Array<number>> => {
|
||||
// const stringPairs = string.split(',');
|
||||
// const arrPairs = stringPairs.map((p) => p.split(':'));
|
||||
// return arrPairs.map(
|
||||
// (p: Array<string>): Array<number> => [parseInt(p[0], 10), parseFloat(p[1])]
|
||||
// );
|
||||
// };
|
||||
|
||||
export default {};
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user