resolve conflicts; blackify

This commit is contained in:
Lincoln Stein 2023-11-13 18:12:45 -05:00
commit 38c1436f02
258 changed files with 1794 additions and 4022 deletions

View File

@ -1,20 +0,0 @@
on:
pull_request:
push:
branches:
- main
- development
- 'release-candidate-*'
jobs:
pyflakes:
name: runner / pyflakes
if: github.event.pull_request.draft == false
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: pyflakes
uses: reviewdog/action-pyflakes@v1
with:
github_token: ${{ secrets.GITHUB_TOKEN }}
reporter: github-pr-review

View File

@ -18,8 +18,7 @@ jobs:
- name: Install dependencies with pip
run: |
pip install black flake8 Flake8-pyproject isort
pip install ruff
- run: isort --check-only .
- run: black --check .
- run: flake8
- run: ruff check --output-format=github .
- run: ruff format --check .

View File

@ -137,7 +137,7 @@ def dest_path(dest=None) -> Path:
path_completer = PathCompleter(
only_directories=True,
expanduser=True,
get_paths=lambda: [browse_start],
get_paths=lambda: [browse_start], # noqa: B023
# get_paths=lambda: [".."].extend(list(browse_start.iterdir()))
)
@ -149,7 +149,7 @@ def dest_path(dest=None) -> Path:
completer=path_completer,
default=str(browse_start) + os.sep,
vi_mode=True,
complete_while_typing=True
complete_while_typing=True,
# Test that this is not needed on Windows
# complete_style=CompleteStyle.READLINE_LIKE,
)

View File

@ -28,7 +28,7 @@ class FastAPIEventService(EventServiceBase):
self.__queue.put(None)
def dispatch(self, event_name: str, payload: Any) -> None:
self.__queue.put(dict(event_name=event_name, payload=payload))
self.__queue.put({"event_name": event_name, "payload": payload})
async def __dispatch_from_queue(self, stop_event: threading.Event):
"""Get events on from the queue and dispatch them, from the correct thread"""

View File

@ -8,12 +8,20 @@ from typing import List, Optional
from fastapi import Body, Path, Query, Response
from fastapi.routing import APIRouter
from pydantic import BaseModel, ConfigDict, TypeAdapter
from pydantic import BaseModel, ConfigDict
from starlette.exceptions import HTTPException
from typing_extensions import Annotated
from invokeai.app.services.model_records import DuplicateModelException, InvalidModelException, UnknownModelException
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelType
from invokeai.app.services.model_records import (
DuplicateModelException,
InvalidModelException,
UnknownModelException,
)
from invokeai.backend.model_manager.config import (
AnyModelConfig,
BaseModelType,
ModelType,
)
from ..dependencies import ApiDependencies
@ -28,26 +36,32 @@ class ModelsList(BaseModel):
model_config = ConfigDict(use_enum_values=True)
ModelsListValidator = TypeAdapter(ModelsList)
@model_records_router.get(
"/",
operation_id="list_model_records",
)
async def list_model_records(
base_models: Optional[List[BaseModelType]] = Query(default=None, description="Base models to include"),
model_type: Optional[ModelType] = Query(default=None, description="The type of model to get"),
base_models: Optional[List[BaseModelType]] = Query(
default=None, description="Base models to include"
),
model_type: Optional[ModelType] = Query(
default=None, description="The type of model to get"
),
) -> ModelsList:
"""Get a list of models."""
record_store = ApiDependencies.invoker.services.model_records
models = list()
found_models: list[AnyModelConfig] = []
if base_models:
for base_model in base_models:
models.extend(record_store.search_by_attr(base_model=base_model, model_type=model_type))
found_models.extend(
record_store.search_by_attr(
base_model=base_model, model_type=model_type
)
)
else:
models.extend(record_store.search_by_attr(model_type=model_type))
return ModelsList(models=models)
found_models.extend(record_store.search_by_attr(model_type=model_type))
return ModelsList(models=found_models)
@model_records_router.get(
"/i/{key}",
@ -83,13 +97,16 @@ async def get_model_record(
)
async def update_model_record(
key: Annotated[str, Path(description="Unique key of model")],
info: Annotated[AnyModelConfig, Body(description="Model config", discriminator="type")],
info: Annotated[
AnyModelConfig, Body(description="Model config", discriminator="type")
],
) -> AnyModelConfig:
"""Update model contents with a new config. If the model name or base fields are changed, then the model is renamed."""
logger = ApiDependencies.invoker.services.logger
record_store = ApiDependencies.invoker.services.model_records
try:
model_response = record_store.update_model(key, config=info)
logger.info(f"Updated model: {key}")
except UnknownModelException as e:
raise HTTPException(status_code=404, detail=str(e))
except ValueError as e:
@ -101,7 +118,10 @@ async def update_model_record(
@model_records_router.delete(
"/i/{key}",
operation_id="del_model_record",
responses={204: {"description": "Model deleted successfully"}, 404: {"description": "Model not found"}},
responses={
204: {"description": "Model deleted successfully"},
404: {"description": "Model not found"},
},
status_code=204,
)
async def del_model_record(
@ -125,13 +145,17 @@ async def del_model_record(
operation_id="add_model_record",
responses={
201: {"description": "The model added successfully"},
409: {"description": "There is already a model corresponding to this path or repo_id"},
409: {
"description": "There is already a model corresponding to this path or repo_id"
},
415: {"description": "Unrecognized file/folder format"},
},
status_code=201,
)
async def add_model_record(
config: Annotated[AnyModelConfig, Body(description="Model config", discriminator="type")]
config: Annotated[
AnyModelConfig, Body(description="Model config", discriminator="type")
]
) -> AnyModelConfig:
"""
Add a model using the configuration information appropriate for its type.

View File

@ -54,7 +54,7 @@ async def list_models(
) -> ModelsList:
"""Gets a list of models"""
if base_models and len(base_models) > 0:
models_raw = list()
models_raw = []
for base_model in base_models:
models_raw.extend(ApiDependencies.invoker.services.model_manager.list_models(base_model, model_type))
else:

View File

@ -132,7 +132,7 @@ def custom_openapi() -> dict[str, Any]:
# Add all outputs
all_invocations = BaseInvocation.get_invocations()
output_types = set()
output_type_titles = dict()
output_type_titles = {}
for invoker in all_invocations:
output_type = signature(invoker.invoke).return_annotation
output_types.add(output_type)
@ -173,12 +173,12 @@ def custom_openapi() -> dict[str, Any]:
# print(f"Config with name {name} already defined")
continue
openapi_schema["components"]["schemas"][name] = dict(
title=name,
description="An enumeration.",
type="string",
enum=list(v.value for v in model_config_format_enum),
)
openapi_schema["components"]["schemas"][name] = {
"title": name,
"description": "An enumeration.",
"type": "string",
"enum": [v.value for v in model_config_format_enum],
}
app.openapi_schema = openapi_schema
return app.openapi_schema

View File

@ -25,4 +25,4 @@ spec.loader.exec_module(module)
# add core nodes to __all__
python_files = filter(lambda f: not f.name.startswith("_"), Path(__file__).parent.glob("*.py"))
__all__ = list(f.stem for f in python_files) # type: ignore
__all__ = [f.stem for f in python_files] # type: ignore

View File

@ -236,35 +236,35 @@ def InputField(
Ignored for non-collection fields.
"""
json_schema_extra_: dict[str, Any] = dict(
input=input,
ui_type=ui_type,
ui_component=ui_component,
ui_hidden=ui_hidden,
ui_order=ui_order,
item_default=item_default,
ui_choice_labels=ui_choice_labels,
_field_kind="input",
)
json_schema_extra_: dict[str, Any] = {
"input": input,
"ui_type": ui_type,
"ui_component": ui_component,
"ui_hidden": ui_hidden,
"ui_order": ui_order,
"item_default": item_default,
"ui_choice_labels": ui_choice_labels,
"_field_kind": "input",
}
field_args = dict(
default=default,
default_factory=default_factory,
title=title,
description=description,
pattern=pattern,
strict=strict,
gt=gt,
ge=ge,
lt=lt,
le=le,
multiple_of=multiple_of,
allow_inf_nan=allow_inf_nan,
max_digits=max_digits,
decimal_places=decimal_places,
min_length=min_length,
max_length=max_length,
)
field_args = {
"default": default,
"default_factory": default_factory,
"title": title,
"description": description,
"pattern": pattern,
"strict": strict,
"gt": gt,
"ge": ge,
"lt": lt,
"le": le,
"multiple_of": multiple_of,
"allow_inf_nan": allow_inf_nan,
"max_digits": max_digits,
"decimal_places": decimal_places,
"min_length": min_length,
"max_length": max_length,
}
"""
Invocation definitions have their fields typed correctly for their `invoke()` functions.
@ -299,24 +299,24 @@ def InputField(
# because we are manually making fields optional, we need to store the original required bool for reference later
if default is PydanticUndefined and default_factory is PydanticUndefined:
json_schema_extra_.update(dict(orig_required=True))
json_schema_extra_.update({"orig_required": True})
else:
json_schema_extra_.update(dict(orig_required=False))
json_schema_extra_.update({"orig_required": False})
# make Input.Any and Input.Connection fields optional, providing None as a default if the field doesn't already have one
if (input is Input.Any or input is Input.Connection) and default_factory is PydanticUndefined:
default_ = None if default is PydanticUndefined else default
provided_args.update(dict(default=default_))
provided_args.update({"default": default_})
if default is not PydanticUndefined:
# before invoking, we'll grab the original default value and set it on the field if the field wasn't provided a value
json_schema_extra_.update(dict(default=default))
json_schema_extra_.update(dict(orig_default=default))
json_schema_extra_.update({"default": default})
json_schema_extra_.update({"orig_default": default})
elif default is not PydanticUndefined and default_factory is PydanticUndefined:
default_ = default
provided_args.update(dict(default=default_))
json_schema_extra_.update(dict(orig_default=default_))
provided_args.update({"default": default_})
json_schema_extra_.update({"orig_default": default_})
elif default_factory is not PydanticUndefined:
provided_args.update(dict(default_factory=default_factory))
provided_args.update({"default_factory": default_factory})
# TODO: cannot serialize default_factory...
# json_schema_extra_.update(dict(orig_default_factory=default_factory))
@ -383,12 +383,12 @@ def OutputField(
decimal_places=decimal_places,
min_length=min_length,
max_length=max_length,
json_schema_extra=dict(
ui_type=ui_type,
ui_hidden=ui_hidden,
ui_order=ui_order,
_field_kind="output",
),
json_schema_extra={
"ui_type": ui_type,
"ui_hidden": ui_hidden,
"ui_order": ui_order,
"_field_kind": "output",
},
)
@ -460,14 +460,14 @@ class BaseInvocationOutput(BaseModel):
@classmethod
def get_output_types(cls) -> Iterable[str]:
return map(lambda i: get_type(i), BaseInvocationOutput.get_outputs())
return (get_type(i) for i in BaseInvocationOutput.get_outputs())
@staticmethod
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None:
# Because we use a pydantic Literal field with default value for the invocation type,
# it will be typed as optional in the OpenAPI schema. Make it required manually.
if "required" not in schema or not isinstance(schema["required"], list):
schema["required"] = list()
schema["required"] = []
schema["required"].extend(["type"])
model_config = ConfigDict(
@ -527,16 +527,11 @@ class BaseInvocation(ABC, BaseModel):
@classmethod
def get_invocations_map(cls) -> dict[str, BaseInvocation]:
# Get the type strings out of the literals and into a dictionary
return dict(
map(
lambda i: (get_type(i), i),
BaseInvocation.get_invocations(),
)
)
return {get_type(i): i for i in BaseInvocation.get_invocations()}
@classmethod
def get_invocation_types(cls) -> Iterable[str]:
return map(lambda i: get_type(i), BaseInvocation.get_invocations())
return (get_type(i) for i in BaseInvocation.get_invocations())
@classmethod
def get_output_type(cls) -> BaseInvocationOutput:
@ -555,7 +550,7 @@ class BaseInvocation(ABC, BaseModel):
if uiconfig and hasattr(uiconfig, "version"):
schema["version"] = uiconfig.version
if "required" not in schema or not isinstance(schema["required"], list):
schema["required"] = list()
schema["required"] = []
schema["required"].extend(["type", "id"])
@abstractmethod
@ -609,15 +604,15 @@ class BaseInvocation(ABC, BaseModel):
id: str = Field(
default_factory=uuid_string,
description="The id of this instance of an invocation. Must be unique among all instances of invocations.",
json_schema_extra=dict(_field_kind="internal"),
json_schema_extra={"_field_kind": "internal"},
)
is_intermediate: bool = Field(
default=False,
description="Whether or not this is an intermediate invocation.",
json_schema_extra=dict(ui_type=UIType.IsIntermediate, _field_kind="internal"),
json_schema_extra={"ui_type": UIType.IsIntermediate, "_field_kind": "internal"},
)
use_cache: bool = Field(
default=True, description="Whether or not to use the cache", json_schema_extra=dict(_field_kind="internal")
default=True, description="Whether or not to use the cache", json_schema_extra={"_field_kind": "internal"}
)
UIConfig: ClassVar[Type[UIConfigBase]]
@ -651,7 +646,7 @@ class _Model(BaseModel):
# Get all pydantic model attrs, methods, etc
RESERVED_PYDANTIC_FIELD_NAMES = set(map(lambda m: m[0], inspect.getmembers(_Model())))
RESERVED_PYDANTIC_FIELD_NAMES = {m[0] for m in inspect.getmembers(_Model())}
def validate_fields(model_fields: dict[str, FieldInfo], model_type: str) -> None:
@ -666,9 +661,7 @@ def validate_fields(model_fields: dict[str, FieldInfo], model_type: str) -> None
field_kind = (
# _field_kind is defined via InputField(), OutputField() or by one of the internal fields defined in this file
field.json_schema_extra.get("_field_kind", None)
if field.json_schema_extra
else None
field.json_schema_extra.get("_field_kind", None) if field.json_schema_extra else None
)
# must have a field_kind
@ -729,7 +722,7 @@ def invocation(
# Add OpenAPI schema extras
uiconf_name = cls.__qualname__ + ".UIConfig"
if not hasattr(cls, "UIConfig") or cls.UIConfig.__qualname__ != uiconf_name:
cls.UIConfig = type(uiconf_name, (UIConfigBase,), dict())
cls.UIConfig = type(uiconf_name, (UIConfigBase,), {})
if title is not None:
cls.UIConfig.title = title
if tags is not None:
@ -756,7 +749,7 @@ def invocation(
invocation_type_annotation = Literal[invocation_type] # type: ignore
invocation_type_field = Field(
title="type", default=invocation_type, json_schema_extra=dict(_field_kind="internal")
title="type", default=invocation_type, json_schema_extra={"_field_kind": "internal"}
)
docstring = cls.__doc__
@ -802,7 +795,7 @@ def invocation_output(
# Add the output type to the model.
output_type_annotation = Literal[output_type] # type: ignore
output_type_field = Field(title="type", default=output_type, json_schema_extra=dict(_field_kind="internal"))
output_type_field = Field(title="type", default=output_type, json_schema_extra={"_field_kind": "internal"})
docstring = cls.__doc__
cls = create_model(
@ -834,7 +827,7 @@ WorkflowFieldValidator = TypeAdapter(WorkflowField)
class WithWorkflow(BaseModel):
workflow: Optional[WorkflowField] = Field(
default=None, description=FieldDescriptions.workflow, json_schema_extra=dict(_field_kind="internal")
default=None, description=FieldDescriptions.workflow, json_schema_extra={"_field_kind": "internal"}
)
@ -852,5 +845,5 @@ MetadataFieldValidator = TypeAdapter(MetadataField)
class WithMetadata(BaseModel):
metadata: Optional[MetadataField] = Field(
default=None, description=FieldDescriptions.metadata, json_schema_extra=dict(_field_kind="internal")
default=None, description=FieldDescriptions.metadata, json_schema_extra={"_field_kind": "internal"}
)

View File

@ -131,7 +131,7 @@ def prepare_faces_list(
deduped_faces: list[FaceResultData] = []
if len(face_result_list) == 0:
return list()
return []
for candidate in face_result_list:
should_add = True
@ -210,7 +210,7 @@ def generate_face_box_mask(
# Check if any face is detected.
if results.multi_face_landmarks: # type: ignore # this are via protobuf and not typed
# Search for the face_id in the detected faces.
for face_id, face_landmarks in enumerate(results.multi_face_landmarks): # type: ignore #this are via protobuf and not typed
for _face_id, face_landmarks in enumerate(results.multi_face_landmarks): # type: ignore #this are via protobuf and not typed
# Get the bounding box of the face mesh.
x_coordinates = [landmark.x for landmark in face_landmarks.landmark]
y_coordinates = [landmark.y for landmark in face_landmarks.landmark]

View File

@ -77,7 +77,7 @@ if choose_torch_device() == torch.device("mps"):
DEFAULT_PRECISION = choose_precision(choose_torch_device())
SAMPLER_NAME_VALUES = Literal[tuple(list(SCHEDULER_MAP.keys()))]
SAMPLER_NAME_VALUES = Literal[tuple(SCHEDULER_MAP.keys())]
@invocation_output("scheduler_output")
@ -1105,7 +1105,7 @@ class BlendLatentsInvocation(BaseInvocation):
latents_b = context.services.latents.get(self.latents_b.latents_name)
if latents_a.shape != latents_b.shape:
raise "Latents to blend must be the same size."
raise Exception("Latents to blend must be the same size.")
# TODO:
device = choose_torch_device()

View File

@ -145,17 +145,17 @@ INTEGER_OPERATIONS = Literal[
]
INTEGER_OPERATIONS_LABELS = dict(
ADD="Add A+B",
SUB="Subtract A-B",
MUL="Multiply A*B",
DIV="Divide A/B",
EXP="Exponentiate A^B",
MOD="Modulus A%B",
ABS="Absolute Value of A",
MIN="Minimum(A,B)",
MAX="Maximum(A,B)",
)
INTEGER_OPERATIONS_LABELS = {
"ADD": "Add A+B",
"SUB": "Subtract A-B",
"MUL": "Multiply A*B",
"DIV": "Divide A/B",
"EXP": "Exponentiate A^B",
"MOD": "Modulus A%B",
"ABS": "Absolute Value of A",
"MIN": "Minimum(A,B)",
"MAX": "Maximum(A,B)",
}
@invocation(
@ -231,17 +231,17 @@ FLOAT_OPERATIONS = Literal[
]
FLOAT_OPERATIONS_LABELS = dict(
ADD="Add A+B",
SUB="Subtract A-B",
MUL="Multiply A*B",
DIV="Divide A/B",
EXP="Exponentiate A^B",
ABS="Absolute Value of A",
SQRT="Square Root of A",
MIN="Minimum(A,B)",
MAX="Maximum(A,B)",
)
FLOAT_OPERATIONS_LABELS = {
"ADD": "Add A+B",
"SUB": "Subtract A-B",
"MUL": "Multiply A*B",
"DIV": "Divide A/B",
"EXP": "Exponentiate A^B",
"ABS": "Absolute Value of A",
"SQRT": "Square Root of A",
"MIN": "Minimum(A,B)",
"MAX": "Maximum(A,B)",
}
@invocation(
@ -266,7 +266,7 @@ class FloatMathInvocation(BaseInvocation):
raise ValueError("Cannot divide by zero")
elif info.data["operation"] == "EXP" and info.data["a"] == 0 and v < 0:
raise ValueError("Cannot raise zero to a negative power")
elif info.data["operation"] == "EXP" and type(info.data["a"] ** v) is complex:
elif info.data["operation"] == "EXP" and isinstance(info.data["a"] ** v, complex):
raise ValueError("Root operation resulted in a complex number")
return v

View File

@ -54,7 +54,7 @@ ORT_TO_NP_TYPE = {
"tensor(double)": np.float64,
}
PRECISION_VALUES = Literal[tuple(list(ORT_TO_NP_TYPE.keys()))]
PRECISION_VALUES = Literal[tuple(ORT_TO_NP_TYPE.keys())]
@invocation("prompt_onnx", title="ONNX Prompt (Raw)", tags=["prompt", "onnx"], category="conditioning", version="1.0.0")
@ -252,7 +252,7 @@ class ONNXTextToLatentsInvocation(BaseInvocation):
scheduler.set_timesteps(self.steps)
latents = latents * np.float64(scheduler.init_noise_sigma)
extra_step_kwargs = dict()
extra_step_kwargs = {}
if "eta" in set(inspect.signature(scheduler.step).parameters.keys()):
extra_step_kwargs.update(
eta=0.0,

View File

@ -100,7 +100,7 @@ EASING_FUNCTIONS_MAP = {
"BounceInOut": BounceEaseInOut,
}
EASING_FUNCTION_KEYS = Literal[tuple(list(EASING_FUNCTIONS_MAP.keys()))]
EASING_FUNCTION_KEYS = Literal[tuple(EASING_FUNCTIONS_MAP.keys())]
# actually I think for now could just use CollectionOutput (which is list[Any]
@ -161,7 +161,7 @@ class StepParamEasingInvocation(BaseInvocation):
easing_class = EASING_FUNCTIONS_MAP[self.easing]
if log_diagnostics:
context.services.logger.debug("easing class: " + str(easing_class))
easing_list = list()
easing_list = []
if self.mirror: # "expected" mirroring
# if number of steps is even, squeeze duration down to (number_of_steps)/2
# and create reverse copy of list to append
@ -178,7 +178,7 @@ class StepParamEasingInvocation(BaseInvocation):
end=self.end_value,
duration=base_easing_duration - 1,
)
base_easing_vals = list()
base_easing_vals = []
for step_index in range(base_easing_duration):
easing_val = easing_function.ease(step_index)
base_easing_vals.append(easing_val)

View File

@ -139,7 +139,7 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
(board_id,),
)
result = cast(list[sqlite3.Row], self._cursor.fetchall())
images = list(map(lambda r: deserialize_image_record(dict(r)), result))
images = [deserialize_image_record(dict(r)) for r in result]
self._cursor.execute(
"""--sql
@ -167,7 +167,7 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
(board_id,),
)
result = cast(list[sqlite3.Row], self._cursor.fetchall())
image_names = list(map(lambda r: r[0], result))
image_names = [r[0] for r in result]
return image_names
except sqlite3.Error as e:
self._conn.rollback()

View File

@ -199,7 +199,7 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
)
result = cast(list[sqlite3.Row], self._cursor.fetchall())
boards = list(map(lambda r: deserialize_board_record(dict(r)), result))
boards = [deserialize_board_record(dict(r)) for r in result]
# Get the total number of boards
self._cursor.execute(
@ -236,7 +236,7 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
)
result = cast(list[sqlite3.Row], self._cursor.fetchall())
boards = list(map(lambda r: deserialize_board_record(dict(r)), result))
boards = [deserialize_board_record(dict(r)) for r in result]
return boards

View File

@ -55,7 +55,7 @@ class InvokeAISettings(BaseSettings):
"""
cls = self.__class__
type = get_args(get_type_hints(cls)["type"])[0]
field_dict = dict({type: dict()})
field_dict = {type: {}}
for name, field in self.model_fields.items():
if name in cls._excluded_from_yaml():
continue
@ -64,7 +64,7 @@ class InvokeAISettings(BaseSettings):
)
value = getattr(self, name)
if category not in field_dict[type]:
field_dict[type][category] = dict()
field_dict[type][category] = {}
# keep paths as strings to make it easier to read
field_dict[type][category][name] = str(value) if isinstance(value, Path) else value
conf = OmegaConf.create(field_dict)
@ -89,7 +89,7 @@ class InvokeAISettings(BaseSettings):
# create an upcase version of the environment in
# order to achieve case-insensitive environment
# variables (the way Windows does)
upcase_environ = dict()
upcase_environ = {}
for key, value in os.environ.items():
upcase_environ[key.upper()] = value

View File

@ -188,18 +188,18 @@ DEFAULT_MAX_VRAM = 0.5
class Categories(object):
WebServer = dict(category="Web Server")
Features = dict(category="Features")
Paths = dict(category="Paths")
Logging = dict(category="Logging")
Development = dict(category="Development")
Other = dict(category="Other")
ModelCache = dict(category="Model Cache")
Device = dict(category="Device")
Generation = dict(category="Generation")
Queue = dict(category="Queue")
Nodes = dict(category="Nodes")
MemoryPerformance = dict(category="Memory/Performance")
WebServer = {"category": "Web Server"}
Features = {"category": "Features"}
Paths = {"category": "Paths"}
Logging = {"category": "Logging"}
Development = {"category": "Development"}
Other = {"category": "Other"}
ModelCache = {"category": "Model Cache"}
Device = {"category": "Device"}
Generation = {"category": "Generation"}
Queue = {"category": "Queue"}
Nodes = {"category": "Nodes"}
MemoryPerformance = {"category": "Memory/Performance"}
class InvokeAIAppConfig(InvokeAISettings):
@ -482,7 +482,7 @@ def _find_root() -> Path:
venv = Path(os.environ.get("VIRTUAL_ENV") or ".")
if os.environ.get("INVOKEAI_ROOT"):
root = Path(os.environ["INVOKEAI_ROOT"])
elif any([(venv.parent / x).exists() for x in [INIT_FILE, LEGACY_INIT_FILE]]):
elif any((venv.parent / x).exists() for x in [INIT_FILE, LEGACY_INIT_FILE]):
root = (venv.parent).resolve()
else:
root = Path("~/invokeai").expanduser().resolve()

View File

@ -27,7 +27,7 @@ class EventServiceBase:
payload["timestamp"] = get_timestamp()
self.dispatch(
event_name=EventServiceBase.queue_event,
payload=dict(event=event_name, data=payload),
payload={"event": event_name, "data": payload},
)
# Define events here for every event in the system.
@ -48,18 +48,18 @@ class EventServiceBase:
"""Emitted when there is generation progress"""
self.__emit_queue_event(
event_name="generator_progress",
payload=dict(
queue_id=queue_id,
queue_item_id=queue_item_id,
queue_batch_id=queue_batch_id,
graph_execution_state_id=graph_execution_state_id,
node_id=node.get("id"),
source_node_id=source_node_id,
progress_image=progress_image.model_dump() if progress_image is not None else None,
step=step,
order=order,
total_steps=total_steps,
),
payload={
"queue_id": queue_id,
"queue_item_id": queue_item_id,
"queue_batch_id": queue_batch_id,
"graph_execution_state_id": graph_execution_state_id,
"node_id": node.get("id"),
"source_node_id": source_node_id,
"progress_image": progress_image.model_dump() if progress_image is not None else None,
"step": step,
"order": order,
"total_steps": total_steps,
},
)
def emit_invocation_complete(
@ -75,15 +75,15 @@ class EventServiceBase:
"""Emitted when an invocation has completed"""
self.__emit_queue_event(
event_name="invocation_complete",
payload=dict(
queue_id=queue_id,
queue_item_id=queue_item_id,
queue_batch_id=queue_batch_id,
graph_execution_state_id=graph_execution_state_id,
node=node,
source_node_id=source_node_id,
result=result,
),
payload={
"queue_id": queue_id,
"queue_item_id": queue_item_id,
"queue_batch_id": queue_batch_id,
"graph_execution_state_id": graph_execution_state_id,
"node": node,
"source_node_id": source_node_id,
"result": result,
},
)
def emit_invocation_error(
@ -100,16 +100,16 @@ class EventServiceBase:
"""Emitted when an invocation has completed"""
self.__emit_queue_event(
event_name="invocation_error",
payload=dict(
queue_id=queue_id,
queue_item_id=queue_item_id,
queue_batch_id=queue_batch_id,
graph_execution_state_id=graph_execution_state_id,
node=node,
source_node_id=source_node_id,
error_type=error_type,
error=error,
),
payload={
"queue_id": queue_id,
"queue_item_id": queue_item_id,
"queue_batch_id": queue_batch_id,
"graph_execution_state_id": graph_execution_state_id,
"node": node,
"source_node_id": source_node_id,
"error_type": error_type,
"error": error,
},
)
def emit_invocation_started(
@ -124,14 +124,14 @@ class EventServiceBase:
"""Emitted when an invocation has started"""
self.__emit_queue_event(
event_name="invocation_started",
payload=dict(
queue_id=queue_id,
queue_item_id=queue_item_id,
queue_batch_id=queue_batch_id,
graph_execution_state_id=graph_execution_state_id,
node=node,
source_node_id=source_node_id,
),
payload={
"queue_id": queue_id,
"queue_item_id": queue_item_id,
"queue_batch_id": queue_batch_id,
"graph_execution_state_id": graph_execution_state_id,
"node": node,
"source_node_id": source_node_id,
},
)
def emit_graph_execution_complete(
@ -140,12 +140,12 @@ class EventServiceBase:
"""Emitted when a session has completed all invocations"""
self.__emit_queue_event(
event_name="graph_execution_state_complete",
payload=dict(
queue_id=queue_id,
queue_item_id=queue_item_id,
queue_batch_id=queue_batch_id,
graph_execution_state_id=graph_execution_state_id,
),
payload={
"queue_id": queue_id,
"queue_item_id": queue_item_id,
"queue_batch_id": queue_batch_id,
"graph_execution_state_id": graph_execution_state_id,
},
)
def emit_model_load_started(
@ -162,16 +162,16 @@ class EventServiceBase:
"""Emitted when a model is requested"""
self.__emit_queue_event(
event_name="model_load_started",
payload=dict(
queue_id=queue_id,
queue_item_id=queue_item_id,
queue_batch_id=queue_batch_id,
graph_execution_state_id=graph_execution_state_id,
model_name=model_name,
base_model=base_model,
model_type=model_type,
submodel=submodel,
),
payload={
"queue_id": queue_id,
"queue_item_id": queue_item_id,
"queue_batch_id": queue_batch_id,
"graph_execution_state_id": graph_execution_state_id,
"model_name": model_name,
"base_model": base_model,
"model_type": model_type,
"submodel": submodel,
},
)
def emit_model_load_completed(
@ -189,19 +189,19 @@ class EventServiceBase:
"""Emitted when a model is correctly loaded (returns model info)"""
self.__emit_queue_event(
event_name="model_load_completed",
payload=dict(
queue_id=queue_id,
queue_item_id=queue_item_id,
queue_batch_id=queue_batch_id,
graph_execution_state_id=graph_execution_state_id,
model_name=model_name,
base_model=base_model,
model_type=model_type,
submodel=submodel,
hash=model_info.hash,
location=str(model_info.location),
precision=str(model_info.precision),
),
payload={
"queue_id": queue_id,
"queue_item_id": queue_item_id,
"queue_batch_id": queue_batch_id,
"graph_execution_state_id": graph_execution_state_id,
"model_name": model_name,
"base_model": base_model,
"model_type": model_type,
"submodel": submodel,
"hash": model_info.hash,
"location": str(model_info.location),
"precision": str(model_info.precision),
},
)
def emit_session_retrieval_error(
@ -216,14 +216,14 @@ class EventServiceBase:
"""Emitted when session retrieval fails"""
self.__emit_queue_event(
event_name="session_retrieval_error",
payload=dict(
queue_id=queue_id,
queue_item_id=queue_item_id,
queue_batch_id=queue_batch_id,
graph_execution_state_id=graph_execution_state_id,
error_type=error_type,
error=error,
),
payload={
"queue_id": queue_id,
"queue_item_id": queue_item_id,
"queue_batch_id": queue_batch_id,
"graph_execution_state_id": graph_execution_state_id,
"error_type": error_type,
"error": error,
},
)
def emit_invocation_retrieval_error(
@ -239,15 +239,15 @@ class EventServiceBase:
"""Emitted when invocation retrieval fails"""
self.__emit_queue_event(
event_name="invocation_retrieval_error",
payload=dict(
queue_id=queue_id,
queue_item_id=queue_item_id,
queue_batch_id=queue_batch_id,
graph_execution_state_id=graph_execution_state_id,
node_id=node_id,
error_type=error_type,
error=error,
),
payload={
"queue_id": queue_id,
"queue_item_id": queue_item_id,
"queue_batch_id": queue_batch_id,
"graph_execution_state_id": graph_execution_state_id,
"node_id": node_id,
"error_type": error_type,
"error": error,
},
)
def emit_session_canceled(
@ -260,12 +260,12 @@ class EventServiceBase:
"""Emitted when a session is canceled"""
self.__emit_queue_event(
event_name="session_canceled",
payload=dict(
queue_id=queue_id,
queue_item_id=queue_item_id,
queue_batch_id=queue_batch_id,
graph_execution_state_id=graph_execution_state_id,
),
payload={
"queue_id": queue_id,
"queue_item_id": queue_item_id,
"queue_batch_id": queue_batch_id,
"graph_execution_state_id": graph_execution_state_id,
},
)
def emit_queue_item_status_changed(
@ -277,39 +277,39 @@ class EventServiceBase:
"""Emitted when a queue item's status changes"""
self.__emit_queue_event(
event_name="queue_item_status_changed",
payload=dict(
queue_id=queue_status.queue_id,
queue_item=dict(
queue_id=session_queue_item.queue_id,
item_id=session_queue_item.item_id,
status=session_queue_item.status,
batch_id=session_queue_item.batch_id,
session_id=session_queue_item.session_id,
error=session_queue_item.error,
created_at=str(session_queue_item.created_at) if session_queue_item.created_at else None,
updated_at=str(session_queue_item.updated_at) if session_queue_item.updated_at else None,
started_at=str(session_queue_item.started_at) if session_queue_item.started_at else None,
completed_at=str(session_queue_item.completed_at) if session_queue_item.completed_at else None,
),
batch_status=batch_status.model_dump(),
queue_status=queue_status.model_dump(),
),
payload={
"queue_id": queue_status.queue_id,
"queue_item": {
"queue_id": session_queue_item.queue_id,
"item_id": session_queue_item.item_id,
"status": session_queue_item.status,
"batch_id": session_queue_item.batch_id,
"session_id": session_queue_item.session_id,
"error": session_queue_item.error,
"created_at": str(session_queue_item.created_at) if session_queue_item.created_at else None,
"updated_at": str(session_queue_item.updated_at) if session_queue_item.updated_at else None,
"started_at": str(session_queue_item.started_at) if session_queue_item.started_at else None,
"completed_at": str(session_queue_item.completed_at) if session_queue_item.completed_at else None,
},
"batch_status": batch_status.model_dump(),
"queue_status": queue_status.model_dump(),
},
)
def emit_batch_enqueued(self, enqueue_result: EnqueueBatchResult) -> None:
"""Emitted when a batch is enqueued"""
self.__emit_queue_event(
event_name="batch_enqueued",
payload=dict(
queue_id=enqueue_result.queue_id,
batch_id=enqueue_result.batch.batch_id,
enqueued=enqueue_result.enqueued,
),
payload={
"queue_id": enqueue_result.queue_id,
"batch_id": enqueue_result.batch.batch_id,
"enqueued": enqueue_result.enqueued,
},
)
def emit_queue_cleared(self, queue_id: str) -> None:
"""Emitted when the queue is cleared"""
self.__emit_queue_event(
event_name="queue_cleared",
payload=dict(queue_id=queue_id),
payload={"queue_id": queue_id},
)

View File

@ -25,7 +25,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
__invoker: Invoker
def __init__(self, output_folder: Union[str, Path]):
self.__cache = dict()
self.__cache = {}
self.__cache_ids = Queue()
self.__max_cache_size = 10 # TODO: get this from config

View File

@ -90,25 +90,23 @@ class ImageRecordDeleteException(Exception):
IMAGE_DTO_COLS = ", ".join(
list(
map(
lambda c: "images." + c,
[
"image_name",
"image_origin",
"image_category",
"width",
"height",
"session_id",
"node_id",
"is_intermediate",
"created_at",
"updated_at",
"deleted_at",
"starred",
],
)
)
[
"images." + c
for c in [
"image_name",
"image_origin",
"image_category",
"width",
"height",
"session_id",
"node_id",
"is_intermediate",
"created_at",
"updated_at",
"deleted_at",
"starred",
]
]
)

View File

@ -263,7 +263,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
if categories is not None:
# Convert the enum values to unique list of strings
category_strings = list(map(lambda c: c.value, set(categories)))
category_strings = [c.value for c in set(categories)]
# Create the correct length of placeholders
placeholders = ",".join("?" * len(category_strings))
@ -307,7 +307,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
# Build the list of images, deserializing each row
self._cursor.execute(images_query, images_params)
result = cast(list[sqlite3.Row], self._cursor.fetchall())
images = list(map(lambda r: deserialize_image_record(dict(r)), result))
images = [deserialize_image_record(dict(r)) for r in result]
# Set up and execute the count query, without pagination
count_query += query_conditions + ";"
@ -386,7 +386,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
"""
)
result = cast(list[sqlite3.Row], self._cursor.fetchall())
image_names = list(map(lambda r: r[0], result))
image_names = [r[0] for r in result]
self._cursor.execute(
"""--sql
DELETE FROM images

View File

@ -21,8 +21,8 @@ class ImageServiceABC(ABC):
_on_deleted_callbacks: list[Callable[[str], None]]
def __init__(self) -> None:
self._on_changed_callbacks = list()
self._on_deleted_callbacks = list()
self._on_changed_callbacks = []
self._on_deleted_callbacks = []
def on_changed(self, on_changed: Callable[[ImageDTO], None]) -> None:
"""Register a callback for when an image is changed"""

View File

@ -217,18 +217,16 @@ class ImageService(ImageServiceABC):
board_id,
)
image_dtos = list(
map(
lambda r: image_record_to_dto(
image_record=r,
image_url=self.__invoker.services.urls.get_image_url(r.image_name),
thumbnail_url=self.__invoker.services.urls.get_image_url(r.image_name, True),
board_id=self.__invoker.services.board_image_records.get_board_for_image(r.image_name),
workflow_id=self.__invoker.services.workflow_image_records.get_workflow_for_image(r.image_name),
),
results.items,
image_dtos = [
image_record_to_dto(
image_record=r,
image_url=self.__invoker.services.urls.get_image_url(r.image_name),
thumbnail_url=self.__invoker.services.urls.get_image_url(r.image_name, True),
board_id=self.__invoker.services.board_image_records.get_board_for_image(r.image_name),
workflow_id=self.__invoker.services.workflow_image_records.get_workflow_for_image(r.image_name),
)
)
for r in results.items
]
return OffsetPaginatedResults[ImageDTO](
items=image_dtos,

View File

@ -1,5 +1,5 @@
from abc import ABC
class InvocationProcessorABC(ABC):
class InvocationProcessorABC(ABC): # noqa: B024
pass

View File

@ -26,7 +26,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
self.__invoker_thread = Thread(
name="invoker_processor",
target=self.__process,
kwargs=dict(stop_event=self.__stop_event),
kwargs={"stop_event": self.__stop_event},
)
self.__invoker_thread.daemon = True # TODO: make async and do not use threads
self.__invoker_thread.start()

View File

@ -14,7 +14,7 @@ class MemoryInvocationQueue(InvocationQueueABC):
def __init__(self):
self.__queue = Queue()
self.__cancellations = dict()
self.__cancellations = {}
def get(self) -> InvocationQueueItem:
item = self.__queue.get()

View File

@ -122,7 +122,7 @@ class InvocationStatsService(InvocationStatsServiceBase):
def log_stats(self):
completed = set()
errored = set()
for graph_id, node_log in self._stats.items():
for graph_id, _node_log in self._stats.items():
try:
current_graph_state = self._invoker.services.graph_execution_manager.get(graph_id)
except Exception:
@ -142,7 +142,7 @@ class InvocationStatsService(InvocationStatsServiceBase):
cache_stats = self._cache_stats[graph_id]
hwm = cache_stats.high_watermark / GIG
tot = cache_stats.cache_size / GIG
loaded = sum([v for v in cache_stats.loaded_model_sizes.values()]) / GIG
loaded = sum(list(cache_stats.loaded_model_sizes.values())) / GIG
logger.info(f"TOTAL GRAPH EXECUTION TIME: {total_time:7.3f}s")
logger.info("RAM used by InvokeAI process: " + "%4.2fG" % self.ram_used + f" ({self.ram_changed:+5.3f}G)")

View File

@ -15,8 +15,8 @@ class ItemStorageABC(ABC, Generic[T]):
_on_deleted_callbacks: list[Callable[[str], None]]
def __init__(self) -> None:
self._on_changed_callbacks = list()
self._on_deleted_callbacks = list()
self._on_changed_callbacks = []
self._on_deleted_callbacks = []
"""Base item storage class"""

View File

@ -112,7 +112,7 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
)
result = self._cursor.fetchall()
items = list(map(lambda r: self._parse_item(r[0]), result))
items = [self._parse_item(r[0]) for r in result]
self._cursor.execute(f"""SELECT count(*) FROM {self._table_name};""")
count = self._cursor.fetchone()[0]
@ -132,7 +132,7 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
)
result = self._cursor.fetchall()
items = list(map(lambda r: self._parse_item(r[0]), result))
items = [self._parse_item(r[0]) for r in result]
self._cursor.execute(
f"""SELECT count(*) FROM {self._table_name} WHERE item LIKE ?;""",

View File

@ -13,8 +13,8 @@ class LatentsStorageBase(ABC):
_on_deleted_callbacks: list[Callable[[str], None]]
def __init__(self) -> None:
self._on_changed_callbacks = list()
self._on_deleted_callbacks = list()
self._on_changed_callbacks = []
self._on_deleted_callbacks = []
@abstractmethod
def get(self, name: str) -> torch.Tensor:

View File

@ -19,7 +19,7 @@ class ForwardCacheLatentsStorage(LatentsStorageBase):
def __init__(self, underlying_storage: LatentsStorageBase, max_cache_size: int = 20):
super().__init__()
self.__underlying_storage = underlying_storage
self.__cache = dict()
self.__cache = {}
self.__cache_ids = Queue()
self.__max_cache_size = max_cache_size

View File

@ -33,9 +33,11 @@ class DefaultSessionProcessor(SessionProcessorBase):
self.__thread = Thread(
name="session_processor",
target=self.__process,
kwargs=dict(
stop_event=self.__stop_event, poll_now_event=self.__poll_now_event, resume_event=self.__resume_event
),
kwargs={
"stop_event": self.__stop_event,
"poll_now_event": self.__poll_now_event,
"resume_event": self.__resume_event,
},
)
self.__thread.start()

View File

@ -129,12 +129,12 @@ class Batch(BaseModel):
return v
model_config = ConfigDict(
json_schema_extra=dict(
required=[
json_schema_extra={
"required": [
"graph",
"runs",
]
)
}
)
@ -191,8 +191,8 @@ class SessionQueueItemWithoutGraph(BaseModel):
return SessionQueueItemDTO(**queue_item_dict)
model_config = ConfigDict(
json_schema_extra=dict(
required=[
json_schema_extra={
"required": [
"item_id",
"status",
"batch_id",
@ -203,7 +203,7 @@ class SessionQueueItemWithoutGraph(BaseModel):
"created_at",
"updated_at",
]
)
}
)
@ -222,8 +222,8 @@ class SessionQueueItem(SessionQueueItemWithoutGraph):
return SessionQueueItem(**queue_item_dict)
model_config = ConfigDict(
json_schema_extra=dict(
required=[
json_schema_extra={
"required": [
"item_id",
"status",
"batch_id",
@ -235,7 +235,7 @@ class SessionQueueItem(SessionQueueItemWithoutGraph):
"created_at",
"updated_at",
]
)
}
)
@ -355,7 +355,7 @@ def create_session_nfv_tuples(
for item in batch_datum.items
]
node_field_values_to_zip.append(node_field_values)
data.append(list(zip(*node_field_values_to_zip))) # type: ignore [arg-type]
data.append(list(zip(*node_field_values_to_zip, strict=True))) # type: ignore [arg-type]
# create generator to yield session,nfv tuples
count = 0
@ -383,7 +383,7 @@ def calc_session_count(batch: Batch) -> int:
for batch_datum in batch_datum_list:
batch_data_items = range(len(batch_datum.items))
to_zip.append(batch_data_items)
data.append(list(zip(*to_zip)))
data.append(list(zip(*to_zip, strict=True)))
data_product = list(product(*data))
return len(data_product) * batch.runs

View File

@ -78,7 +78,7 @@ def create_system_graphs(graph_library: ItemStorageABC[LibraryGraph]) -> list[Li
"""Creates the default system graphs, or adds new versions if the old ones don't match"""
# TODO: Uncomment this when we are ready to fix this up to prevent breaking changes
graphs: list[LibraryGraph] = list()
graphs: list[LibraryGraph] = []
text_to_image = graph_library.get(default_text_to_image_graph_id)

View File

@ -352,7 +352,7 @@ class Graph(BaseModel):
# Validate that all node ids are unique
node_ids = [n.id for n in self.nodes.values()]
duplicate_node_ids = set([node_id for node_id in node_ids if node_ids.count(node_id) >= 2])
duplicate_node_ids = {node_id for node_id in node_ids if node_ids.count(node_id) >= 2}
if duplicate_node_ids:
raise DuplicateNodeIdError(f"Node ids must be unique, found duplicates {duplicate_node_ids}")
@ -616,7 +616,7 @@ class Graph(BaseModel):
self, node_path: str, prefix: Optional[str] = None
) -> list[tuple["Graph", Union[str, None], Edge]]:
"""Gets all input edges for a node along with the graph they are in and the graph's path"""
edges = list()
edges = []
# Return any input edges that appear in this graph
edges.extend([(self, prefix, e) for e in self.edges if e.destination.node_id == node_path])
@ -658,7 +658,7 @@ class Graph(BaseModel):
self, node_path: str, prefix: Optional[str] = None
) -> list[tuple["Graph", Union[str, None], Edge]]:
"""Gets all output edges for a node along with the graph they are in and the graph's path"""
edges = list()
edges = []
# Return any input edges that appear in this graph
edges.extend([(self, prefix, e) for e in self.edges if e.source.node_id == node_path])
@ -680,8 +680,8 @@ class Graph(BaseModel):
new_input: Optional[EdgeConnection] = None,
new_output: Optional[EdgeConnection] = None,
) -> bool:
inputs = list([e.source for e in self._get_input_edges(node_path, "collection")])
outputs = list([e.destination for e in self._get_output_edges(node_path, "item")])
inputs = [e.source for e in self._get_input_edges(node_path, "collection")]
outputs = [e.destination for e in self._get_output_edges(node_path, "item")]
if new_input is not None:
inputs.append(new_input)
@ -694,7 +694,7 @@ class Graph(BaseModel):
# Get input and output fields (the fields linked to the iterator's input/output)
input_field = get_output_field(self.get_node(inputs[0].node_id), inputs[0].field)
output_fields = list([get_input_field(self.get_node(e.node_id), e.field) for e in outputs])
output_fields = [get_input_field(self.get_node(e.node_id), e.field) for e in outputs]
# Input type must be a list
if get_origin(input_field) != list:
@ -713,8 +713,8 @@ class Graph(BaseModel):
new_input: Optional[EdgeConnection] = None,
new_output: Optional[EdgeConnection] = None,
) -> bool:
inputs = list([e.source for e in self._get_input_edges(node_path, "item")])
outputs = list([e.destination for e in self._get_output_edges(node_path, "collection")])
inputs = [e.source for e in self._get_input_edges(node_path, "item")]
outputs = [e.destination for e in self._get_output_edges(node_path, "collection")]
if new_input is not None:
inputs.append(new_input)
@ -722,18 +722,16 @@ class Graph(BaseModel):
outputs.append(new_output)
# Get input and output fields (the fields linked to the iterator's input/output)
input_fields = list([get_output_field(self.get_node(e.node_id), e.field) for e in inputs])
output_fields = list([get_input_field(self.get_node(e.node_id), e.field) for e in outputs])
input_fields = [get_output_field(self.get_node(e.node_id), e.field) for e in inputs]
output_fields = [get_input_field(self.get_node(e.node_id), e.field) for e in outputs]
# Validate that all inputs are derived from or match a single type
input_field_types = set(
[
t
for input_field in input_fields
for t in ([input_field] if get_origin(input_field) is None else get_args(input_field))
if t != NoneType
]
) # Get unique types
input_field_types = {
t
for input_field in input_fields
for t in ([input_field] if get_origin(input_field) is None else get_args(input_field))
if t != NoneType
} # Get unique types
type_tree = nx.DiGraph()
type_tree.add_nodes_from(input_field_types)
type_tree.add_edges_from([e for e in itertools.permutations(input_field_types, 2) if issubclass(e[1], e[0])])
@ -761,15 +759,15 @@ class Graph(BaseModel):
"""Returns a NetworkX DiGraph representing the layout of this graph"""
# TODO: Cache this?
g = nx.DiGraph()
g.add_nodes_from([n for n in self.nodes.keys()])
g.add_edges_from(set([(e.source.node_id, e.destination.node_id) for e in self.edges]))
g.add_nodes_from(list(self.nodes.keys()))
g.add_edges_from({(e.source.node_id, e.destination.node_id) for e in self.edges})
return g
def nx_graph_with_data(self) -> nx.DiGraph:
"""Returns a NetworkX DiGraph representing the data and layout of this graph"""
g = nx.DiGraph()
g.add_nodes_from([n for n in self.nodes.items()])
g.add_edges_from(set([(e.source.node_id, e.destination.node_id) for e in self.edges]))
g.add_nodes_from(list(self.nodes.items()))
g.add_edges_from({(e.source.node_id, e.destination.node_id) for e in self.edges})
return g
def nx_graph_flat(self, nx_graph: Optional[nx.DiGraph] = None, prefix: Optional[str] = None) -> nx.DiGraph:
@ -791,7 +789,7 @@ class Graph(BaseModel):
# TODO: figure out if iteration nodes need to be expanded
unique_edges = set([(e.source.node_id, e.destination.node_id) for e in self.edges])
unique_edges = {(e.source.node_id, e.destination.node_id) for e in self.edges}
g.add_edges_from([(self._get_node_path(e[0], prefix), self._get_node_path(e[1], prefix)) for e in unique_edges])
return g
@ -843,8 +841,8 @@ class GraphExecutionState(BaseModel):
return v
model_config = ConfigDict(
json_schema_extra=dict(
required=[
json_schema_extra={
"required": [
"id",
"graph",
"execution_graph",
@ -855,7 +853,7 @@ class GraphExecutionState(BaseModel):
"prepared_source_mapping",
"source_prepared_mapping",
]
)
}
)
def next(self) -> Optional[BaseInvocation]:
@ -895,7 +893,7 @@ class GraphExecutionState(BaseModel):
source_node = self.prepared_source_mapping[node_id]
prepared_nodes = self.source_prepared_mapping[source_node]
if all([n in self.executed for n in prepared_nodes]):
if all(n in self.executed for n in prepared_nodes):
self.executed.add(source_node)
self.executed_history.append(source_node)
@ -930,7 +928,7 @@ class GraphExecutionState(BaseModel):
input_collection = getattr(input_collection_prepared_node_output, input_collection_edge.source.field)
self_iteration_count = len(input_collection)
new_nodes: list[str] = list()
new_nodes: list[str] = []
if self_iteration_count == 0:
# TODO: should this raise a warning? It might just happen if an empty collection is input, and should be valid.
return new_nodes
@ -940,7 +938,7 @@ class GraphExecutionState(BaseModel):
# Create new edges for this iteration
# For collect nodes, this may contain multiple inputs to the same field
new_edges: list[Edge] = list()
new_edges: list[Edge] = []
for edge in input_edges:
for input_node_id in (n[1] for n in iteration_node_map if n[0] == edge.source.node_id):
new_edge = Edge(
@ -1034,7 +1032,7 @@ class GraphExecutionState(BaseModel):
# Create execution nodes
next_node = self.graph.get_node(next_node_id)
new_node_ids = list()
new_node_ids = []
if isinstance(next_node, CollectInvocation):
# Collapse all iterator input mappings and create a single execution node for the collect invocation
all_iteration_mappings = list(
@ -1055,7 +1053,10 @@ class GraphExecutionState(BaseModel):
# For every iterator, the parent must either not be a child of that iterator, or must match the prepared iteration for that iterator
# TODO: Handle a node mapping to none
eg = self.execution_graph.nx_graph_flat()
prepared_parent_mappings = [[(n, self._get_iteration_node(n, g, eg, it)) for n in next_node_parents] for it in iterator_node_prepared_combinations] # type: ignore
prepared_parent_mappings = [
[(n, self._get_iteration_node(n, g, eg, it)) for n in next_node_parents]
for it in iterator_node_prepared_combinations
] # type: ignore
# Create execution node for each iteration
for iteration_mappings in prepared_parent_mappings:
@ -1121,7 +1122,7 @@ class GraphExecutionState(BaseModel):
for edge in input_edges
if edge.destination.field == "item"
]
setattr(node, "collection", output_collection)
node.collection = output_collection
else:
for edge in input_edges:
output_value = getattr(self.results[edge.source.node_id], edge.source.field)
@ -1201,7 +1202,7 @@ class LibraryGraph(BaseModel):
@field_validator("exposed_inputs", "exposed_outputs")
def validate_exposed_aliases(cls, v: list[Union[ExposedNodeInput, ExposedNodeOutput]]):
if len(v) != len(set(i.alias for i in v)):
if len(v) != len({i.alias for i in v}):
raise ValueError("Duplicate exposed alias")
return v

View File

@ -59,7 +59,7 @@ def thin_one_time(x, kernels):
def lvmin_thin(x, prunings=True):
y = x
for i in range(32):
for _i in range(32):
y, is_done = thin_one_time(y, lvmin_kernels)
if is_done:
break

View File

@ -21,11 +21,11 @@ def get_metadata_graph_from_raw_session(session_raw: str) -> Optional[dict]:
# sanity check make sure the graph is at least reasonably shaped
if (
type(graph) is not dict
not isinstance(graph, dict)
or "nodes" not in graph
or type(graph["nodes"]) is not dict
or not isinstance(graph["nodes"], dict)
or "edges" not in graph
or type(graph["edges"]) is not list
or not isinstance(graph["edges"], list)
):
# something has gone terribly awry, return an empty dict
return None

View File

@ -88,7 +88,7 @@ class PromptFormatter:
t2i = self.t2i
opt = self.opt
switches = list()
switches = []
switches.append(f'"{opt.prompt}"')
switches.append(f"-s{opt.steps or t2i.steps}")
switches.append(f"-W{opt.width or t2i.width}")

View File

@ -88,7 +88,7 @@ class Txt2Mask(object):
provided image and returns a SegmentedGrayscale object in which the brighter
pixels indicate where the object is inferred to be.
"""
if type(image) is str:
if isinstance(image, str):
image = Image.open(image).convert("RGB")
image = ImageOps.exif_transpose(image)

View File

@ -40,7 +40,7 @@ class InitImageResizer:
(rw, rh) = (int(scale * im.width), int(scale * im.height))
# round everything to multiples of 64
width, height, rw, rh = map(lambda x: x - x % 64, (width, height, rw, rh))
width, height, rw, rh = (x - x % 64 for x in (width, height, rw, rh))
# no resize necessary, but return a copy
if im.width == width and im.height == height:

View File

@ -197,7 +197,7 @@ def download_with_progress_bar(model_url: str, model_dest: str, label: str = "th
def download_conversion_models():
target_dir = config.models_path / "core/convert"
kwargs = dict() # for future use
kwargs = {} # for future use
try:
logger.info("Downloading core tokenizers and text encoders")
@ -252,26 +252,26 @@ def download_conversion_models():
def download_realesrgan():
logger.info("Installing ESRGAN Upscaling models...")
URLs = [
dict(
url="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
dest="core/upscaling/realesrgan/RealESRGAN_x4plus.pth",
description="RealESRGAN_x4plus.pth",
),
dict(
url="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth",
dest="core/upscaling/realesrgan/RealESRGAN_x4plus_anime_6B.pth",
description="RealESRGAN_x4plus_anime_6B.pth",
),
dict(
url="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth",
dest="core/upscaling/realesrgan/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth",
description="ESRGAN_SRx4_DF2KOST_official.pth",
),
dict(
url="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
dest="core/upscaling/realesrgan/RealESRGAN_x2plus.pth",
description="RealESRGAN_x2plus.pth",
),
{
"url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
"dest": "core/upscaling/realesrgan/RealESRGAN_x4plus.pth",
"description": "RealESRGAN_x4plus.pth",
},
{
"url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth",
"dest": "core/upscaling/realesrgan/RealESRGAN_x4plus_anime_6B.pth",
"description": "RealESRGAN_x4plus_anime_6B.pth",
},
{
"url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth",
"dest": "core/upscaling/realesrgan/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth",
"description": "ESRGAN_SRx4_DF2KOST_official.pth",
},
{
"url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
"dest": "core/upscaling/realesrgan/RealESRGAN_x2plus.pth",
"description": "RealESRGAN_x2plus.pth",
},
]
for model in URLs:
download_with_progress_bar(model["url"], config.models_path / model["dest"], model["description"])
@ -680,7 +680,7 @@ def default_user_selections(program_opts: Namespace) -> InstallSelections:
if program_opts.default_only
else [models[x].path or models[x].repo_id for x in installer.recommended_models()]
if program_opts.yes_to_all
else list(),
else [],
)

View File

@ -123,8 +123,6 @@ class MigrateTo3(object):
logger.error(str(e))
except KeyboardInterrupt:
raise
except Exception as e:
logger.error(str(e))
for f in files:
# don't copy raw learned_embeds.bin or pytorch_lora_weights.bin
# let them be copied as part of a tree copy operation
@ -143,8 +141,6 @@ class MigrateTo3(object):
logger.error(str(e))
except KeyboardInterrupt:
raise
except Exception as e:
logger.error(str(e))
def migrate_support_models(self):
"""
@ -182,10 +178,10 @@ class MigrateTo3(object):
"""
dest_directory = self.dest_models
kwargs = dict(
cache_dir=self.root_directory / "models/hub",
kwargs = {
"cache_dir": self.root_directory / "models/hub",
# local_files_only = True
)
}
try:
logger.info("Migrating core tokenizers and text encoders")
target_dir = dest_directory / "core" / "convert"
@ -316,11 +312,11 @@ class MigrateTo3(object):
dest_dir = self.dest_models
cache = self.root_directory / "models/hub"
kwargs = dict(
cache_dir=cache,
safety_checker=None,
kwargs = {
"cache_dir": cache,
"safety_checker": None,
# local_files_only = True,
)
}
owner, repo_name = repo_id.split("/")
model_name = model_name or repo_name

View File

@ -120,7 +120,7 @@ class ModelInstall(object):
be treated uniformly. It also sorts the models alphabetically
by their name, to improve the display somewhat.
"""
model_dict = dict()
model_dict = {}
# first populate with the entries in INITIAL_MODELS.yaml
for key, value in self.datasets.items():
@ -134,7 +134,7 @@ class ModelInstall(object):
model_dict[key] = model_info
# supplement with entries in models.yaml
installed_models = [x for x in self.mgr.list_models()]
installed_models = list(self.mgr.list_models())
for md in installed_models:
base = md["base_model"]
@ -176,7 +176,7 @@ class ModelInstall(object):
# logic here a little reversed to maintain backward compatibility
def starter_models(self, all_models: bool = False) -> Set[str]:
models = set()
for key, value in self.datasets.items():
for key, _value in self.datasets.items():
name, base, model_type = ModelManager.parse_key(key)
if all_models or model_type in [ModelType.Main, ModelType.Vae]:
models.add(key)
@ -184,7 +184,7 @@ class ModelInstall(object):
def recommended_models(self) -> Set[str]:
starters = self.starter_models(all_models=True)
return set([x for x in starters if self.datasets[x].get("recommended", False)])
return {x for x in starters if self.datasets[x].get("recommended", False)}
def default_model(self) -> str:
starters = self.starter_models()
@ -234,7 +234,7 @@ class ModelInstall(object):
"""
if not models_installed:
models_installed = dict()
models_installed = {}
model_path_id_or_url = str(model_path_id_or_url).strip("\"' ")
@ -252,16 +252,14 @@ class ModelInstall(object):
# folders style or similar
elif path.is_dir() and any(
[
(path / x).exists()
for x in {
"config.json",
"model_index.json",
"learned_embeds.bin",
"pytorch_lora_weights.bin",
"pytorch_lora_weights.safetensors",
}
]
(path / x).exists()
for x in {
"config.json",
"model_index.json",
"learned_embeds.bin",
"pytorch_lora_weights.bin",
"pytorch_lora_weights.safetensors",
}
):
models_installed.update({str(model_path_id_or_url): self._install_path(path)})
@ -433,17 +431,17 @@ class ModelInstall(object):
rel_path = self.relative_to_root(path, self.config.models_path)
attributes = dict(
path=str(rel_path),
description=str(description),
model_format=info.format,
)
attributes = {
"path": str(rel_path),
"description": str(description),
"model_format": info.format,
}
legacy_conf = None
if info.model_type == ModelType.Main or info.model_type == ModelType.ONNX:
attributes.update(
dict(
variant=info.variant_type,
)
{
"variant": info.variant_type,
}
)
if info.format == "checkpoint":
try:
@ -474,7 +472,7 @@ class ModelInstall(object):
)
if legacy_conf:
attributes.update(dict(config=str(legacy_conf)))
attributes.update({"config": str(legacy_conf)})
return attributes
def relative_to_root(self, path: Path, root: Optional[Path] = None) -> Path:
@ -519,7 +517,7 @@ class ModelInstall(object):
def _download_hf_model(self, repo_id: str, files: List[str], staging: Path, subfolder: None) -> Path:
_, name = repo_id.split("/")
location = staging / name
paths = list()
paths = []
for filename in files:
filePath = Path(filename)
p = hf_download_with_resume(

View File

@ -130,7 +130,9 @@ class IPAttnProcessor2_0(torch.nn.Module):
assert ip_adapter_image_prompt_embeds is not None
assert len(ip_adapter_image_prompt_embeds) == len(self._weights)
for ipa_embed, ipa_weights, scale in zip(ip_adapter_image_prompt_embeds, self._weights, self._scales):
for ipa_embed, ipa_weights, scale in zip(
ip_adapter_image_prompt_embeds, self._weights, self._scales, strict=True
):
# The batch dimensions should match.
assert ipa_embed.shape[0] == encoder_hidden_states.shape[0]
# The token_len dimensions should match.

View File

@ -56,7 +56,7 @@ class PerceiverAttention(nn.Module):
x = self.norm1(x)
latents = self.norm2(latents)
b, l, _ = latents.shape
b, L, _ = latents.shape
q = self.to_q(latents)
kv_input = torch.cat((x, latents), dim=-2)
@ -72,7 +72,7 @@ class PerceiverAttention(nn.Module):
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
out = weight @ v
out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
out = out.permute(0, 2, 1, 3).reshape(b, L, -1)
return self.to_out(out)

View File

@ -269,7 +269,7 @@ def create_unet_diffusers_config(original_config, image_size: int, controlnet=Fa
resolution *= 2
up_block_types = []
for i in range(len(block_out_channels)):
for _i in range(len(block_out_channels)):
block_type = "CrossAttnUpBlock2D" if resolution in unet_params.attention_resolutions else "UpBlock2D"
up_block_types.append(block_type)
resolution //= 2
@ -1223,7 +1223,7 @@ def download_from_original_stable_diffusion_ckpt(
# scan model
scan_result = scan_file_path(checkpoint_path)
if scan_result.infected_files != 0:
raise "The model {checkpoint_path} is potentially infected by malware. Aborting import."
raise Exception("The model {checkpoint_path} is potentially infected by malware. Aborting import.")
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
checkpoint = torch.load(checkpoint_path, map_location=device)
@ -1664,7 +1664,7 @@ def download_controlnet_from_original_ckpt(
# scan model
scan_result = scan_file_path(checkpoint_path)
if scan_result.infected_files != 0:
raise "The model {checkpoint_path} is potentially infected by malware. Aborting import."
raise Exception("The model {checkpoint_path} is potentially infected by malware. Aborting import.")
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
checkpoint = torch.load(checkpoint_path, map_location=device)

View File

@ -104,7 +104,7 @@ class ModelPatcher:
loras: List[Tuple[LoRAModel, float]],
prefix: str,
):
original_weights = dict()
original_weights = {}
try:
with torch.no_grad():
for lora, lora_weight in loras:
@ -242,7 +242,7 @@ class ModelPatcher:
):
skipped_layers = []
try:
for i in range(clip_skip):
for _i in range(clip_skip):
skipped_layers.append(text_encoder.text_model.encoder.layers.pop(-1))
yield
@ -324,7 +324,7 @@ class TextualInversionManager(BaseTextualInversionManager):
tokenizer: CLIPTokenizer
def __init__(self, tokenizer: CLIPTokenizer):
self.pad_tokens = dict()
self.pad_tokens = {}
self.tokenizer = tokenizer
def expand_textual_inversion_token_ids_if_necessary(self, token_ids: list[int]) -> list[int]:
@ -385,10 +385,10 @@ class ONNXModelPatcher:
if not isinstance(model, IAIOnnxRuntimeModel):
raise Exception("Only IAIOnnxRuntimeModel models supported")
orig_weights = dict()
orig_weights = {}
try:
blended_loras = dict()
blended_loras = {}
for lora, lora_weight in loras:
for layer_key, layer in lora.layers.items():
@ -404,7 +404,7 @@ class ONNXModelPatcher:
else:
blended_loras[layer_key] = layer_weight
node_names = dict()
node_names = {}
for node in model.nodes.values():
node_names[node.name.replace("/", "_").replace(".", "_").lstrip("_")] = node.name

View File

@ -66,11 +66,13 @@ class CacheStats(object):
class ModelLocker(object):
"Forward declaration"
pass
class ModelCache(object):
"Forward declaration"
pass
@ -132,7 +134,7 @@ class ModelCache(object):
snapshots, so it is recommended to disable this feature unless you are actively inspecting the model cache's
behaviour.
"""
self.model_infos: Dict[str, ModelBase] = dict()
self.model_infos: Dict[str, ModelBase] = {}
# allow lazy offloading only when vram cache enabled
self.lazy_offloading = lazy_offloading and max_vram_cache_size > 0
self.precision: torch.dtype = precision
@ -147,8 +149,8 @@ class ModelCache(object):
# used for stats collection
self.stats = None
self._cached_models = dict()
self._cache_stack = list()
self._cached_models = {}
self._cache_stack = []
def _capture_memory_snapshot(self) -> Optional[MemorySnapshot]:
if self._log_memory_usage:

View File

@ -26,5 +26,5 @@ def skip_torch_weight_init():
yield None
finally:
for torch_module, saved_function in zip(torch_modules, saved_functions):
for torch_module, saved_function in zip(torch_modules, saved_functions, strict=True):
torch_module.reset_parameters = saved_function

View File

@ -363,7 +363,7 @@ class ModelManager(object):
else:
return
self.models = dict()
self.models = {}
for model_key, model_config in config.items():
if model_key.startswith("_"):
continue
@ -374,7 +374,7 @@ class ModelManager(object):
self.models[model_key] = model_class.create_config(**model_config)
# check config version number and update on disk/RAM if necessary
self.cache_keys = dict()
self.cache_keys = {}
# add controlnet, lora and textual_inversion models from disk
self.scan_models_directory()
@ -655,7 +655,7 @@ class ModelManager(object):
"""
# TODO: redo
for model_dict in self.list_models():
for model_name, model_info in model_dict.items():
for _model_name, model_info in model_dict.items():
line = f'{model_info["name"]:25s} {model_info["type"]:10s} {model_info["description"]}'
print(line)
@ -902,7 +902,7 @@ class ModelManager(object):
"""
Write current configuration out to the indicated file.
"""
data_to_save = dict()
data_to_save = {}
data_to_save["__metadata__"] = self.config_meta.model_dump()
for model_key, model_config in self.models.items():
@ -1034,7 +1034,7 @@ class ModelManager(object):
self.ignore = ignore
def on_search_started(self):
self.new_models_found = dict()
self.new_models_found = {}
def on_model_found(self, model: Path):
if model not in self.ignore:
@ -1106,7 +1106,7 @@ class ModelManager(object):
# avoid circular import here
from invokeai.backend.install.model_install_backend import ModelInstall
successfully_installed = dict()
successfully_installed = {}
installer = ModelInstall(
config=self.app_config, prediction_type_helper=prediction_type_helper, model_manager=self

View File

@ -92,7 +92,7 @@ class ModelMerger(object):
**kwargs - the default DiffusionPipeline.get_config_dict kwargs:
cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map
"""
model_paths = list()
model_paths = []
config = self.manager.app_config
base_model = BaseModelType(base_model)
vae = None
@ -124,13 +124,13 @@ class ModelMerger(object):
dump_path = (dump_path / merged_model_name).as_posix()
merged_pipe.save_pretrained(dump_path, safe_serialization=True)
attributes = dict(
path=dump_path,
description=f"Merge of models {', '.join(model_names)}",
model_format="diffusers",
variant=ModelVariantType.Normal.value,
vae=vae,
)
attributes = {
"path": dump_path,
"description": f"Merge of models {', '.join(model_names)}",
"model_format": "diffusers",
"variant": ModelVariantType.Normal.value,
"vae": vae,
}
return self.manager.add_model(
merged_model_name,
base_model=base_model,

View File

@ -237,7 +237,7 @@ class ModelProbe(object):
# scan model
scan_result = scan_file_path(checkpoint)
if scan_result.infected_files != 0:
raise "The model {model_name} is potentially infected by malware. Aborting import."
raise Exception("The model {model_name} is potentially infected by malware. Aborting import.")
# ##################################################3

View File

@ -59,7 +59,7 @@ class ModelSearch(ABC):
for root, dirs, files in os.walk(path, followlinks=True):
if str(Path(root).name).startswith("."):
self._pruned_paths.add(root)
if any([Path(root).is_relative_to(x) for x in self._pruned_paths]):
if any(Path(root).is_relative_to(x) for x in self._pruned_paths):
continue
self._items_scanned += len(dirs) + len(files)
@ -69,16 +69,14 @@ class ModelSearch(ABC):
self._scanned_dirs.add(path)
continue
if any(
[
(path / x).exists()
for x in {
"config.json",
"model_index.json",
"learned_embeds.bin",
"pytorch_lora_weights.bin",
"image_encoder.txt",
}
]
(path / x).exists()
for x in {
"config.json",
"model_index.json",
"learned_embeds.bin",
"pytorch_lora_weights.bin",
"image_encoder.txt",
}
):
try:
self.on_model_found(path)

View File

@ -97,8 +97,8 @@ MODEL_CLASSES = {
# },
}
MODEL_CONFIGS = list()
OPENAPI_MODEL_CONFIGS = list()
MODEL_CONFIGS = []
OPENAPI_MODEL_CONFIGS = []
class OpenAPIModelInfoBase(BaseModel):
@ -109,7 +109,7 @@ class OpenAPIModelInfoBase(BaseModel):
model_config = ConfigDict(protected_namespaces=())
for base_model, models in MODEL_CLASSES.items():
for _base_model, models in MODEL_CLASSES.items():
for model_type, model_class in models.items():
model_configs = set(model_class._get_configs().values())
model_configs.discard(None)
@ -133,7 +133,7 @@ for base_model, models in MODEL_CLASSES.items():
def get_model_config_enums():
enums = list()
enums = []
for model_config in MODEL_CONFIGS:
if hasattr(inspect, "get_annotations"):

View File

@ -153,7 +153,7 @@ class ModelBase(metaclass=ABCMeta):
else:
res_type = sys.modules["diffusers"]
res_type = getattr(res_type, "pipelines")
res_type = res_type.pipelines
for subtype in subtypes:
res_type = getattr(res_type, subtype)
@ -164,7 +164,7 @@ class ModelBase(metaclass=ABCMeta):
with suppress(Exception):
return cls.__configs
configs = dict()
configs = {}
for name in dir(cls):
if name.startswith("__"):
continue
@ -246,8 +246,8 @@ class DiffusersModel(ModelBase):
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
super().__init__(model_path, base_model, model_type)
self.child_types: Dict[str, Type] = dict()
self.child_sizes: Dict[str, int] = dict()
self.child_types: Dict[str, Type] = {}
self.child_sizes: Dict[str, int] = {}
try:
config_data = DiffusionPipeline.load_config(self.model_path)
@ -326,8 +326,8 @@ def calc_model_size_by_fs(model_path: str, subfolder: Optional[str] = None, vari
all_files = os.listdir(model_path)
all_files = [f for f in all_files if os.path.isfile(os.path.join(model_path, f))]
fp16_files = set([f for f in all_files if ".fp16." in f or ".fp16-" in f])
bit8_files = set([f for f in all_files if ".8bit." in f or ".8bit-" in f])
fp16_files = {f for f in all_files if ".fp16." in f or ".fp16-" in f}
bit8_files = {f for f in all_files if ".8bit." in f or ".8bit-" in f}
other_files = set(all_files) - fp16_files - bit8_files
if variant is None:
@ -413,7 +413,7 @@ def _calc_onnx_model_by_data(model) -> int:
def _fast_safetensors_reader(path: str):
checkpoint = dict()
checkpoint = {}
device = torch.device("meta")
with open(path, "rb") as f:
definition_len = int.from_bytes(f.read(8), "little")
@ -483,7 +483,7 @@ class IAIOnnxRuntimeModel:
class _tensor_access:
def __init__(self, model):
self.model = model
self.indexes = dict()
self.indexes = {}
for idx, obj in enumerate(self.model.proto.graph.initializer):
self.indexes[obj.name] = idx
@ -524,7 +524,7 @@ class IAIOnnxRuntimeModel:
class _access_helper:
def __init__(self, raw_proto):
self.indexes = dict()
self.indexes = {}
self.raw_proto = raw_proto
for idx, obj in enumerate(raw_proto):
self.indexes[obj.name] = idx
@ -549,7 +549,7 @@ class IAIOnnxRuntimeModel:
return self.indexes.keys()
def values(self):
return [obj for obj in self.raw_proto]
return list(self.raw_proto)
def __init__(self, model_path: str, provider: Optional[str]):
self.path = model_path

View File

@ -104,7 +104,7 @@ class ControlNetModel(ModelBase):
return ControlNetModelFormat.Diffusers
if os.path.isfile(path):
if any([path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt", "pth"]]):
if any(path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt", "pth"]):
return ControlNetModelFormat.Checkpoint
raise InvalidModelException(f"Not a valid model: {path}")

View File

@ -73,7 +73,7 @@ class LoRAModel(ModelBase):
return LoRAModelFormat.Diffusers
if os.path.isfile(path):
if any([path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]]):
if any(path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]):
return LoRAModelFormat.LyCORIS
raise InvalidModelException(f"Not a valid model: {path}")
@ -462,7 +462,7 @@ class LoRAModelRaw: # (torch.nn.Module):
dtype: Optional[torch.dtype] = None,
):
# TODO: try revert if exception?
for key, layer in self.layers.items():
for _key, layer in self.layers.items():
layer.to(device=device, dtype=dtype)
def calc_size(self) -> int:
@ -499,7 +499,7 @@ class LoRAModelRaw: # (torch.nn.Module):
stability_unet_keys = list(SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP)
stability_unet_keys.sort()
new_state_dict = dict()
new_state_dict = {}
for full_key, value in state_dict.items():
if full_key.startswith("lora_unet_"):
search_key = full_key.replace("lora_unet_", "")
@ -545,7 +545,7 @@ class LoRAModelRaw: # (torch.nn.Module):
model = cls(
name=file_path.stem, # TODO:
layers=dict(),
layers={},
)
if file_path.suffix == ".safetensors":
@ -593,12 +593,12 @@ class LoRAModelRaw: # (torch.nn.Module):
@staticmethod
def _group_state(state_dict: dict):
state_dict_groupped = dict()
state_dict_groupped = {}
for key, value in state_dict.items():
stem, leaf = key.split(".", 1)
if stem not in state_dict_groupped:
state_dict_groupped[stem] = dict()
state_dict_groupped[stem] = {}
state_dict_groupped[stem][leaf] = value
return state_dict_groupped

View File

@ -110,7 +110,7 @@ class StableDiffusion1Model(DiffusersModel):
return StableDiffusion1ModelFormat.Diffusers
if os.path.isfile(model_path):
if any([model_path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]]):
if any(model_path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]):
return StableDiffusion1ModelFormat.Checkpoint
raise InvalidModelException(f"Not a valid model: {model_path}")
@ -221,7 +221,7 @@ class StableDiffusion2Model(DiffusersModel):
return StableDiffusion2ModelFormat.Diffusers
if os.path.isfile(model_path):
if any([model_path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]]):
if any(model_path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]):
return StableDiffusion2ModelFormat.Checkpoint
raise InvalidModelException(f"Not a valid model: {model_path}")

View File

@ -71,7 +71,7 @@ class TextualInversionModel(ModelBase):
return None # diffusers-ti
if os.path.isfile(path):
if any([path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt", "bin"]]):
if any(path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt", "bin"]):
return None
raise InvalidModelException(f"Not a valid model: {path}")

View File

@ -89,7 +89,7 @@ class VaeModel(ModelBase):
return VaeModelFormat.Diffusers
if os.path.isfile(path):
if any([path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]]):
if any(path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]):
return VaeModelFormat.Checkpoint
raise InvalidModelException(f"Not a valid model: {path}")

View File

@ -114,8 +114,10 @@ class ModelConfigBase(BaseModel):
current_hash: Optional[str] = Field(
description="current fasthash of model contents", default=None
) # if model is converted or otherwise modified, this will hold updated hash
description: Optional[str] = Field(None)
source: Optional[str] = Field(description="Model download source (URL or repo_id)", default=None)
description: Optional[str] = Field(default=None)
source: Optional[str] = Field(
description="Model download source (URL or repo_id)", default=None
)
model_config = ConfigDict(
use_enum_values=False,
@ -249,12 +251,19 @@ class T2IConfig(ModelConfigBase):
format: Literal[ModelFormat.Diffusers]
_ONNXConfig = Annotated[Union[ONNXSD1Config, ONNXSD2Config], Field(discriminator="base")]
_ControlNetConfig = Annotated[
Union[ControlNetDiffusersConfig, ControlNetCheckpointConfig], Field(discriminator="format")
_ONNXConfig = Annotated[
Union[ONNXSD1Config, ONNXSD2Config], Field(discriminator="base")
]
_ControlNetConfig = Annotated[
Union[ControlNetDiffusersConfig, ControlNetCheckpointConfig],
Field(discriminator="format"),
]
_VaeConfig = Annotated[
Union[VaeDiffusersConfig, VaeCheckpointConfig], Field(discriminator="format")
]
_MainModelConfig = Annotated[
Union[MainDiffusersConfig, MainCheckpointConfig], Field(discriminator="format")
]
_VaeConfig = Annotated[Union[VaeDiffusersConfig, VaeCheckpointConfig], Field(discriminator="format")]
_MainModelConfig = Annotated[Union[MainDiffusersConfig, MainCheckpointConfig], Field(discriminator="format")]
AnyModelConfig = Union[
_MainModelConfig,

View File

@ -49,7 +49,7 @@ class FastModelHash(object):
def _hash_dir(cls, model_location: Union[str, Path]) -> str:
components: Dict[str, str] = {}
for root, dirs, files in os.walk(model_location):
for root, _dirs, files in os.walk(model_location):
for file in files:
# only tally tensor files because diffusers config files change slightly
# depending on how the model was downloaded/converted.
@ -61,6 +61,6 @@ class FastModelHash(object):
# hash all the model hashes together, using alphabetic file order
md5 = hashlib.md5()
for path, fast_hash in sorted(components.items()):
for _path, fast_hash in sorted(components.items()):
md5.update(fast_hash.encode("utf-8"))
return md5.hexdigest()

View File

@ -7,9 +7,16 @@ from omegaconf import DictConfig, OmegaConf
from pydantic import TypeAdapter
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.model_records import DuplicateModelException, ModelRecordServiceSQL
from invokeai.app.services.model_records import (
DuplicateModelException,
ModelRecordServiceSQL,
)
from invokeai.app.services.shared.sqlite import SqliteDatabase
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelType
from invokeai.backend.model_manager.config import (
AnyModelConfig,
BaseModelType,
ModelType,
)
from invokeai.backend.model_manager.hash import FastModelHash
from invokeai.backend.util.logging import InvokeAILogger

View File

@ -193,6 +193,7 @@ class InvokeAIStableDiffusionPipelineOutput(StableDiffusionPipelineOutput):
attention_map_saver (`AttentionMapSaver`): Object containing attention maps that can be displayed to the user
after generation completes. Optional.
"""
attention_map_saver: Optional[AttentionMapSaver]

View File

@ -54,13 +54,13 @@ class Context:
self.clear_requests(cleanup=True)
def register_cross_attention_modules(self, model):
for name, module in get_cross_attention_modules(model, CrossAttentionType.SELF):
for name, _module in get_cross_attention_modules(model, CrossAttentionType.SELF):
if name in self.self_cross_attention_module_identifiers:
assert False, f"name {name} cannot appear more than once"
raise AssertionError(f"name {name} cannot appear more than once")
self.self_cross_attention_module_identifiers.append(name)
for name, module in get_cross_attention_modules(model, CrossAttentionType.TOKENS):
for name, _module in get_cross_attention_modules(model, CrossAttentionType.TOKENS):
if name in self.tokens_cross_attention_module_identifiers:
assert False, f"name {name} cannot appear more than once"
raise AssertionError(f"name {name} cannot appear more than once")
self.tokens_cross_attention_module_identifiers.append(name)
def request_save_attention_maps(self, cross_attention_type: CrossAttentionType):
@ -170,7 +170,7 @@ class Context:
self.saved_cross_attention_maps = {}
def offload_saved_attention_slices_to_cpu(self):
for key, map_dict in self.saved_cross_attention_maps.items():
for _key, map_dict in self.saved_cross_attention_maps.items():
for offset, slice in map_dict["slices"].items():
map_dict[offset] = slice.to("cpu")
@ -433,7 +433,7 @@ def inject_attention_function(unet, context: Context):
module.identifier = identifier
try:
module.set_attention_slice_wrangler(attention_slice_wrangler)
module.set_slicing_strategy_getter(lambda module: context.get_slicing_strategy(identifier))
module.set_slicing_strategy_getter(lambda module: context.get_slicing_strategy(identifier)) # noqa: B023
except AttributeError as e:
if is_attribute_error_about(e, "set_attention_slice_wrangler"):
print(f"TODO: implement set_attention_slice_wrangler for {type(module)}") # TODO
@ -445,7 +445,7 @@ def remove_attention_function(unet):
cross_attention_modules = get_cross_attention_modules(
unet, CrossAttentionType.TOKENS
) + get_cross_attention_modules(unet, CrossAttentionType.SELF)
for identifier, module in cross_attention_modules:
for _identifier, module in cross_attention_modules:
try:
# clear wrangler callback
module.set_attention_slice_wrangler(None)

View File

@ -56,7 +56,7 @@ class AttentionMapSaver:
merged = None
for key, maps in self.collated_maps.items():
for _key, maps in self.collated_maps.items():
# maps has shape [(H*W), N] for N tokens
# but we want [N, H, W]
this_scale_factor = math.sqrt(maps.shape[0] / (latents_width * latents_height))

View File

@ -123,7 +123,7 @@ class InvokeAIDiffuserComponent:
# control_data should be type List[ControlNetData]
# this loop covers both ControlNet (one ControlNetData in list)
# and MultiControlNet (multiple ControlNetData in list)
for i, control_datum in enumerate(control_data):
for _i, control_datum in enumerate(control_data):
control_mode = control_datum.control_mode
# soft_injection and cfg_injection are the two ControlNet control_mode booleans
# that are combined at higher level to make control_mode enum
@ -214,7 +214,7 @@ class InvokeAIDiffuserComponent:
# add controlnet outputs together if have multiple controlnets
down_block_res_samples = [
samples_prev + samples_curr
for samples_prev, samples_curr in zip(down_block_res_samples, down_samples)
for samples_prev, samples_curr in zip(down_block_res_samples, down_samples, strict=True)
]
mid_block_res_sample += mid_sample
@ -642,7 +642,9 @@ class InvokeAIDiffuserComponent:
deltas = None
uncond_latents = None
weighted_cond_list = c_or_weighted_c_list if type(c_or_weighted_c_list) is list else [(c_or_weighted_c_list, 1)]
weighted_cond_list = (
c_or_weighted_c_list if isinstance(c_or_weighted_c_list, list) else [(c_or_weighted_c_list, 1)]
)
# below is fugly omg
conditionings = [uc] + [c for c, weight in weighted_cond_list]

View File

@ -16,28 +16,28 @@ from diffusers import (
UniPCMultistepScheduler,
)
SCHEDULER_MAP = dict(
ddim=(DDIMScheduler, dict()),
ddpm=(DDPMScheduler, dict()),
deis=(DEISMultistepScheduler, dict()),
lms=(LMSDiscreteScheduler, dict(use_karras_sigmas=False)),
lms_k=(LMSDiscreteScheduler, dict(use_karras_sigmas=True)),
pndm=(PNDMScheduler, dict()),
heun=(HeunDiscreteScheduler, dict(use_karras_sigmas=False)),
heun_k=(HeunDiscreteScheduler, dict(use_karras_sigmas=True)),
euler=(EulerDiscreteScheduler, dict(use_karras_sigmas=False)),
euler_k=(EulerDiscreteScheduler, dict(use_karras_sigmas=True)),
euler_a=(EulerAncestralDiscreteScheduler, dict()),
kdpm_2=(KDPM2DiscreteScheduler, dict()),
kdpm_2_a=(KDPM2AncestralDiscreteScheduler, dict()),
dpmpp_2s=(DPMSolverSinglestepScheduler, dict(use_karras_sigmas=False)),
dpmpp_2s_k=(DPMSolverSinglestepScheduler, dict(use_karras_sigmas=True)),
dpmpp_2m=(DPMSolverMultistepScheduler, dict(use_karras_sigmas=False)),
dpmpp_2m_k=(DPMSolverMultistepScheduler, dict(use_karras_sigmas=True)),
dpmpp_2m_sde=(DPMSolverMultistepScheduler, dict(use_karras_sigmas=False, algorithm_type="sde-dpmsolver++")),
dpmpp_2m_sde_k=(DPMSolverMultistepScheduler, dict(use_karras_sigmas=True, algorithm_type="sde-dpmsolver++")),
dpmpp_sde=(DPMSolverSDEScheduler, dict(use_karras_sigmas=False, noise_sampler_seed=0)),
dpmpp_sde_k=(DPMSolverSDEScheduler, dict(use_karras_sigmas=True, noise_sampler_seed=0)),
unipc=(UniPCMultistepScheduler, dict(cpu_only=True)),
lcm=(LCMScheduler, dict()),
)
SCHEDULER_MAP = {
"ddim": (DDIMScheduler, {}),
"ddpm": (DDPMScheduler, {}),
"deis": (DEISMultistepScheduler, {}),
"lms": (LMSDiscreteScheduler, {"use_karras_sigmas": False}),
"lms_k": (LMSDiscreteScheduler, {"use_karras_sigmas": True}),
"pndm": (PNDMScheduler, {}),
"heun": (HeunDiscreteScheduler, {"use_karras_sigmas": False}),
"heun_k": (HeunDiscreteScheduler, {"use_karras_sigmas": True}),
"euler": (EulerDiscreteScheduler, {"use_karras_sigmas": False}),
"euler_k": (EulerDiscreteScheduler, {"use_karras_sigmas": True}),
"euler_a": (EulerAncestralDiscreteScheduler, {}),
"kdpm_2": (KDPM2DiscreteScheduler, {}),
"kdpm_2_a": (KDPM2AncestralDiscreteScheduler, {}),
"dpmpp_2s": (DPMSolverSinglestepScheduler, {"use_karras_sigmas": False}),
"dpmpp_2s_k": (DPMSolverSinglestepScheduler, {"use_karras_sigmas": True}),
"dpmpp_2m": (DPMSolverMultistepScheduler, {"use_karras_sigmas": False}),
"dpmpp_2m_k": (DPMSolverMultistepScheduler, {"use_karras_sigmas": True}),
"dpmpp_2m_sde": (DPMSolverMultistepScheduler, {"use_karras_sigmas": False, "algorithm_type": "sde-dpmsolver++"}),
"dpmpp_2m_sde_k": (DPMSolverMultistepScheduler, {"use_karras_sigmas": True, "algorithm_type": "sde-dpmsolver++"}),
"dpmpp_sde": (DPMSolverSDEScheduler, {"use_karras_sigmas": False, "noise_sampler_seed": 0}),
"dpmpp_sde_k": (DPMSolverSDEScheduler, {"use_karras_sigmas": True, "noise_sampler_seed": 0}),
"unipc": (UniPCMultistepScheduler, {"cpu_only": True}),
"lcm": (LCMScheduler, {}),
}

View File

@ -615,7 +615,7 @@ def do_textual_inversion_training(
vae_info = model_manager.get_model(*model_meta, submodel=SubModelType.Vae)
unet_info = model_manager.get_model(*model_meta, submodel=SubModelType.UNet)
pipeline_args = dict(local_files_only=True)
pipeline_args = {"local_files_only": True}
if tokenizer_name:
tokenizer = CLIPTokenizer.from_pretrained(tokenizer_name, **pipeline_args)
else:

View File

@ -732,7 +732,9 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
controlnet_down_block_res_samples = ()
for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):
for down_block_res_sample, controlnet_block in zip(
down_block_res_samples, self.controlnet_down_blocks, strict=True
):
down_block_res_sample = controlnet_block(down_block_res_sample)
controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,)
@ -745,7 +747,9 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0
scales = scales * conditioning_scale
down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)]
down_block_res_samples = [
sample * scale for sample, scale in zip(down_block_res_samples, scales, strict=False)
]
mid_block_res_sample = mid_block_res_sample * scales[-1] # last one
else:
down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]

View File

@ -225,34 +225,34 @@ def basicConfig(**kwargs):
_FACILITY_MAP = (
dict(
LOG_KERN=syslog.LOG_KERN,
LOG_USER=syslog.LOG_USER,
LOG_MAIL=syslog.LOG_MAIL,
LOG_DAEMON=syslog.LOG_DAEMON,
LOG_AUTH=syslog.LOG_AUTH,
LOG_LPR=syslog.LOG_LPR,
LOG_NEWS=syslog.LOG_NEWS,
LOG_UUCP=syslog.LOG_UUCP,
LOG_CRON=syslog.LOG_CRON,
LOG_SYSLOG=syslog.LOG_SYSLOG,
LOG_LOCAL0=syslog.LOG_LOCAL0,
LOG_LOCAL1=syslog.LOG_LOCAL1,
LOG_LOCAL2=syslog.LOG_LOCAL2,
LOG_LOCAL3=syslog.LOG_LOCAL3,
LOG_LOCAL4=syslog.LOG_LOCAL4,
LOG_LOCAL5=syslog.LOG_LOCAL5,
LOG_LOCAL6=syslog.LOG_LOCAL6,
LOG_LOCAL7=syslog.LOG_LOCAL7,
)
{
"LOG_KERN": syslog.LOG_KERN,
"LOG_USER": syslog.LOG_USER,
"LOG_MAIL": syslog.LOG_MAIL,
"LOG_DAEMON": syslog.LOG_DAEMON,
"LOG_AUTH": syslog.LOG_AUTH,
"LOG_LPR": syslog.LOG_LPR,
"LOG_NEWS": syslog.LOG_NEWS,
"LOG_UUCP": syslog.LOG_UUCP,
"LOG_CRON": syslog.LOG_CRON,
"LOG_SYSLOG": syslog.LOG_SYSLOG,
"LOG_LOCAL0": syslog.LOG_LOCAL0,
"LOG_LOCAL1": syslog.LOG_LOCAL1,
"LOG_LOCAL2": syslog.LOG_LOCAL2,
"LOG_LOCAL3": syslog.LOG_LOCAL3,
"LOG_LOCAL4": syslog.LOG_LOCAL4,
"LOG_LOCAL5": syslog.LOG_LOCAL5,
"LOG_LOCAL6": syslog.LOG_LOCAL6,
"LOG_LOCAL7": syslog.LOG_LOCAL7,
}
if SYSLOG_AVAILABLE
else dict()
else {}
)
_SOCK_MAP = dict(
SOCK_STREAM=socket.SOCK_STREAM,
SOCK_DGRAM=socket.SOCK_DGRAM,
)
_SOCK_MAP = {
"SOCK_STREAM": socket.SOCK_STREAM,
"SOCK_DGRAM": socket.SOCK_DGRAM,
}
class InvokeAIFormatter(logging.Formatter):
@ -344,7 +344,7 @@ LOG_FORMATTERS = {
class InvokeAILogger(object):
loggers = dict()
loggers = {}
@classmethod
def get_logger(
@ -364,7 +364,7 @@ class InvokeAILogger(object):
@classmethod
def get_loggers(cls, config: InvokeAIAppConfig) -> list[logging.Handler]:
handler_strs = config.log_handlers
handlers = list()
handlers = []
for handler in handler_strs:
handler_name, *args = handler.split("=", 2)
args = args[0] if len(args) > 0 else None
@ -398,7 +398,7 @@ class InvokeAILogger(object):
raise ValueError("syslog is not available on this system")
if not args:
args = "/dev/log" if Path("/dev/log").exists() else "address:localhost:514"
syslog_args = dict()
syslog_args = {}
try:
for a in args.split(","):
arg_name, *arg_value = a.split(":", 2)
@ -434,7 +434,7 @@ class InvokeAILogger(object):
path = url.path
port = url.port or 80
syslog_args = dict()
syslog_args = {}
for a in arg_list:
arg_name, *arg_value = a.split(":", 2)
if arg_name == "method":

View File

@ -29,7 +29,7 @@ def log_txt_as_img(wh, xc, size=10):
# wh a tuple of (width, height)
# xc a list of captions to plot
b = len(xc)
txts = list()
txts = []
for bi in range(b):
txt = Image.new("RGB", wh, color="white")
draw = ImageDraw.Draw(txt)
@ -93,7 +93,7 @@ def instantiate_from_config(config, **kwargs):
elif config == "__is_unconditional__":
return None
raise KeyError("Expected key `target` to instantiate.")
return get_obj_from_str(config["target"])(**config.get("params", dict()), **kwargs)
return get_obj_from_str(config["target"])(**config.get("params", {}), **kwargs)
def get_obj_from_str(string, reload=False):
@ -231,11 +231,12 @@ def rand_perlin_2d(shape, res, device, fade=lambda t: 6 * t**5 - 15 * t**4 + 10
angles = 2 * math.pi * rand_val
gradients = torch.stack((torch.cos(angles), torch.sin(angles)), dim=-1).to(device)
tile_grads = (
lambda slice1, slice2: gradients[slice1[0] : slice1[1], slice2[0] : slice2[1]]
.repeat_interleave(d[0], 0)
.repeat_interleave(d[1], 1)
)
def tile_grads(slice1, slice2):
return (
gradients[slice1[0] : slice1[1], slice2[0] : slice2[1]]
.repeat_interleave(d[0], 0)
.repeat_interleave(d[1], 1)
)
def dot(grad, shift):
return (

View File

@ -341,19 +341,19 @@ class InvokeAIMetadataParser:
# this was more elegant as a case statement, but that's not available in python 3.9
if old_scheduler is None:
return None
scheduler_map = dict(
ddim="ddim",
plms="pnmd",
k_lms="lms",
k_dpm_2="kdpm_2",
k_dpm_2_a="kdpm_2_a",
dpmpp_2="dpmpp_2s",
k_dpmpp_2="dpmpp_2m",
k_dpmpp_2_a=None, # invalid, in 2.3.x, selecting this sample would just fallback to last run or plms if new session
k_euler="euler",
k_euler_a="euler_a",
k_heun="heun",
)
scheduler_map = {
"ddim": "ddim",
"plms": "pnmd",
"k_lms": "lms",
"k_dpm_2": "kdpm_2",
"k_dpm_2_a": "kdpm_2_a",
"dpmpp_2": "dpmpp_2s",
"k_dpmpp_2": "dpmpp_2m",
"k_dpmpp_2_a": None, # invalid, in 2.3.x, selecting this sample would just fallback to last run or plms if new session
"k_euler": "euler",
"k_euler_a": "euler_a",
"k_heun": "heun",
}
return scheduler_map.get(old_scheduler)
def split_prompt(self, raw_prompt: str):

View File

@ -72,7 +72,7 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
def __init__(self, parentApp, name, multipage=False, *args, **keywords):
self.multipage = multipage
self.subprocess = None
super().__init__(parentApp=parentApp, name=name, *args, **keywords)
super().__init__(parentApp=parentApp, name=name, *args, **keywords) # noqa: B026 # TODO: maybe this is bad?
def create(self):
self.keypress_timeout = 10
@ -203,14 +203,14 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
)
# This restores the selected page on return from an installation
for i in range(1, self.current_tab + 1):
for _i in range(1, self.current_tab + 1):
self.tabs.h_cursor_line_down(1)
self._toggle_tables([self.current_tab])
############# diffusers tab ##########
def add_starter_pipelines(self) -> dict[str, npyscreen.widget]:
"""Add widgets responsible for selecting diffusers models"""
widgets = dict()
widgets = {}
models = self.all_models
starters = self.starter_models
starter_model_labels = self.model_labels
@ -258,10 +258,12 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
model_type: ModelType,
window_width: int = 120,
install_prompt: str = None,
exclude: set = set(),
exclude: set = None,
) -> dict[str, npyscreen.widget]:
"""Generic code to create model selection widgets"""
widgets = dict()
if exclude is None:
exclude = set()
widgets = {}
model_list = [x for x in self.all_models if self.all_models[x].model_type == model_type and x not in exclude]
model_labels = [self.model_labels[x] for x in model_list]
@ -366,13 +368,13 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
]
for group in widgets:
for k, v in group.items():
for _k, v in group.items():
try:
v.hidden = True
v.editable = False
except Exception:
pass
for k, v in widgets[selected_tab].items():
for _k, v in widgets[selected_tab].items():
try:
v.hidden = False
if not isinstance(v, (npyscreen.FixedText, npyscreen.TitleFixedText, CenteredTitleText)):
@ -391,7 +393,7 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
label_width = max([len(models[x].name) for x in models])
description_width = window_width - label_width - checkbox_width - spacing_width
result = dict()
result = {}
for x in models.keys():
description = models[x].description
description = (
@ -433,11 +435,11 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
parent_conn, child_conn = Pipe()
p = Process(
target=process_and_execute,
kwargs=dict(
opt=app.program_opts,
selections=app.install_selections,
conn_out=child_conn,
),
kwargs={
"opt": app.program_opts,
"selections": app.install_selections,
"conn_out": child_conn,
},
)
p.start()
child_conn.close()
@ -558,7 +560,7 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
for section in ui_sections:
if "models_selected" not in section:
continue
selected = set([section["models"][x] for x in section["models_selected"].value])
selected = {section["models"][x] for x in section["models_selected"].value}
models_to_install = [x for x in selected if not self.all_models[x].installed]
models_to_remove = [x for x in section["models"] if x not in selected and self.all_models[x].installed]
selections.remove_models.extend(models_to_remove)

View File

@ -11,6 +11,7 @@ import sys
import textwrap
from curses import BUTTON2_CLICKED, BUTTON3_CLICKED
from shutil import get_terminal_size
from typing import Optional
import npyscreen
import npyscreen.wgmultiline as wgmultiline
@ -243,7 +244,9 @@ class SelectColumnBase:
class MultiSelectColumns(SelectColumnBase, npyscreen.MultiSelect):
def __init__(self, screen, columns: int = 1, values: list = [], **keywords):
def __init__(self, screen, columns: int = 1, values: Optional[list] = None, **keywords):
if values is None:
values = []
self.columns = columns
self.value_cnt = len(values)
self.rows = math.ceil(self.value_cnt / self.columns)
@ -267,7 +270,9 @@ class SingleSelectWithChanged(npyscreen.SelectOne):
class SingleSelectColumnsSimple(SelectColumnBase, SingleSelectWithChanged):
"""Row of radio buttons. Spacebar to select."""
def __init__(self, screen, columns: int = 1, values: list = [], **keywords):
def __init__(self, screen, columns: int = 1, values: list = None, **keywords):
if values is None:
values = []
self.columns = columns
self.value_cnt = len(values)
self.rows = math.ceil(self.value_cnt / self.columns)

View File

@ -275,14 +275,14 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
interp = self.interpolations[self.merge_method.value[0]]
bases = ["sd-1", "sd-2", "sdxl"]
args = dict(
model_names=models,
base_model=BaseModelType(bases[self.base_select.value[0]]),
alpha=self.alpha.value,
interp=interp,
force=self.force.value,
merged_model_name=self.merged_model_name.value,
)
args = {
"model_names": models,
"base_model": BaseModelType(bases[self.base_select.value[0]]),
"alpha": self.alpha.value,
"interp": interp,
"force": self.force.value,
"merged_model_name": self.merged_model_name.value,
}
return args
def check_for_overwrite(self) -> bool:
@ -297,7 +297,7 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
def validate_field_values(self) -> bool:
bad_fields = []
model_names = self.model_names
selected_models = set((model_names[self.model1.value[0]], model_names[self.model2.value[0]]))
selected_models = {model_names[self.model1.value[0]], model_names[self.model2.value[0]]}
if self.model3.value[0] > 0:
selected_models.add(model_names[self.model3.value[0] - 1])
if len(selected_models) < 2:

View File

@ -276,13 +276,13 @@ class textualInversionForm(npyscreen.FormMultiPageAction):
def get_model_names(self) -> Tuple[List[str], int]:
conf = OmegaConf.load(config.root_dir / "configs/models.yaml")
model_names = [idx for idx in sorted(list(conf.keys())) if conf[idx].get("format", None) == "diffusers"]
model_names = [idx for idx in sorted(conf.keys()) if conf[idx].get("format", None) == "diffusers"]
defaults = [idx for idx in range(len(model_names)) if "default" in conf[model_names[idx]]]
default = defaults[0] if len(defaults) > 0 else 0
return (model_names, default)
def marshall_arguments(self) -> dict:
args = dict()
args = {}
# the choices
args.update(

View File

@ -24,6 +24,7 @@ module.exports = {
root: true,
rules: {
curly: 'error',
'react/jsx-no-bind': ['error', { allowBind: true }],
'react/jsx-curly-brace-presence': [
'error',
{ props: 'never', children: 'never' },

View File

@ -583,6 +583,7 @@
"strength": "Image to image strength",
"Threshold": "Noise Threshold",
"variations": "Seed-weight pairs",
"vae": "VAE",
"width": "Width",
"workflow": "Workflow"
},

View File

@ -1487,5 +1487,18 @@
"scheduler": "Campionatore",
"recallParameters": "Richiama i parametri",
"noRecallParameters": "Nessun parametro da richiamare trovato"
},
"hrf": {
"enableHrf": "Abilita Correzione Alta Risoluzione",
"upscaleMethod": "Metodo di ampliamento",
"enableHrfTooltip": "Genera con una risoluzione iniziale inferiore, esegue l'ampliamento alla risoluzione di base, quindi esegue Immagine a Immagine.",
"metadata": {
"strength": "Forza della Correzione Alta Risoluzione",
"enabled": "Correzione Alta Risoluzione Abilitata",
"method": "Metodo della Correzione Alta Risoluzione"
},
"hrf": "Correzione Alta Risoluzione",
"hrfStrength": "Forza della Correzione Alta Risoluzione",
"strengthTooltip": "Valori più bassi comportano meno dettagli, il che può ridurre potenziali artefatti."
}
}

View File

@ -19,7 +19,7 @@ import sdxlReducer from 'features/sdxl/store/sdxlSlice';
import configReducer from 'features/system/store/configSlice';
import systemReducer from 'features/system/store/systemSlice';
import queueReducer from 'features/queue/store/queueSlice';
import modelmanagerReducer from 'features/ui/components/tabs/ModelManager/store/modelManagerSlice';
import modelmanagerReducer from 'features/modelManager/store/modelManagerSlice';
import hotkeysReducer from 'features/ui/store/hotkeysSlice';
import uiReducer from 'features/ui/store/uiSlice';
import dynamicMiddlewares from 'redux-dynamic-middlewares';

View File

@ -8,7 +8,14 @@ import {
forwardRef,
useDisclosure,
} from '@chakra-ui/react';
import { cloneElement, memo, ReactElement, ReactNode, useRef } from 'react';
import {
cloneElement,
memo,
ReactElement,
ReactNode,
useCallback,
useRef,
} from 'react';
import { useTranslation } from 'react-i18next';
import IAIButton from './IAIButton';
@ -38,15 +45,15 @@ const IAIAlertDialog = forwardRef((props: Props, ref) => {
const { isOpen, onOpen, onClose } = useDisclosure();
const cancelRef = useRef<HTMLButtonElement | null>(null);
const handleAccept = () => {
const handleAccept = useCallback(() => {
acceptCallback();
onClose();
};
}, [acceptCallback, onClose]);
const handleCancel = () => {
const handleCancel = useCallback(() => {
cancelCallback && cancelCallback();
onClose();
};
}, [cancelCallback, onClose]);
return (
<>

View File

@ -1,9 +1,12 @@
import { Box, ChakraProps } from '@chakra-ui/react';
import { memo } from 'react';
import { ChakraProps, Flex } from '@chakra-ui/react';
import { memo, useCallback } from 'react';
import { RgbaColorPicker } from 'react-colorful';
import { ColorPickerBaseProps, RgbaColor } from 'react-colorful/dist/types';
import IAINumberInput from './IAINumberInput';
type IAIColorPickerProps = ColorPickerBaseProps<RgbaColor>;
type IAIColorPickerProps = ColorPickerBaseProps<RgbaColor> & {
withNumberInput?: boolean;
};
const colorPickerStyles: NonNullable<ChakraProps['sx']> = {
width: 6,
@ -11,17 +14,84 @@ const colorPickerStyles: NonNullable<ChakraProps['sx']> = {
borderColor: 'base.100',
};
const sx = {
const sx: ChakraProps['sx'] = {
'.react-colorful__hue-pointer': colorPickerStyles,
'.react-colorful__saturation-pointer': colorPickerStyles,
'.react-colorful__alpha-pointer': colorPickerStyles,
gap: 2,
flexDir: 'column',
};
const numberInputWidth: ChakraProps['w'] = '4.2rem';
const IAIColorPicker = (props: IAIColorPickerProps) => {
const { color, onChange, withNumberInput, ...rest } = props;
const handleChangeR = useCallback(
(r: number) => onChange({ ...color, r }),
[color, onChange]
);
const handleChangeG = useCallback(
(g: number) => onChange({ ...color, g }),
[color, onChange]
);
const handleChangeB = useCallback(
(b: number) => onChange({ ...color, b }),
[color, onChange]
);
const handleChangeA = useCallback(
(a: number) => onChange({ ...color, a }),
[color, onChange]
);
return (
<Box sx={sx}>
<RgbaColorPicker {...props} />
</Box>
<Flex sx={sx}>
<RgbaColorPicker
color={color}
onChange={onChange}
style={{ width: '100%' }}
{...rest}
/>
{withNumberInput && (
<Flex>
<IAINumberInput
value={color.r}
onChange={handleChangeR}
min={0}
max={255}
step={1}
label="Red"
w={numberInputWidth}
/>
<IAINumberInput
value={color.g}
onChange={handleChangeG}
min={0}
max={255}
step={1}
label="Green"
w={numberInputWidth}
/>
<IAINumberInput
value={color.b}
onChange={handleChangeB}
min={0}
max={255}
step={1}
label="Blue"
w={numberInputWidth}
/>
<IAINumberInput
value={color.a}
onChange={handleChangeA}
step={0.1}
min={0}
max={1}
label="Alpha"
w={numberInputWidth}
isInteger={false}
/>
</Flex>
)}
</Flex>
);
};

View File

@ -1,43 +0,0 @@
import { Box, Flex, Icon } from '@chakra-ui/react';
import { memo } from 'react';
import { FaExclamation } from 'react-icons/fa';
const IAIErrorLoadingImageFallback = () => {
return (
<Box
sx={{
position: 'relative',
height: 'full',
width: 'full',
'::before': {
content: "''",
display: 'block',
pt: '100%',
},
}}
>
<Flex
sx={{
position: 'absolute',
top: 0,
insetInlineStart: 0,
height: 'full',
width: 'full',
alignItems: 'center',
justifyContent: 'center',
borderRadius: 'base',
bg: 'base.100',
color: 'base.500',
_dark: {
color: 'base.700',
bg: 'base.850',
},
}}
>
<Icon as={FaExclamation} boxSize={16} opacity={0.7} />
</Flex>
</Box>
);
};
export default memo(IAIErrorLoadingImageFallback);

View File

@ -1,8 +0,0 @@
import { chakra } from '@chakra-ui/react';
/**
* Chakra-enabled <form />
*/
const IAIForm = chakra.form;
export default IAIForm;

View File

@ -1,15 +0,0 @@
import { FormErrorMessage, FormErrorMessageProps } from '@chakra-ui/react';
import { ReactNode } from 'react';
type IAIFormErrorMessageProps = FormErrorMessageProps & {
children: ReactNode | string;
};
export default function IAIFormErrorMessage(props: IAIFormErrorMessageProps) {
const { children, ...rest } = props;
return (
<FormErrorMessage color="error.400" {...rest}>
{children}
</FormErrorMessage>
);
}

View File

@ -1,15 +0,0 @@
import { FormHelperText, FormHelperTextProps } from '@chakra-ui/react';
import { ReactNode } from 'react';
type IAIFormHelperTextProps = FormHelperTextProps & {
children: ReactNode | string;
};
export default function IAIFormHelperText(props: IAIFormHelperTextProps) {
const { children, ...rest } = props;
return (
<FormHelperText margin={0} color="base.400" {...rest}>
{children}
</FormHelperText>
);
}

View File

@ -1,25 +0,0 @@
import { Flex, useColorMode } from '@chakra-ui/react';
import { ReactElement } from 'react';
import { mode } from 'theme/util/mode';
export function IAIFormItemWrapper({
children,
}: {
children: ReactElement | ReactElement[];
}) {
const { colorMode } = useColorMode();
return (
<Flex
sx={{
flexDirection: 'column',
padding: 4,
rowGap: 4,
borderRadius: 'base',
width: 'full',
bg: mode('base.100', 'base.900')(colorMode),
}}
>
{children}
</Flex>
);
}

View File

@ -1,25 +0,0 @@
import {
Checkbox,
CheckboxProps,
FormControl,
FormControlProps,
FormLabel,
} from '@chakra-ui/react';
import { memo, ReactNode } from 'react';
type IAIFullCheckboxProps = CheckboxProps & {
label: string | ReactNode;
formControlProps?: FormControlProps;
};
const IAIFullCheckbox = (props: IAIFullCheckboxProps) => {
const { label, formControlProps, ...rest } = props;
return (
<FormControl {...formControlProps}>
<FormLabel>{label}</FormLabel>
<Checkbox colorScheme="accent" {...rest} />
</FormControl>
);
};
export default memo(IAIFullCheckbox);

View File

@ -1,6 +1,7 @@
import { useColorMode } from '@chakra-ui/react';
import { TextInput, TextInputProps } from '@mantine/core';
import { useChakraThemeTokens } from 'common/hooks/useChakraThemeTokens';
import { useCallback } from 'react';
import { mode } from 'theme/util/mode';
type IAIMantineTextInputProps = TextInputProps;
@ -20,26 +21,37 @@ export default function IAIMantineTextInput(props: IAIMantineTextInputProps) {
} = useChakraThemeTokens();
const { colorMode } = useColorMode();
return (
<TextInput
styles={() => ({
input: {
color: mode(base900, base100)(colorMode),
backgroundColor: mode(base50, base900)(colorMode),
borderColor: mode(base200, base800)(colorMode),
borderWidth: 2,
outline: 'none',
':focus': {
borderColor: mode(accent300, accent500)(colorMode),
},
const stylesFunc = useCallback(
() => ({
input: {
color: mode(base900, base100)(colorMode),
backgroundColor: mode(base50, base900)(colorMode),
borderColor: mode(base200, base800)(colorMode),
borderWidth: 2,
outline: 'none',
':focus': {
borderColor: mode(accent300, accent500)(colorMode),
},
label: {
color: mode(base700, base300)(colorMode),
fontWeight: 'normal',
marginBottom: 4,
},
})}
{...rest}
/>
},
label: {
color: mode(base700, base300)(colorMode),
fontWeight: 'normal' as const,
marginBottom: 4,
},
}),
[
accent300,
accent500,
base100,
base200,
base300,
base50,
base700,
base800,
base900,
colorMode,
]
);
return <TextInput styles={stylesFunc} {...rest} />;
}

View File

@ -98,28 +98,34 @@ const IAINumberInput = forwardRef((props: Props, ref) => {
}
}, [value, valueAsString]);
const handleOnChange = (v: string) => {
setValueAsString(v);
// This allows negatives and decimals e.g. '-123', `.5`, `-0.2`, etc.
if (!v.match(numberStringRegex)) {
// Cast the value to number. Floor it if it should be an integer.
onChange(isInteger ? Math.floor(Number(v)) : Number(v));
}
};
const handleOnChange = useCallback(
(v: string) => {
setValueAsString(v);
// This allows negatives and decimals e.g. '-123', `.5`, `-0.2`, etc.
if (!v.match(numberStringRegex)) {
// Cast the value to number. Floor it if it should be an integer.
onChange(isInteger ? Math.floor(Number(v)) : Number(v));
}
},
[isInteger, onChange]
);
/**
* Clicking the steppers allows the value to go outside bounds; we need to
* clamp it on blur and floor it if needed.
*/
const handleBlur = (e: FocusEvent<HTMLInputElement>) => {
const clamped = clamp(
isInteger ? Math.floor(Number(e.target.value)) : Number(e.target.value),
min,
max
);
setValueAsString(String(clamped));
onChange(clamped);
};
const handleBlur = useCallback(
(e: FocusEvent<HTMLInputElement>) => {
const clamped = clamp(
isInteger ? Math.floor(Number(e.target.value)) : Number(e.target.value),
min,
max
);
setValueAsString(String(clamped));
onChange(clamped);
},
[isInteger, max, min, onChange]
);
const handleKeyDown = useCallback(
(e: KeyboardEvent<HTMLInputElement>) => {

View File

@ -6,7 +6,7 @@ import {
Tooltip,
TooltipProps,
} from '@chakra-ui/react';
import { memo, MouseEvent } from 'react';
import { memo, MouseEvent, useCallback } from 'react';
import IAIOption from './IAIOption';
type IAISelectProps = SelectProps & {
@ -33,15 +33,16 @@ const IAISelect = (props: IAISelectProps) => {
spaceEvenly,
...rest
} = props;
const handleClick = useCallback((e: MouseEvent<HTMLDivElement>) => {
e.stopPropagation();
e.nativeEvent.stopImmediatePropagation();
e.nativeEvent.stopPropagation();
e.nativeEvent.cancelBubble = true;
}, []);
return (
<FormControl
isDisabled={isDisabled}
onClick={(e: MouseEvent<HTMLDivElement>) => {
e.stopPropagation();
e.nativeEvent.stopImmediatePropagation();
e.nativeEvent.stopPropagation();
e.nativeEvent.cancelBubble = true;
}}
onClick={handleClick}
sx={
horizontal
? {

View File

@ -186,6 +186,13 @@ const IAISlider = forwardRef((props: IAIFullSliderProps, ref) => {
[dispatch]
);
const handleMouseEnter = useCallback(() => setShowTooltip(true), []);
const handleMouseLeave = useCallback(() => setShowTooltip(false), []);
const handleStepperClick = useCallback(
() => onChange(Number(localInputValue)),
[localInputValue, onChange]
);
return (
<FormControl
ref={ref}
@ -219,8 +226,8 @@ const IAISlider = forwardRef((props: IAIFullSliderProps, ref) => {
max={max}
step={step}
onChange={handleSliderChange}
onMouseEnter={() => setShowTooltip(true)}
onMouseLeave={() => setShowTooltip(false)}
onMouseEnter={handleMouseEnter}
onMouseLeave={handleMouseLeave}
focusThumbOnChange={false}
isDisabled={isDisabled}
{...rest}
@ -332,12 +339,8 @@ const IAISlider = forwardRef((props: IAIFullSliderProps, ref) => {
{...sliderNumberInputFieldProps}
/>
<NumberInputStepper {...sliderNumberInputStepperProps}>
<NumberIncrementStepper
onClick={() => onChange(Number(localInputValue))}
/>
<NumberDecrementStepper
onClick={() => onChange(Number(localInputValue))}
/>
<NumberIncrementStepper onClick={handleStepperClick} />
<NumberDecrementStepper onClick={handleStepperClick} />
</NumberInputStepper>
</NumberInput>
)}

View File

@ -146,16 +146,15 @@ const ImageUploader = (props: ImageUploaderProps) => {
};
}, [inputRef]);
const handleKeyDown = useCallback((e: KeyboardEvent) => {
// Bail out if user hits spacebar - do not open the uploader
if (e.key === ' ') {
return;
}
}, []);
return (
<Box
{...getRootProps({ style: {} })}
onKeyDown={(e: KeyboardEvent) => {
// Bail out if user hits spacebar - do not open the uploader
if (e.key === ' ') {
return;
}
}}
>
<Box {...getRootProps({ style: {} })} onKeyDown={handleKeyDown}>
<input {...getInputProps()} />
{children}
<AnimatePresence>

View File

@ -1,23 +0,0 @@
import { Flex, Icon } from '@chakra-ui/react';
import { memo } from 'react';
import { FaImage } from 'react-icons/fa';
const SelectImagePlaceholder = () => {
return (
<Flex
sx={{
w: 'full',
h: 'full',
// bg: 'base.800',
borderRadius: 'base',
alignItems: 'center',
justifyContent: 'center',
aspectRatio: '1/1',
}}
>
<Icon color="base.400" boxSize={32} as={FaImage}></Icon>
</Flex>
);
};
export default memo(SelectImagePlaceholder);

View File

@ -1,24 +0,0 @@
import { useBreakpoint } from '@chakra-ui/react';
export default function useResolution():
| 'mobile'
| 'tablet'
| 'desktop'
| 'unknown' {
const breakpointValue = useBreakpoint();
const mobileResolutions = ['base', 'sm'];
const tabletResolutions = ['md', 'lg'];
const desktopResolutions = ['xl', '2xl'];
if (mobileResolutions.includes(breakpointValue)) {
return 'mobile';
}
if (tabletResolutions.includes(breakpointValue)) {
return 'tablet';
}
if (desktopResolutions.includes(breakpointValue)) {
return 'desktop';
}
return 'unknown';
}

View File

@ -1,7 +0,0 @@
import dateFormat from 'dateformat';
/**
* Get a `now` timestamp with 1s precision, formatted as ISO datetime.
*/
export const getTimestamp = () =>
dateFormat(new Date(), `yyyy-mm-dd'T'HH:MM:ss:lo`);

View File

@ -1,71 +0,0 @@
// TODO: Restore variations
// Support code from v2.3 in here.
// export const stringToSeedWeights = (
// string: string
// ): InvokeAI.SeedWeights | boolean => {
// const stringPairs = string.split(',');
// const arrPairs = stringPairs.map((p) => p.split(':'));
// const pairs = arrPairs.map((p: Array<string>): InvokeAI.SeedWeightPair => {
// return { seed: Number(p[0]), weight: Number(p[1]) };
// });
// if (!validateSeedWeights(pairs)) {
// return false;
// }
// return pairs;
// };
// export const validateSeedWeights = (
// seedWeights: InvokeAI.SeedWeights | string
// ): boolean => {
// return typeof seedWeights === 'string'
// ? Boolean(stringToSeedWeights(seedWeights))
// : Boolean(
// seedWeights.length &&
// !seedWeights.some((pair: InvokeAI.SeedWeightPair) => {
// const { seed, weight } = pair;
// const isSeedValid = !isNaN(parseInt(seed.toString(), 10));
// const isWeightValid =
// !isNaN(parseInt(weight.toString(), 10)) &&
// weight >= 0 &&
// weight <= 1;
// return !(isSeedValid && isWeightValid);
// })
// );
// };
// export const seedWeightsToString = (
// seedWeights: InvokeAI.SeedWeights
// ): string => {
// return seedWeights.reduce((acc, pair, i, arr) => {
// const { seed, weight } = pair;
// acc += `${seed}:${weight}`;
// if (i !== arr.length - 1) {
// acc += ',';
// }
// return acc;
// }, '');
// };
// export const seedWeightsToArray = (
// seedWeights: InvokeAI.SeedWeights
// ): Array<Array<number>> => {
// return seedWeights.map((pair: InvokeAI.SeedWeightPair) => [
// pair.seed,
// pair.weight,
// ]);
// };
// export const stringToSeedWeightsArray = (
// string: string
// ): Array<Array<number>> => {
// const stringPairs = string.split(',');
// const arrPairs = stringPairs.map((p) => p.split(':'));
// return arrPairs.map(
// (p: Array<string>): Array<number> => [parseInt(p[0], 10), parseFloat(p[1])]
// );
// };
export default {};

Some files were not shown because too many files have changed in this diff Show More