chore: ruff check - fix flake8-comprensions

This commit is contained in:
psychedelicious 2023-11-11 10:44:43 +11:00
parent 43f2398e14
commit 3a136420d5
60 changed files with 489 additions and 512 deletions

View File

@ -28,7 +28,7 @@ class FastAPIEventService(EventServiceBase):
self.__queue.put(None) self.__queue.put(None)
def dispatch(self, event_name: str, payload: Any) -> None: def dispatch(self, event_name: str, payload: Any) -> None:
self.__queue.put(dict(event_name=event_name, payload=payload)) self.__queue.put({"event_name": event_name, "payload": payload})
async def __dispatch_from_queue(self, stop_event: threading.Event): async def __dispatch_from_queue(self, stop_event: threading.Event):
"""Get events on from the queue and dispatch them, from the correct thread""" """Get events on from the queue and dispatch them, from the correct thread"""

View File

@ -55,7 +55,7 @@ async def list_models(
) -> ModelsList: ) -> ModelsList:
"""Gets a list of models""" """Gets a list of models"""
if base_models and len(base_models) > 0: if base_models and len(base_models) > 0:
models_raw = list() models_raw = []
for base_model in base_models: for base_model in base_models:
models_raw.extend(ApiDependencies.invoker.services.model_manager.list_models(base_model, model_type)) models_raw.extend(ApiDependencies.invoker.services.model_manager.list_models(base_model, model_type))
else: else:

View File

@ -130,7 +130,7 @@ def custom_openapi() -> dict[str, Any]:
# Add all outputs # Add all outputs
all_invocations = BaseInvocation.get_invocations() all_invocations = BaseInvocation.get_invocations()
output_types = set() output_types = set()
output_type_titles = dict() output_type_titles = {}
for invoker in all_invocations: for invoker in all_invocations:
output_type = signature(invoker.invoke).return_annotation output_type = signature(invoker.invoke).return_annotation
output_types.add(output_type) output_types.add(output_type)
@ -171,12 +171,12 @@ def custom_openapi() -> dict[str, Any]:
# print(f"Config with name {name} already defined") # print(f"Config with name {name} already defined")
continue continue
openapi_schema["components"]["schemas"][name] = dict( openapi_schema["components"]["schemas"][name] = {
title=name, "title": name,
description="An enumeration.", "description": "An enumeration.",
type="string", "type": "string",
enum=list(v.value for v in model_config_format_enum), "enum": [v.value for v in model_config_format_enum],
) }
app.openapi_schema = openapi_schema app.openapi_schema = openapi_schema
return app.openapi_schema return app.openapi_schema

View File

@ -25,4 +25,4 @@ spec.loader.exec_module(module)
# add core nodes to __all__ # add core nodes to __all__
python_files = filter(lambda f: not f.name.startswith("_"), Path(__file__).parent.glob("*.py")) python_files = filter(lambda f: not f.name.startswith("_"), Path(__file__).parent.glob("*.py"))
__all__ = list(f.stem for f in python_files) # type: ignore __all__ = [f.stem for f in python_files] # type: ignore

View File

@ -236,35 +236,35 @@ def InputField(
Ignored for non-collection fields. Ignored for non-collection fields.
""" """
json_schema_extra_: dict[str, Any] = dict( json_schema_extra_: dict[str, Any] = {
input=input, "input": input,
ui_type=ui_type, "ui_type": ui_type,
ui_component=ui_component, "ui_component": ui_component,
ui_hidden=ui_hidden, "ui_hidden": ui_hidden,
ui_order=ui_order, "ui_order": ui_order,
item_default=item_default, "item_default": item_default,
ui_choice_labels=ui_choice_labels, "ui_choice_labels": ui_choice_labels,
_field_kind="input", "_field_kind": "input",
) }
field_args = dict( field_args = {
default=default, "default": default,
default_factory=default_factory, "default_factory": default_factory,
title=title, "title": title,
description=description, "description": description,
pattern=pattern, "pattern": pattern,
strict=strict, "strict": strict,
gt=gt, "gt": gt,
ge=ge, "ge": ge,
lt=lt, "lt": lt,
le=le, "le": le,
multiple_of=multiple_of, "multiple_of": multiple_of,
allow_inf_nan=allow_inf_nan, "allow_inf_nan": allow_inf_nan,
max_digits=max_digits, "max_digits": max_digits,
decimal_places=decimal_places, "decimal_places": decimal_places,
min_length=min_length, "min_length": min_length,
max_length=max_length, "max_length": max_length,
) }
""" """
Invocation definitions have their fields typed correctly for their `invoke()` functions. Invocation definitions have their fields typed correctly for their `invoke()` functions.
@ -299,24 +299,24 @@ def InputField(
# because we are manually making fields optional, we need to store the original required bool for reference later # because we are manually making fields optional, we need to store the original required bool for reference later
if default is PydanticUndefined and default_factory is PydanticUndefined: if default is PydanticUndefined and default_factory is PydanticUndefined:
json_schema_extra_.update(dict(orig_required=True)) json_schema_extra_.update({"orig_required": True})
else: else:
json_schema_extra_.update(dict(orig_required=False)) json_schema_extra_.update({"orig_required": False})
# make Input.Any and Input.Connection fields optional, providing None as a default if the field doesn't already have one # make Input.Any and Input.Connection fields optional, providing None as a default if the field doesn't already have one
if (input is Input.Any or input is Input.Connection) and default_factory is PydanticUndefined: if (input is Input.Any or input is Input.Connection) and default_factory is PydanticUndefined:
default_ = None if default is PydanticUndefined else default default_ = None if default is PydanticUndefined else default
provided_args.update(dict(default=default_)) provided_args.update({"default": default_})
if default is not PydanticUndefined: if default is not PydanticUndefined:
# before invoking, we'll grab the original default value and set it on the field if the field wasn't provided a value # before invoking, we'll grab the original default value and set it on the field if the field wasn't provided a value
json_schema_extra_.update(dict(default=default)) json_schema_extra_.update({"default": default})
json_schema_extra_.update(dict(orig_default=default)) json_schema_extra_.update({"orig_default": default})
elif default is not PydanticUndefined and default_factory is PydanticUndefined: elif default is not PydanticUndefined and default_factory is PydanticUndefined:
default_ = default default_ = default
provided_args.update(dict(default=default_)) provided_args.update({"default": default_})
json_schema_extra_.update(dict(orig_default=default_)) json_schema_extra_.update({"orig_default": default_})
elif default_factory is not PydanticUndefined: elif default_factory is not PydanticUndefined:
provided_args.update(dict(default_factory=default_factory)) provided_args.update({"default_factory": default_factory})
# TODO: cannot serialize default_factory... # TODO: cannot serialize default_factory...
# json_schema_extra_.update(dict(orig_default_factory=default_factory)) # json_schema_extra_.update(dict(orig_default_factory=default_factory))
@ -383,12 +383,12 @@ def OutputField(
decimal_places=decimal_places, decimal_places=decimal_places,
min_length=min_length, min_length=min_length,
max_length=max_length, max_length=max_length,
json_schema_extra=dict( json_schema_extra={
ui_type=ui_type, "ui_type": ui_type,
ui_hidden=ui_hidden, "ui_hidden": ui_hidden,
ui_order=ui_order, "ui_order": ui_order,
_field_kind="output", "_field_kind": "output",
), },
) )
@ -460,14 +460,14 @@ class BaseInvocationOutput(BaseModel):
@classmethod @classmethod
def get_output_types(cls) -> Iterable[str]: def get_output_types(cls) -> Iterable[str]:
return map(lambda i: get_type(i), BaseInvocationOutput.get_outputs()) return (get_type(i) for i in BaseInvocationOutput.get_outputs())
@staticmethod @staticmethod
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None: def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None:
# Because we use a pydantic Literal field with default value for the invocation type, # Because we use a pydantic Literal field with default value for the invocation type,
# it will be typed as optional in the OpenAPI schema. Make it required manually. # it will be typed as optional in the OpenAPI schema. Make it required manually.
if "required" not in schema or not isinstance(schema["required"], list): if "required" not in schema or not isinstance(schema["required"], list):
schema["required"] = list() schema["required"] = []
schema["required"].extend(["type"]) schema["required"].extend(["type"])
model_config = ConfigDict( model_config = ConfigDict(
@ -527,16 +527,11 @@ class BaseInvocation(ABC, BaseModel):
@classmethod @classmethod
def get_invocations_map(cls) -> dict[str, BaseInvocation]: def get_invocations_map(cls) -> dict[str, BaseInvocation]:
# Get the type strings out of the literals and into a dictionary # Get the type strings out of the literals and into a dictionary
return dict( return {get_type(i): i for i in BaseInvocation.get_invocations()}
map(
lambda i: (get_type(i), i),
BaseInvocation.get_invocations(),
)
)
@classmethod @classmethod
def get_invocation_types(cls) -> Iterable[str]: def get_invocation_types(cls) -> Iterable[str]:
return map(lambda i: get_type(i), BaseInvocation.get_invocations()) return (get_type(i) for i in BaseInvocation.get_invocations())
@classmethod @classmethod
def get_output_type(cls) -> BaseInvocationOutput: def get_output_type(cls) -> BaseInvocationOutput:
@ -555,7 +550,7 @@ class BaseInvocation(ABC, BaseModel):
if uiconfig and hasattr(uiconfig, "version"): if uiconfig and hasattr(uiconfig, "version"):
schema["version"] = uiconfig.version schema["version"] = uiconfig.version
if "required" not in schema or not isinstance(schema["required"], list): if "required" not in schema or not isinstance(schema["required"], list):
schema["required"] = list() schema["required"] = []
schema["required"].extend(["type", "id"]) schema["required"].extend(["type", "id"])
@abstractmethod @abstractmethod
@ -609,15 +604,15 @@ class BaseInvocation(ABC, BaseModel):
id: str = Field( id: str = Field(
default_factory=uuid_string, default_factory=uuid_string,
description="The id of this instance of an invocation. Must be unique among all instances of invocations.", description="The id of this instance of an invocation. Must be unique among all instances of invocations.",
json_schema_extra=dict(_field_kind="internal"), json_schema_extra={"_field_kind": "internal"},
) )
is_intermediate: bool = Field( is_intermediate: bool = Field(
default=False, default=False,
description="Whether or not this is an intermediate invocation.", description="Whether or not this is an intermediate invocation.",
json_schema_extra=dict(ui_type=UIType.IsIntermediate, _field_kind="internal"), json_schema_extra={"ui_type": UIType.IsIntermediate, "_field_kind": "internal"},
) )
use_cache: bool = Field( use_cache: bool = Field(
default=True, description="Whether or not to use the cache", json_schema_extra=dict(_field_kind="internal") default=True, description="Whether or not to use the cache", json_schema_extra={"_field_kind": "internal"}
) )
UIConfig: ClassVar[Type[UIConfigBase]] UIConfig: ClassVar[Type[UIConfigBase]]
@ -651,7 +646,7 @@ class _Model(BaseModel):
# Get all pydantic model attrs, methods, etc # Get all pydantic model attrs, methods, etc
RESERVED_PYDANTIC_FIELD_NAMES = set(map(lambda m: m[0], inspect.getmembers(_Model()))) RESERVED_PYDANTIC_FIELD_NAMES = {m[0] for m in inspect.getmembers(_Model())}
def validate_fields(model_fields: dict[str, FieldInfo], model_type: str) -> None: def validate_fields(model_fields: dict[str, FieldInfo], model_type: str) -> None:
@ -729,7 +724,7 @@ def invocation(
# Add OpenAPI schema extras # Add OpenAPI schema extras
uiconf_name = cls.__qualname__ + ".UIConfig" uiconf_name = cls.__qualname__ + ".UIConfig"
if not hasattr(cls, "UIConfig") or cls.UIConfig.__qualname__ != uiconf_name: if not hasattr(cls, "UIConfig") or cls.UIConfig.__qualname__ != uiconf_name:
cls.UIConfig = type(uiconf_name, (UIConfigBase,), dict()) cls.UIConfig = type(uiconf_name, (UIConfigBase,), {})
if title is not None: if title is not None:
cls.UIConfig.title = title cls.UIConfig.title = title
if tags is not None: if tags is not None:
@ -756,7 +751,7 @@ def invocation(
invocation_type_annotation = Literal[invocation_type] # type: ignore invocation_type_annotation = Literal[invocation_type] # type: ignore
invocation_type_field = Field( invocation_type_field = Field(
title="type", default=invocation_type, json_schema_extra=dict(_field_kind="internal") title="type", default=invocation_type, json_schema_extra={"_field_kind": "internal"}
) )
docstring = cls.__doc__ docstring = cls.__doc__
@ -802,7 +797,7 @@ def invocation_output(
# Add the output type to the model. # Add the output type to the model.
output_type_annotation = Literal[output_type] # type: ignore output_type_annotation = Literal[output_type] # type: ignore
output_type_field = Field(title="type", default=output_type, json_schema_extra=dict(_field_kind="internal")) output_type_field = Field(title="type", default=output_type, json_schema_extra={"_field_kind": "internal"})
docstring = cls.__doc__ docstring = cls.__doc__
cls = create_model( cls = create_model(
@ -834,7 +829,7 @@ WorkflowFieldValidator = TypeAdapter(WorkflowField)
class WithWorkflow(BaseModel): class WithWorkflow(BaseModel):
workflow: Optional[WorkflowField] = Field( workflow: Optional[WorkflowField] = Field(
default=None, description=FieldDescriptions.workflow, json_schema_extra=dict(_field_kind="internal") default=None, description=FieldDescriptions.workflow, json_schema_extra={"_field_kind": "internal"}
) )
@ -852,5 +847,5 @@ MetadataFieldValidator = TypeAdapter(MetadataField)
class WithMetadata(BaseModel): class WithMetadata(BaseModel):
metadata: Optional[MetadataField] = Field( metadata: Optional[MetadataField] = Field(
default=None, description=FieldDescriptions.metadata, json_schema_extra=dict(_field_kind="internal") default=None, description=FieldDescriptions.metadata, json_schema_extra={"_field_kind": "internal"}
) )

View File

@ -131,7 +131,7 @@ def prepare_faces_list(
deduped_faces: list[FaceResultData] = [] deduped_faces: list[FaceResultData] = []
if len(face_result_list) == 0: if len(face_result_list) == 0:
return list() return []
for candidate in face_result_list: for candidate in face_result_list:
should_add = True should_add = True

View File

@ -77,7 +77,7 @@ if choose_torch_device() == torch.device("mps"):
DEFAULT_PRECISION = choose_precision(choose_torch_device()) DEFAULT_PRECISION = choose_precision(choose_torch_device())
SAMPLER_NAME_VALUES = Literal[tuple(list(SCHEDULER_MAP.keys()))] SAMPLER_NAME_VALUES = Literal[tuple(SCHEDULER_MAP.keys())]
@invocation_output("scheduler_output") @invocation_output("scheduler_output")

View File

@ -145,17 +145,17 @@ INTEGER_OPERATIONS = Literal[
] ]
INTEGER_OPERATIONS_LABELS = dict( INTEGER_OPERATIONS_LABELS = {
ADD="Add A+B", "ADD": "Add A+B",
SUB="Subtract A-B", "SUB": "Subtract A-B",
MUL="Multiply A*B", "MUL": "Multiply A*B",
DIV="Divide A/B", "DIV": "Divide A/B",
EXP="Exponentiate A^B", "EXP": "Exponentiate A^B",
MOD="Modulus A%B", "MOD": "Modulus A%B",
ABS="Absolute Value of A", "ABS": "Absolute Value of A",
MIN="Minimum(A,B)", "MIN": "Minimum(A,B)",
MAX="Maximum(A,B)", "MAX": "Maximum(A,B)",
) }
@invocation( @invocation(
@ -231,17 +231,17 @@ FLOAT_OPERATIONS = Literal[
] ]
FLOAT_OPERATIONS_LABELS = dict( FLOAT_OPERATIONS_LABELS = {
ADD="Add A+B", "ADD": "Add A+B",
SUB="Subtract A-B", "SUB": "Subtract A-B",
MUL="Multiply A*B", "MUL": "Multiply A*B",
DIV="Divide A/B", "DIV": "Divide A/B",
EXP="Exponentiate A^B", "EXP": "Exponentiate A^B",
ABS="Absolute Value of A", "ABS": "Absolute Value of A",
SQRT="Square Root of A", "SQRT": "Square Root of A",
MIN="Minimum(A,B)", "MIN": "Minimum(A,B)",
MAX="Maximum(A,B)", "MAX": "Maximum(A,B)",
) }
@invocation( @invocation(

View File

@ -54,7 +54,7 @@ ORT_TO_NP_TYPE = {
"tensor(double)": np.float64, "tensor(double)": np.float64,
} }
PRECISION_VALUES = Literal[tuple(list(ORT_TO_NP_TYPE.keys()))] PRECISION_VALUES = Literal[tuple(ORT_TO_NP_TYPE.keys())]
@invocation("prompt_onnx", title="ONNX Prompt (Raw)", tags=["prompt", "onnx"], category="conditioning", version="1.0.0") @invocation("prompt_onnx", title="ONNX Prompt (Raw)", tags=["prompt", "onnx"], category="conditioning", version="1.0.0")
@ -252,7 +252,7 @@ class ONNXTextToLatentsInvocation(BaseInvocation):
scheduler.set_timesteps(self.steps) scheduler.set_timesteps(self.steps)
latents = latents * np.float64(scheduler.init_noise_sigma) latents = latents * np.float64(scheduler.init_noise_sigma)
extra_step_kwargs = dict() extra_step_kwargs = {}
if "eta" in set(inspect.signature(scheduler.step).parameters.keys()): if "eta" in set(inspect.signature(scheduler.step).parameters.keys()):
extra_step_kwargs.update( extra_step_kwargs.update(
eta=0.0, eta=0.0,

View File

@ -100,7 +100,7 @@ EASING_FUNCTIONS_MAP = {
"BounceInOut": BounceEaseInOut, "BounceInOut": BounceEaseInOut,
} }
EASING_FUNCTION_KEYS = Literal[tuple(list(EASING_FUNCTIONS_MAP.keys()))] EASING_FUNCTION_KEYS = Literal[tuple(EASING_FUNCTIONS_MAP.keys())]
# actually I think for now could just use CollectionOutput (which is list[Any] # actually I think for now could just use CollectionOutput (which is list[Any]
@ -161,7 +161,7 @@ class StepParamEasingInvocation(BaseInvocation):
easing_class = EASING_FUNCTIONS_MAP[self.easing] easing_class = EASING_FUNCTIONS_MAP[self.easing]
if log_diagnostics: if log_diagnostics:
context.services.logger.debug("easing class: " + str(easing_class)) context.services.logger.debug("easing class: " + str(easing_class))
easing_list = list() easing_list = []
if self.mirror: # "expected" mirroring if self.mirror: # "expected" mirroring
# if number of steps is even, squeeze duration down to (number_of_steps)/2 # if number of steps is even, squeeze duration down to (number_of_steps)/2
# and create reverse copy of list to append # and create reverse copy of list to append
@ -178,7 +178,7 @@ class StepParamEasingInvocation(BaseInvocation):
end=self.end_value, end=self.end_value,
duration=base_easing_duration - 1, duration=base_easing_duration - 1,
) )
base_easing_vals = list() base_easing_vals = []
for step_index in range(base_easing_duration): for step_index in range(base_easing_duration):
easing_val = easing_function.ease(step_index) easing_val = easing_function.ease(step_index)
base_easing_vals.append(easing_val) base_easing_vals.append(easing_val)

View File

@ -139,7 +139,7 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
(board_id,), (board_id,),
) )
result = cast(list[sqlite3.Row], self._cursor.fetchall()) result = cast(list[sqlite3.Row], self._cursor.fetchall())
images = list(map(lambda r: deserialize_image_record(dict(r)), result)) images = [deserialize_image_record(dict(r)) for r in result]
self._cursor.execute( self._cursor.execute(
"""--sql """--sql
@ -167,7 +167,7 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
(board_id,), (board_id,),
) )
result = cast(list[sqlite3.Row], self._cursor.fetchall()) result = cast(list[sqlite3.Row], self._cursor.fetchall())
image_names = list(map(lambda r: r[0], result)) image_names = [r[0] for r in result]
return image_names return image_names
except sqlite3.Error as e: except sqlite3.Error as e:
self._conn.rollback() self._conn.rollback()

View File

@ -199,7 +199,7 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
) )
result = cast(list[sqlite3.Row], self._cursor.fetchall()) result = cast(list[sqlite3.Row], self._cursor.fetchall())
boards = list(map(lambda r: deserialize_board_record(dict(r)), result)) boards = [deserialize_board_record(dict(r)) for r in result]
# Get the total number of boards # Get the total number of boards
self._cursor.execute( self._cursor.execute(
@ -236,7 +236,7 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
) )
result = cast(list[sqlite3.Row], self._cursor.fetchall()) result = cast(list[sqlite3.Row], self._cursor.fetchall())
boards = list(map(lambda r: deserialize_board_record(dict(r)), result)) boards = [deserialize_board_record(dict(r)) for r in result]
return boards return boards

View File

@ -55,7 +55,7 @@ class InvokeAISettings(BaseSettings):
""" """
cls = self.__class__ cls = self.__class__
type = get_args(get_type_hints(cls)["type"])[0] type = get_args(get_type_hints(cls)["type"])[0]
field_dict = dict({type: dict()}) field_dict = {type: {}}
for name, field in self.model_fields.items(): for name, field in self.model_fields.items():
if name in cls._excluded_from_yaml(): if name in cls._excluded_from_yaml():
continue continue
@ -64,7 +64,7 @@ class InvokeAISettings(BaseSettings):
) )
value = getattr(self, name) value = getattr(self, name)
if category not in field_dict[type]: if category not in field_dict[type]:
field_dict[type][category] = dict() field_dict[type][category] = {}
# keep paths as strings to make it easier to read # keep paths as strings to make it easier to read
field_dict[type][category][name] = str(value) if isinstance(value, Path) else value field_dict[type][category][name] = str(value) if isinstance(value, Path) else value
conf = OmegaConf.create(field_dict) conf = OmegaConf.create(field_dict)
@ -89,7 +89,7 @@ class InvokeAISettings(BaseSettings):
# create an upcase version of the environment in # create an upcase version of the environment in
# order to achieve case-insensitive environment # order to achieve case-insensitive environment
# variables (the way Windows does) # variables (the way Windows does)
upcase_environ = dict() upcase_environ = {}
for key, value in os.environ.items(): for key, value in os.environ.items():
upcase_environ[key.upper()] = value upcase_environ[key.upper()] = value

View File

@ -188,18 +188,18 @@ DEFAULT_MAX_VRAM = 0.5
class Categories(object): class Categories(object):
WebServer = dict(category="Web Server") WebServer = {"category": "Web Server"}
Features = dict(category="Features") Features = {"category": "Features"}
Paths = dict(category="Paths") Paths = {"category": "Paths"}
Logging = dict(category="Logging") Logging = {"category": "Logging"}
Development = dict(category="Development") Development = {"category": "Development"}
Other = dict(category="Other") Other = {"category": "Other"}
ModelCache = dict(category="Model Cache") ModelCache = {"category": "Model Cache"}
Device = dict(category="Device") Device = {"category": "Device"}
Generation = dict(category="Generation") Generation = {"category": "Generation"}
Queue = dict(category="Queue") Queue = {"category": "Queue"}
Nodes = dict(category="Nodes") Nodes = {"category": "Nodes"}
MemoryPerformance = dict(category="Memory/Performance") MemoryPerformance = {"category": "Memory/Performance"}
class InvokeAIAppConfig(InvokeAISettings): class InvokeAIAppConfig(InvokeAISettings):
@ -482,7 +482,7 @@ def _find_root() -> Path:
venv = Path(os.environ.get("VIRTUAL_ENV") or ".") venv = Path(os.environ.get("VIRTUAL_ENV") or ".")
if os.environ.get("INVOKEAI_ROOT"): if os.environ.get("INVOKEAI_ROOT"):
root = Path(os.environ["INVOKEAI_ROOT"]) root = Path(os.environ["INVOKEAI_ROOT"])
elif any([(venv.parent / x).exists() for x in [INIT_FILE, LEGACY_INIT_FILE]]): elif any((venv.parent / x).exists() for x in [INIT_FILE, LEGACY_INIT_FILE]):
root = (venv.parent).resolve() root = (venv.parent).resolve()
else: else:
root = Path("~/invokeai").expanduser().resolve() root = Path("~/invokeai").expanduser().resolve()

View File

@ -27,7 +27,7 @@ class EventServiceBase:
payload["timestamp"] = get_timestamp() payload["timestamp"] = get_timestamp()
self.dispatch( self.dispatch(
event_name=EventServiceBase.queue_event, event_name=EventServiceBase.queue_event,
payload=dict(event=event_name, data=payload), payload={"event": event_name, "data": payload},
) )
# Define events here for every event in the system. # Define events here for every event in the system.
@ -48,18 +48,18 @@ class EventServiceBase:
"""Emitted when there is generation progress""" """Emitted when there is generation progress"""
self.__emit_queue_event( self.__emit_queue_event(
event_name="generator_progress", event_name="generator_progress",
payload=dict( payload={
queue_id=queue_id, "queue_id": queue_id,
queue_item_id=queue_item_id, "queue_item_id": queue_item_id,
queue_batch_id=queue_batch_id, "queue_batch_id": queue_batch_id,
graph_execution_state_id=graph_execution_state_id, "graph_execution_state_id": graph_execution_state_id,
node_id=node.get("id"), "node_id": node.get("id"),
source_node_id=source_node_id, "source_node_id": source_node_id,
progress_image=progress_image.model_dump() if progress_image is not None else None, "progress_image": progress_image.model_dump() if progress_image is not None else None,
step=step, "step": step,
order=order, "order": order,
total_steps=total_steps, "total_steps": total_steps,
), },
) )
def emit_invocation_complete( def emit_invocation_complete(
@ -75,15 +75,15 @@ class EventServiceBase:
"""Emitted when an invocation has completed""" """Emitted when an invocation has completed"""
self.__emit_queue_event( self.__emit_queue_event(
event_name="invocation_complete", event_name="invocation_complete",
payload=dict( payload={
queue_id=queue_id, "queue_id": queue_id,
queue_item_id=queue_item_id, "queue_item_id": queue_item_id,
queue_batch_id=queue_batch_id, "queue_batch_id": queue_batch_id,
graph_execution_state_id=graph_execution_state_id, "graph_execution_state_id": graph_execution_state_id,
node=node, "node": node,
source_node_id=source_node_id, "source_node_id": source_node_id,
result=result, "result": result,
), },
) )
def emit_invocation_error( def emit_invocation_error(
@ -100,16 +100,16 @@ class EventServiceBase:
"""Emitted when an invocation has completed""" """Emitted when an invocation has completed"""
self.__emit_queue_event( self.__emit_queue_event(
event_name="invocation_error", event_name="invocation_error",
payload=dict( payload={
queue_id=queue_id, "queue_id": queue_id,
queue_item_id=queue_item_id, "queue_item_id": queue_item_id,
queue_batch_id=queue_batch_id, "queue_batch_id": queue_batch_id,
graph_execution_state_id=graph_execution_state_id, "graph_execution_state_id": graph_execution_state_id,
node=node, "node": node,
source_node_id=source_node_id, "source_node_id": source_node_id,
error_type=error_type, "error_type": error_type,
error=error, "error": error,
), },
) )
def emit_invocation_started( def emit_invocation_started(
@ -124,14 +124,14 @@ class EventServiceBase:
"""Emitted when an invocation has started""" """Emitted when an invocation has started"""
self.__emit_queue_event( self.__emit_queue_event(
event_name="invocation_started", event_name="invocation_started",
payload=dict( payload={
queue_id=queue_id, "queue_id": queue_id,
queue_item_id=queue_item_id, "queue_item_id": queue_item_id,
queue_batch_id=queue_batch_id, "queue_batch_id": queue_batch_id,
graph_execution_state_id=graph_execution_state_id, "graph_execution_state_id": graph_execution_state_id,
node=node, "node": node,
source_node_id=source_node_id, "source_node_id": source_node_id,
), },
) )
def emit_graph_execution_complete( def emit_graph_execution_complete(
@ -140,12 +140,12 @@ class EventServiceBase:
"""Emitted when a session has completed all invocations""" """Emitted when a session has completed all invocations"""
self.__emit_queue_event( self.__emit_queue_event(
event_name="graph_execution_state_complete", event_name="graph_execution_state_complete",
payload=dict( payload={
queue_id=queue_id, "queue_id": queue_id,
queue_item_id=queue_item_id, "queue_item_id": queue_item_id,
queue_batch_id=queue_batch_id, "queue_batch_id": queue_batch_id,
graph_execution_state_id=graph_execution_state_id, "graph_execution_state_id": graph_execution_state_id,
), },
) )
def emit_model_load_started( def emit_model_load_started(
@ -162,16 +162,16 @@ class EventServiceBase:
"""Emitted when a model is requested""" """Emitted when a model is requested"""
self.__emit_queue_event( self.__emit_queue_event(
event_name="model_load_started", event_name="model_load_started",
payload=dict( payload={
queue_id=queue_id, "queue_id": queue_id,
queue_item_id=queue_item_id, "queue_item_id": queue_item_id,
queue_batch_id=queue_batch_id, "queue_batch_id": queue_batch_id,
graph_execution_state_id=graph_execution_state_id, "graph_execution_state_id": graph_execution_state_id,
model_name=model_name, "model_name": model_name,
base_model=base_model, "base_model": base_model,
model_type=model_type, "model_type": model_type,
submodel=submodel, "submodel": submodel,
), },
) )
def emit_model_load_completed( def emit_model_load_completed(
@ -189,19 +189,19 @@ class EventServiceBase:
"""Emitted when a model is correctly loaded (returns model info)""" """Emitted when a model is correctly loaded (returns model info)"""
self.__emit_queue_event( self.__emit_queue_event(
event_name="model_load_completed", event_name="model_load_completed",
payload=dict( payload={
queue_id=queue_id, "queue_id": queue_id,
queue_item_id=queue_item_id, "queue_item_id": queue_item_id,
queue_batch_id=queue_batch_id, "queue_batch_id": queue_batch_id,
graph_execution_state_id=graph_execution_state_id, "graph_execution_state_id": graph_execution_state_id,
model_name=model_name, "model_name": model_name,
base_model=base_model, "base_model": base_model,
model_type=model_type, "model_type": model_type,
submodel=submodel, "submodel": submodel,
hash=model_info.hash, "hash": model_info.hash,
location=str(model_info.location), "location": str(model_info.location),
precision=str(model_info.precision), "precision": str(model_info.precision),
), },
) )
def emit_session_retrieval_error( def emit_session_retrieval_error(
@ -216,14 +216,14 @@ class EventServiceBase:
"""Emitted when session retrieval fails""" """Emitted when session retrieval fails"""
self.__emit_queue_event( self.__emit_queue_event(
event_name="session_retrieval_error", event_name="session_retrieval_error",
payload=dict( payload={
queue_id=queue_id, "queue_id": queue_id,
queue_item_id=queue_item_id, "queue_item_id": queue_item_id,
queue_batch_id=queue_batch_id, "queue_batch_id": queue_batch_id,
graph_execution_state_id=graph_execution_state_id, "graph_execution_state_id": graph_execution_state_id,
error_type=error_type, "error_type": error_type,
error=error, "error": error,
), },
) )
def emit_invocation_retrieval_error( def emit_invocation_retrieval_error(
@ -239,15 +239,15 @@ class EventServiceBase:
"""Emitted when invocation retrieval fails""" """Emitted when invocation retrieval fails"""
self.__emit_queue_event( self.__emit_queue_event(
event_name="invocation_retrieval_error", event_name="invocation_retrieval_error",
payload=dict( payload={
queue_id=queue_id, "queue_id": queue_id,
queue_item_id=queue_item_id, "queue_item_id": queue_item_id,
queue_batch_id=queue_batch_id, "queue_batch_id": queue_batch_id,
graph_execution_state_id=graph_execution_state_id, "graph_execution_state_id": graph_execution_state_id,
node_id=node_id, "node_id": node_id,
error_type=error_type, "error_type": error_type,
error=error, "error": error,
), },
) )
def emit_session_canceled( def emit_session_canceled(
@ -260,12 +260,12 @@ class EventServiceBase:
"""Emitted when a session is canceled""" """Emitted when a session is canceled"""
self.__emit_queue_event( self.__emit_queue_event(
event_name="session_canceled", event_name="session_canceled",
payload=dict( payload={
queue_id=queue_id, "queue_id": queue_id,
queue_item_id=queue_item_id, "queue_item_id": queue_item_id,
queue_batch_id=queue_batch_id, "queue_batch_id": queue_batch_id,
graph_execution_state_id=graph_execution_state_id, "graph_execution_state_id": graph_execution_state_id,
), },
) )
def emit_queue_item_status_changed( def emit_queue_item_status_changed(
@ -277,39 +277,39 @@ class EventServiceBase:
"""Emitted when a queue item's status changes""" """Emitted when a queue item's status changes"""
self.__emit_queue_event( self.__emit_queue_event(
event_name="queue_item_status_changed", event_name="queue_item_status_changed",
payload=dict( payload={
queue_id=queue_status.queue_id, "queue_id": queue_status.queue_id,
queue_item=dict( "queue_item": {
queue_id=session_queue_item.queue_id, "queue_id": session_queue_item.queue_id,
item_id=session_queue_item.item_id, "item_id": session_queue_item.item_id,
status=session_queue_item.status, "status": session_queue_item.status,
batch_id=session_queue_item.batch_id, "batch_id": session_queue_item.batch_id,
session_id=session_queue_item.session_id, "session_id": session_queue_item.session_id,
error=session_queue_item.error, "error": session_queue_item.error,
created_at=str(session_queue_item.created_at) if session_queue_item.created_at else None, "created_at": str(session_queue_item.created_at) if session_queue_item.created_at else None,
updated_at=str(session_queue_item.updated_at) if session_queue_item.updated_at else None, "updated_at": str(session_queue_item.updated_at) if session_queue_item.updated_at else None,
started_at=str(session_queue_item.started_at) if session_queue_item.started_at else None, "started_at": str(session_queue_item.started_at) if session_queue_item.started_at else None,
completed_at=str(session_queue_item.completed_at) if session_queue_item.completed_at else None, "completed_at": str(session_queue_item.completed_at) if session_queue_item.completed_at else None,
), },
batch_status=batch_status.model_dump(), "batch_status": batch_status.model_dump(),
queue_status=queue_status.model_dump(), "queue_status": queue_status.model_dump(),
), },
) )
def emit_batch_enqueued(self, enqueue_result: EnqueueBatchResult) -> None: def emit_batch_enqueued(self, enqueue_result: EnqueueBatchResult) -> None:
"""Emitted when a batch is enqueued""" """Emitted when a batch is enqueued"""
self.__emit_queue_event( self.__emit_queue_event(
event_name="batch_enqueued", event_name="batch_enqueued",
payload=dict( payload={
queue_id=enqueue_result.queue_id, "queue_id": enqueue_result.queue_id,
batch_id=enqueue_result.batch.batch_id, "batch_id": enqueue_result.batch.batch_id,
enqueued=enqueue_result.enqueued, "enqueued": enqueue_result.enqueued,
), },
) )
def emit_queue_cleared(self, queue_id: str) -> None: def emit_queue_cleared(self, queue_id: str) -> None:
"""Emitted when the queue is cleared""" """Emitted when the queue is cleared"""
self.__emit_queue_event( self.__emit_queue_event(
event_name="queue_cleared", event_name="queue_cleared",
payload=dict(queue_id=queue_id), payload={"queue_id": queue_id},
) )

View File

@ -25,7 +25,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
__invoker: Invoker __invoker: Invoker
def __init__(self, output_folder: Union[str, Path]): def __init__(self, output_folder: Union[str, Path]):
self.__cache = dict() self.__cache = {}
self.__cache_ids = Queue() self.__cache_ids = Queue()
self.__max_cache_size = 10 # TODO: get this from config self.__max_cache_size = 10 # TODO: get this from config

View File

@ -90,10 +90,7 @@ class ImageRecordDeleteException(Exception):
IMAGE_DTO_COLS = ", ".join( IMAGE_DTO_COLS = ", ".join(
list( ["images." + c for c in [
map(
lambda c: "images." + c,
[
"image_name", "image_name",
"image_origin", "image_origin",
"image_category", "image_category",
@ -106,9 +103,7 @@ IMAGE_DTO_COLS = ", ".join(
"updated_at", "updated_at",
"deleted_at", "deleted_at",
"starred", "starred",
], ]]
)
)
) )

View File

@ -263,7 +263,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
if categories is not None: if categories is not None:
# Convert the enum values to unique list of strings # Convert the enum values to unique list of strings
category_strings = list(map(lambda c: c.value, set(categories))) category_strings = [c.value for c in set(categories)]
# Create the correct length of placeholders # Create the correct length of placeholders
placeholders = ",".join("?" * len(category_strings)) placeholders = ",".join("?" * len(category_strings))
@ -307,7 +307,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
# Build the list of images, deserializing each row # Build the list of images, deserializing each row
self._cursor.execute(images_query, images_params) self._cursor.execute(images_query, images_params)
result = cast(list[sqlite3.Row], self._cursor.fetchall()) result = cast(list[sqlite3.Row], self._cursor.fetchall())
images = list(map(lambda r: deserialize_image_record(dict(r)), result)) images = [deserialize_image_record(dict(r)) for r in result]
# Set up and execute the count query, without pagination # Set up and execute the count query, without pagination
count_query += query_conditions + ";" count_query += query_conditions + ";"
@ -386,7 +386,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
""" """
) )
result = cast(list[sqlite3.Row], self._cursor.fetchall()) result = cast(list[sqlite3.Row], self._cursor.fetchall())
image_names = list(map(lambda r: r[0], result)) image_names = [r[0] for r in result]
self._cursor.execute( self._cursor.execute(
"""--sql """--sql
DELETE FROM images DELETE FROM images

View File

@ -21,8 +21,8 @@ class ImageServiceABC(ABC):
_on_deleted_callbacks: list[Callable[[str], None]] _on_deleted_callbacks: list[Callable[[str], None]]
def __init__(self) -> None: def __init__(self) -> None:
self._on_changed_callbacks = list() self._on_changed_callbacks = []
self._on_deleted_callbacks = list() self._on_deleted_callbacks = []
def on_changed(self, on_changed: Callable[[ImageDTO], None]) -> None: def on_changed(self, on_changed: Callable[[ImageDTO], None]) -> None:
"""Register a callback for when an image is changed""" """Register a callback for when an image is changed"""

View File

@ -217,18 +217,13 @@ class ImageService(ImageServiceABC):
board_id, board_id,
) )
image_dtos = list( image_dtos = [image_record_to_dto(
map(
lambda r: image_record_to_dto(
image_record=r, image_record=r,
image_url=self.__invoker.services.urls.get_image_url(r.image_name), image_url=self.__invoker.services.urls.get_image_url(r.image_name),
thumbnail_url=self.__invoker.services.urls.get_image_url(r.image_name, True), thumbnail_url=self.__invoker.services.urls.get_image_url(r.image_name, True),
board_id=self.__invoker.services.board_image_records.get_board_for_image(r.image_name), board_id=self.__invoker.services.board_image_records.get_board_for_image(r.image_name),
workflow_id=self.__invoker.services.workflow_image_records.get_workflow_for_image(r.image_name), workflow_id=self.__invoker.services.workflow_image_records.get_workflow_for_image(r.image_name),
), ) for r in results.items]
results.items,
)
)
return OffsetPaginatedResults[ImageDTO]( return OffsetPaginatedResults[ImageDTO](
items=image_dtos, items=image_dtos,

View File

@ -26,7 +26,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
self.__invoker_thread = Thread( self.__invoker_thread = Thread(
name="invoker_processor", name="invoker_processor",
target=self.__process, target=self.__process,
kwargs=dict(stop_event=self.__stop_event), kwargs={"stop_event": self.__stop_event},
) )
self.__invoker_thread.daemon = True # TODO: make async and do not use threads self.__invoker_thread.daemon = True # TODO: make async and do not use threads
self.__invoker_thread.start() self.__invoker_thread.start()

View File

@ -14,7 +14,7 @@ class MemoryInvocationQueue(InvocationQueueABC):
def __init__(self): def __init__(self):
self.__queue = Queue() self.__queue = Queue()
self.__cancellations = dict() self.__cancellations = {}
def get(self) -> InvocationQueueItem: def get(self) -> InvocationQueueItem:
item = self.__queue.get() item = self.__queue.get()

View File

@ -142,7 +142,7 @@ class InvocationStatsService(InvocationStatsServiceBase):
cache_stats = self._cache_stats[graph_id] cache_stats = self._cache_stats[graph_id]
hwm = cache_stats.high_watermark / GIG hwm = cache_stats.high_watermark / GIG
tot = cache_stats.cache_size / GIG tot = cache_stats.cache_size / GIG
loaded = sum([v for v in cache_stats.loaded_model_sizes.values()]) / GIG loaded = sum(list(cache_stats.loaded_model_sizes.values())) / GIG
logger.info(f"TOTAL GRAPH EXECUTION TIME: {total_time:7.3f}s") logger.info(f"TOTAL GRAPH EXECUTION TIME: {total_time:7.3f}s")
logger.info("RAM used by InvokeAI process: " + "%4.2fG" % self.ram_used + f" ({self.ram_changed:+5.3f}G)") logger.info("RAM used by InvokeAI process: " + "%4.2fG" % self.ram_used + f" ({self.ram_changed:+5.3f}G)")

View File

@ -15,8 +15,8 @@ class ItemStorageABC(ABC, Generic[T]):
_on_deleted_callbacks: list[Callable[[str], None]] _on_deleted_callbacks: list[Callable[[str], None]]
def __init__(self) -> None: def __init__(self) -> None:
self._on_changed_callbacks = list() self._on_changed_callbacks = []
self._on_deleted_callbacks = list() self._on_deleted_callbacks = []
"""Base item storage class""" """Base item storage class"""

View File

@ -112,7 +112,7 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
) )
result = self._cursor.fetchall() result = self._cursor.fetchall()
items = list(map(lambda r: self._parse_item(r[0]), result)) items = [self._parse_item(r[0]) for r in result]
self._cursor.execute(f"""SELECT count(*) FROM {self._table_name};""") self._cursor.execute(f"""SELECT count(*) FROM {self._table_name};""")
count = self._cursor.fetchone()[0] count = self._cursor.fetchone()[0]
@ -132,7 +132,7 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
) )
result = self._cursor.fetchall() result = self._cursor.fetchall()
items = list(map(lambda r: self._parse_item(r[0]), result)) items = [self._parse_item(r[0]) for r in result]
self._cursor.execute( self._cursor.execute(
f"""SELECT count(*) FROM {self._table_name} WHERE item LIKE ?;""", f"""SELECT count(*) FROM {self._table_name} WHERE item LIKE ?;""",

View File

@ -13,8 +13,8 @@ class LatentsStorageBase(ABC):
_on_deleted_callbacks: list[Callable[[str], None]] _on_deleted_callbacks: list[Callable[[str], None]]
def __init__(self) -> None: def __init__(self) -> None:
self._on_changed_callbacks = list() self._on_changed_callbacks = []
self._on_deleted_callbacks = list() self._on_deleted_callbacks = []
@abstractmethod @abstractmethod
def get(self, name: str) -> torch.Tensor: def get(self, name: str) -> torch.Tensor:

View File

@ -19,7 +19,7 @@ class ForwardCacheLatentsStorage(LatentsStorageBase):
def __init__(self, underlying_storage: LatentsStorageBase, max_cache_size: int = 20): def __init__(self, underlying_storage: LatentsStorageBase, max_cache_size: int = 20):
super().__init__() super().__init__()
self.__underlying_storage = underlying_storage self.__underlying_storage = underlying_storage
self.__cache = dict() self.__cache = {}
self.__cache_ids = Queue() self.__cache_ids = Queue()
self.__max_cache_size = max_cache_size self.__max_cache_size = max_cache_size

View File

@ -33,9 +33,9 @@ class DefaultSessionProcessor(SessionProcessorBase):
self.__thread = Thread( self.__thread = Thread(
name="session_processor", name="session_processor",
target=self.__process, target=self.__process,
kwargs=dict( kwargs={
stop_event=self.__stop_event, poll_now_event=self.__poll_now_event, resume_event=self.__resume_event "stop_event": self.__stop_event, "poll_now_event": self.__poll_now_event, "resume_event": self.__resume_event
), },
) )
self.__thread.start() self.__thread.start()

View File

@ -129,12 +129,12 @@ class Batch(BaseModel):
return v return v
model_config = ConfigDict( model_config = ConfigDict(
json_schema_extra=dict( json_schema_extra={
required=[ "required": [
"graph", "graph",
"runs", "runs",
] ]
) }
) )
@ -191,8 +191,8 @@ class SessionQueueItemWithoutGraph(BaseModel):
return SessionQueueItemDTO(**queue_item_dict) return SessionQueueItemDTO(**queue_item_dict)
model_config = ConfigDict( model_config = ConfigDict(
json_schema_extra=dict( json_schema_extra={
required=[ "required": [
"item_id", "item_id",
"status", "status",
"batch_id", "batch_id",
@ -203,7 +203,7 @@ class SessionQueueItemWithoutGraph(BaseModel):
"created_at", "created_at",
"updated_at", "updated_at",
] ]
) }
) )
@ -222,8 +222,8 @@ class SessionQueueItem(SessionQueueItemWithoutGraph):
return SessionQueueItem(**queue_item_dict) return SessionQueueItem(**queue_item_dict)
model_config = ConfigDict( model_config = ConfigDict(
json_schema_extra=dict( json_schema_extra={
required=[ "required": [
"item_id", "item_id",
"status", "status",
"batch_id", "batch_id",
@ -235,7 +235,7 @@ class SessionQueueItem(SessionQueueItemWithoutGraph):
"created_at", "created_at",
"updated_at", "updated_at",
] ]
) }
) )

View File

@ -78,7 +78,7 @@ def create_system_graphs(graph_library: ItemStorageABC[LibraryGraph]) -> list[Li
"""Creates the default system graphs, or adds new versions if the old ones don't match""" """Creates the default system graphs, or adds new versions if the old ones don't match"""
# TODO: Uncomment this when we are ready to fix this up to prevent breaking changes # TODO: Uncomment this when we are ready to fix this up to prevent breaking changes
graphs: list[LibraryGraph] = list() graphs: list[LibraryGraph] = []
text_to_image = graph_library.get(default_text_to_image_graph_id) text_to_image = graph_library.get(default_text_to_image_graph_id)

View File

@ -352,7 +352,7 @@ class Graph(BaseModel):
# Validate that all node ids are unique # Validate that all node ids are unique
node_ids = [n.id for n in self.nodes.values()] node_ids = [n.id for n in self.nodes.values()]
duplicate_node_ids = set([node_id for node_id in node_ids if node_ids.count(node_id) >= 2]) duplicate_node_ids = {node_id for node_id in node_ids if node_ids.count(node_id) >= 2}
if duplicate_node_ids: if duplicate_node_ids:
raise DuplicateNodeIdError(f"Node ids must be unique, found duplicates {duplicate_node_ids}") raise DuplicateNodeIdError(f"Node ids must be unique, found duplicates {duplicate_node_ids}")
@ -616,7 +616,7 @@ class Graph(BaseModel):
self, node_path: str, prefix: Optional[str] = None self, node_path: str, prefix: Optional[str] = None
) -> list[tuple["Graph", Union[str, None], Edge]]: ) -> list[tuple["Graph", Union[str, None], Edge]]:
"""Gets all input edges for a node along with the graph they are in and the graph's path""" """Gets all input edges for a node along with the graph they are in and the graph's path"""
edges = list() edges = []
# Return any input edges that appear in this graph # Return any input edges that appear in this graph
edges.extend([(self, prefix, e) for e in self.edges if e.destination.node_id == node_path]) edges.extend([(self, prefix, e) for e in self.edges if e.destination.node_id == node_path])
@ -658,7 +658,7 @@ class Graph(BaseModel):
self, node_path: str, prefix: Optional[str] = None self, node_path: str, prefix: Optional[str] = None
) -> list[tuple["Graph", Union[str, None], Edge]]: ) -> list[tuple["Graph", Union[str, None], Edge]]:
"""Gets all output edges for a node along with the graph they are in and the graph's path""" """Gets all output edges for a node along with the graph they are in and the graph's path"""
edges = list() edges = []
# Return any input edges that appear in this graph # Return any input edges that appear in this graph
edges.extend([(self, prefix, e) for e in self.edges if e.source.node_id == node_path]) edges.extend([(self, prefix, e) for e in self.edges if e.source.node_id == node_path])
@ -680,8 +680,8 @@ class Graph(BaseModel):
new_input: Optional[EdgeConnection] = None, new_input: Optional[EdgeConnection] = None,
new_output: Optional[EdgeConnection] = None, new_output: Optional[EdgeConnection] = None,
) -> bool: ) -> bool:
inputs = list([e.source for e in self._get_input_edges(node_path, "collection")]) inputs = [e.source for e in self._get_input_edges(node_path, "collection")]
outputs = list([e.destination for e in self._get_output_edges(node_path, "item")]) outputs = [e.destination for e in self._get_output_edges(node_path, "item")]
if new_input is not None: if new_input is not None:
inputs.append(new_input) inputs.append(new_input)
@ -694,7 +694,7 @@ class Graph(BaseModel):
# Get input and output fields (the fields linked to the iterator's input/output) # Get input and output fields (the fields linked to the iterator's input/output)
input_field = get_output_field(self.get_node(inputs[0].node_id), inputs[0].field) input_field = get_output_field(self.get_node(inputs[0].node_id), inputs[0].field)
output_fields = list([get_input_field(self.get_node(e.node_id), e.field) for e in outputs]) output_fields = [get_input_field(self.get_node(e.node_id), e.field) for e in outputs]
# Input type must be a list # Input type must be a list
if get_origin(input_field) != list: if get_origin(input_field) != list:
@ -713,8 +713,8 @@ class Graph(BaseModel):
new_input: Optional[EdgeConnection] = None, new_input: Optional[EdgeConnection] = None,
new_output: Optional[EdgeConnection] = None, new_output: Optional[EdgeConnection] = None,
) -> bool: ) -> bool:
inputs = list([e.source for e in self._get_input_edges(node_path, "item")]) inputs = [e.source for e in self._get_input_edges(node_path, "item")]
outputs = list([e.destination for e in self._get_output_edges(node_path, "collection")]) outputs = [e.destination for e in self._get_output_edges(node_path, "collection")]
if new_input is not None: if new_input is not None:
inputs.append(new_input) inputs.append(new_input)
@ -722,18 +722,16 @@ class Graph(BaseModel):
outputs.append(new_output) outputs.append(new_output)
# Get input and output fields (the fields linked to the iterator's input/output) # Get input and output fields (the fields linked to the iterator's input/output)
input_fields = list([get_output_field(self.get_node(e.node_id), e.field) for e in inputs]) input_fields = [get_output_field(self.get_node(e.node_id), e.field) for e in inputs]
output_fields = list([get_input_field(self.get_node(e.node_id), e.field) for e in outputs]) output_fields = [get_input_field(self.get_node(e.node_id), e.field) for e in outputs]
# Validate that all inputs are derived from or match a single type # Validate that all inputs are derived from or match a single type
input_field_types = set( input_field_types = {
[ t
t
for input_field in input_fields for input_field in input_fields
for t in ([input_field] if get_origin(input_field) is None else get_args(input_field)) for t in ([input_field] if get_origin(input_field) is None else get_args(input_field))
if t != NoneType if t != NoneType
] } # Get unique types
) # Get unique types
type_tree = nx.DiGraph() type_tree = nx.DiGraph()
type_tree.add_nodes_from(input_field_types) type_tree.add_nodes_from(input_field_types)
type_tree.add_edges_from([e for e in itertools.permutations(input_field_types, 2) if issubclass(e[1], e[0])]) type_tree.add_edges_from([e for e in itertools.permutations(input_field_types, 2) if issubclass(e[1], e[0])])
@ -761,15 +759,15 @@ class Graph(BaseModel):
"""Returns a NetworkX DiGraph representing the layout of this graph""" """Returns a NetworkX DiGraph representing the layout of this graph"""
# TODO: Cache this? # TODO: Cache this?
g = nx.DiGraph() g = nx.DiGraph()
g.add_nodes_from([n for n in self.nodes.keys()]) g.add_nodes_from(list(self.nodes.keys()))
g.add_edges_from(set([(e.source.node_id, e.destination.node_id) for e in self.edges])) g.add_edges_from({(e.source.node_id, e.destination.node_id) for e in self.edges})
return g return g
def nx_graph_with_data(self) -> nx.DiGraph: def nx_graph_with_data(self) -> nx.DiGraph:
"""Returns a NetworkX DiGraph representing the data and layout of this graph""" """Returns a NetworkX DiGraph representing the data and layout of this graph"""
g = nx.DiGraph() g = nx.DiGraph()
g.add_nodes_from([n for n in self.nodes.items()]) g.add_nodes_from(list(self.nodes.items()))
g.add_edges_from(set([(e.source.node_id, e.destination.node_id) for e in self.edges])) g.add_edges_from({(e.source.node_id, e.destination.node_id) for e in self.edges})
return g return g
def nx_graph_flat(self, nx_graph: Optional[nx.DiGraph] = None, prefix: Optional[str] = None) -> nx.DiGraph: def nx_graph_flat(self, nx_graph: Optional[nx.DiGraph] = None, prefix: Optional[str] = None) -> nx.DiGraph:
@ -791,7 +789,7 @@ class Graph(BaseModel):
# TODO: figure out if iteration nodes need to be expanded # TODO: figure out if iteration nodes need to be expanded
unique_edges = set([(e.source.node_id, e.destination.node_id) for e in self.edges]) unique_edges = {(e.source.node_id, e.destination.node_id) for e in self.edges}
g.add_edges_from([(self._get_node_path(e[0], prefix), self._get_node_path(e[1], prefix)) for e in unique_edges]) g.add_edges_from([(self._get_node_path(e[0], prefix), self._get_node_path(e[1], prefix)) for e in unique_edges])
return g return g
@ -843,8 +841,8 @@ class GraphExecutionState(BaseModel):
return v return v
model_config = ConfigDict( model_config = ConfigDict(
json_schema_extra=dict( json_schema_extra={
required=[ "required": [
"id", "id",
"graph", "graph",
"execution_graph", "execution_graph",
@ -855,7 +853,7 @@ class GraphExecutionState(BaseModel):
"prepared_source_mapping", "prepared_source_mapping",
"source_prepared_mapping", "source_prepared_mapping",
] ]
) }
) )
def next(self) -> Optional[BaseInvocation]: def next(self) -> Optional[BaseInvocation]:
@ -895,7 +893,7 @@ class GraphExecutionState(BaseModel):
source_node = self.prepared_source_mapping[node_id] source_node = self.prepared_source_mapping[node_id]
prepared_nodes = self.source_prepared_mapping[source_node] prepared_nodes = self.source_prepared_mapping[source_node]
if all([n in self.executed for n in prepared_nodes]): if all(n in self.executed for n in prepared_nodes):
self.executed.add(source_node) self.executed.add(source_node)
self.executed_history.append(source_node) self.executed_history.append(source_node)
@ -930,7 +928,7 @@ class GraphExecutionState(BaseModel):
input_collection = getattr(input_collection_prepared_node_output, input_collection_edge.source.field) input_collection = getattr(input_collection_prepared_node_output, input_collection_edge.source.field)
self_iteration_count = len(input_collection) self_iteration_count = len(input_collection)
new_nodes: list[str] = list() new_nodes: list[str] = []
if self_iteration_count == 0: if self_iteration_count == 0:
# TODO: should this raise a warning? It might just happen if an empty collection is input, and should be valid. # TODO: should this raise a warning? It might just happen if an empty collection is input, and should be valid.
return new_nodes return new_nodes
@ -940,7 +938,7 @@ class GraphExecutionState(BaseModel):
# Create new edges for this iteration # Create new edges for this iteration
# For collect nodes, this may contain multiple inputs to the same field # For collect nodes, this may contain multiple inputs to the same field
new_edges: list[Edge] = list() new_edges: list[Edge] = []
for edge in input_edges: for edge in input_edges:
for input_node_id in (n[1] for n in iteration_node_map if n[0] == edge.source.node_id): for input_node_id in (n[1] for n in iteration_node_map if n[0] == edge.source.node_id):
new_edge = Edge( new_edge = Edge(
@ -1034,7 +1032,7 @@ class GraphExecutionState(BaseModel):
# Create execution nodes # Create execution nodes
next_node = self.graph.get_node(next_node_id) next_node = self.graph.get_node(next_node_id)
new_node_ids = list() new_node_ids = []
if isinstance(next_node, CollectInvocation): if isinstance(next_node, CollectInvocation):
# Collapse all iterator input mappings and create a single execution node for the collect invocation # Collapse all iterator input mappings and create a single execution node for the collect invocation
all_iteration_mappings = list( all_iteration_mappings = list(
@ -1201,7 +1199,7 @@ class LibraryGraph(BaseModel):
@field_validator("exposed_inputs", "exposed_outputs") @field_validator("exposed_inputs", "exposed_outputs")
def validate_exposed_aliases(cls, v: list[Union[ExposedNodeInput, ExposedNodeOutput]]): def validate_exposed_aliases(cls, v: list[Union[ExposedNodeInput, ExposedNodeOutput]]):
if len(v) != len(set(i.alias for i in v)): if len(v) != len({i.alias for i in v}):
raise ValueError("Duplicate exposed alias") raise ValueError("Duplicate exposed alias")
return v return v

View File

@ -88,7 +88,7 @@ class PromptFormatter:
t2i = self.t2i t2i = self.t2i
opt = self.opt opt = self.opt
switches = list() switches = []
switches.append(f'"{opt.prompt}"') switches.append(f'"{opt.prompt}"')
switches.append(f"-s{opt.steps or t2i.steps}") switches.append(f"-s{opt.steps or t2i.steps}")
switches.append(f"-W{opt.width or t2i.width}") switches.append(f"-W{opt.width or t2i.width}")

View File

@ -40,7 +40,7 @@ class InitImageResizer:
(rw, rh) = (int(scale * im.width), int(scale * im.height)) (rw, rh) = (int(scale * im.width), int(scale * im.height))
# round everything to multiples of 64 # round everything to multiples of 64
width, height, rw, rh = map(lambda x: x - x % 64, (width, height, rw, rh)) width, height, rw, rh = (x - x % 64 for x in (width, height, rw, rh))
# no resize necessary, but return a copy # no resize necessary, but return a copy
if im.width == width and im.height == height: if im.width == width and im.height == height:

View File

@ -197,7 +197,7 @@ def download_with_progress_bar(model_url: str, model_dest: str, label: str = "th
def download_conversion_models(): def download_conversion_models():
target_dir = config.models_path / "core/convert" target_dir = config.models_path / "core/convert"
kwargs = dict() # for future use kwargs = {} # for future use
try: try:
logger.info("Downloading core tokenizers and text encoders") logger.info("Downloading core tokenizers and text encoders")
@ -252,26 +252,26 @@ def download_conversion_models():
def download_realesrgan(): def download_realesrgan():
logger.info("Installing ESRGAN Upscaling models...") logger.info("Installing ESRGAN Upscaling models...")
URLs = [ URLs = [
dict( {
url="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth", "url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
dest="core/upscaling/realesrgan/RealESRGAN_x4plus.pth", "dest": "core/upscaling/realesrgan/RealESRGAN_x4plus.pth",
description="RealESRGAN_x4plus.pth", "description": "RealESRGAN_x4plus.pth",
), },
dict( {
url="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth", "url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth",
dest="core/upscaling/realesrgan/RealESRGAN_x4plus_anime_6B.pth", "dest": "core/upscaling/realesrgan/RealESRGAN_x4plus_anime_6B.pth",
description="RealESRGAN_x4plus_anime_6B.pth", "description": "RealESRGAN_x4plus_anime_6B.pth",
), },
dict( {
url="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth", "url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth",
dest="core/upscaling/realesrgan/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth", "dest": "core/upscaling/realesrgan/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth",
description="ESRGAN_SRx4_DF2KOST_official.pth", "description": "ESRGAN_SRx4_DF2KOST_official.pth",
), },
dict( {
url="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth", "url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
dest="core/upscaling/realesrgan/RealESRGAN_x2plus.pth", "dest": "core/upscaling/realesrgan/RealESRGAN_x2plus.pth",
description="RealESRGAN_x2plus.pth", "description": "RealESRGAN_x2plus.pth",
), },
] ]
for model in URLs: for model in URLs:
download_with_progress_bar(model["url"], config.models_path / model["dest"], model["description"]) download_with_progress_bar(model["url"], config.models_path / model["dest"], model["description"])
@ -680,7 +680,7 @@ def default_user_selections(program_opts: Namespace) -> InstallSelections:
if program_opts.default_only if program_opts.default_only
else [models[x].path or models[x].repo_id for x in installer.recommended_models()] else [models[x].path or models[x].repo_id for x in installer.recommended_models()]
if program_opts.yes_to_all if program_opts.yes_to_all
else list(), else [],
) )

View File

@ -182,10 +182,10 @@ class MigrateTo3(object):
""" """
dest_directory = self.dest_models dest_directory = self.dest_models
kwargs = dict( kwargs = {
cache_dir=self.root_directory / "models/hub", "cache_dir": self.root_directory / "models/hub",
# local_files_only = True # local_files_only = True
) }
try: try:
logger.info("Migrating core tokenizers and text encoders") logger.info("Migrating core tokenizers and text encoders")
target_dir = dest_directory / "core" / "convert" target_dir = dest_directory / "core" / "convert"
@ -316,11 +316,11 @@ class MigrateTo3(object):
dest_dir = self.dest_models dest_dir = self.dest_models
cache = self.root_directory / "models/hub" cache = self.root_directory / "models/hub"
kwargs = dict( kwargs = {
cache_dir=cache, "cache_dir": cache,
safety_checker=None, "safety_checker": None,
# local_files_only = True, # local_files_only = True,
) }
owner, repo_name = repo_id.split("/") owner, repo_name = repo_id.split("/")
model_name = model_name or repo_name model_name = model_name or repo_name

View File

@ -120,7 +120,7 @@ class ModelInstall(object):
be treated uniformly. It also sorts the models alphabetically be treated uniformly. It also sorts the models alphabetically
by their name, to improve the display somewhat. by their name, to improve the display somewhat.
""" """
model_dict = dict() model_dict = {}
# first populate with the entries in INITIAL_MODELS.yaml # first populate with the entries in INITIAL_MODELS.yaml
for key, value in self.datasets.items(): for key, value in self.datasets.items():
@ -134,7 +134,7 @@ class ModelInstall(object):
model_dict[key] = model_info model_dict[key] = model_info
# supplement with entries in models.yaml # supplement with entries in models.yaml
installed_models = [x for x in self.mgr.list_models()] installed_models = list(self.mgr.list_models())
for md in installed_models: for md in installed_models:
base = md["base_model"] base = md["base_model"]
@ -184,7 +184,7 @@ class ModelInstall(object):
def recommended_models(self) -> Set[str]: def recommended_models(self) -> Set[str]:
starters = self.starter_models(all_models=True) starters = self.starter_models(all_models=True)
return set([x for x in starters if self.datasets[x].get("recommended", False)]) return {x for x in starters if self.datasets[x].get("recommended", False)}
def default_model(self) -> str: def default_model(self) -> str:
starters = self.starter_models() starters = self.starter_models()
@ -234,7 +234,7 @@ class ModelInstall(object):
""" """
if not models_installed: if not models_installed:
models_installed = dict() models_installed = {}
model_path_id_or_url = str(model_path_id_or_url).strip("\"' ") model_path_id_or_url = str(model_path_id_or_url).strip("\"' ")
@ -252,8 +252,7 @@ class ModelInstall(object):
# folders style or similar # folders style or similar
elif path.is_dir() and any( elif path.is_dir() and any(
[ (path / x).exists()
(path / x).exists()
for x in { for x in {
"config.json", "config.json",
"model_index.json", "model_index.json",
@ -261,7 +260,6 @@ class ModelInstall(object):
"pytorch_lora_weights.bin", "pytorch_lora_weights.bin",
"pytorch_lora_weights.safetensors", "pytorch_lora_weights.safetensors",
} }
]
): ):
models_installed.update({str(model_path_id_or_url): self._install_path(path)}) models_installed.update({str(model_path_id_or_url): self._install_path(path)})
@ -433,17 +431,17 @@ class ModelInstall(object):
rel_path = self.relative_to_root(path, self.config.models_path) rel_path = self.relative_to_root(path, self.config.models_path)
attributes = dict( attributes = {
path=str(rel_path), "path": str(rel_path),
description=str(description), "description": str(description),
model_format=info.format, "model_format": info.format,
) }
legacy_conf = None legacy_conf = None
if info.model_type == ModelType.Main or info.model_type == ModelType.ONNX: if info.model_type == ModelType.Main or info.model_type == ModelType.ONNX:
attributes.update( attributes.update(
dict( {
variant=info.variant_type, "variant": info.variant_type,
) }
) )
if info.format == "checkpoint": if info.format == "checkpoint":
try: try:
@ -474,7 +472,7 @@ class ModelInstall(object):
) )
if legacy_conf: if legacy_conf:
attributes.update(dict(config=str(legacy_conf))) attributes.update({"config": str(legacy_conf)})
return attributes return attributes
def relative_to_root(self, path: Path, root: Optional[Path] = None) -> Path: def relative_to_root(self, path: Path, root: Optional[Path] = None) -> Path:
@ -519,7 +517,7 @@ class ModelInstall(object):
def _download_hf_model(self, repo_id: str, files: List[str], staging: Path, subfolder: None) -> Path: def _download_hf_model(self, repo_id: str, files: List[str], staging: Path, subfolder: None) -> Path:
_, name = repo_id.split("/") _, name = repo_id.split("/")
location = staging / name location = staging / name
paths = list() paths = []
for filename in files: for filename in files:
filePath = Path(filename) filePath = Path(filename)
p = hf_download_with_resume( p = hf_download_with_resume(

View File

@ -104,7 +104,7 @@ class ModelPatcher:
loras: List[Tuple[LoRAModel, float]], loras: List[Tuple[LoRAModel, float]],
prefix: str, prefix: str,
): ):
original_weights = dict() original_weights = {}
try: try:
with torch.no_grad(): with torch.no_grad():
for lora, lora_weight in loras: for lora, lora_weight in loras:
@ -324,7 +324,7 @@ class TextualInversionManager(BaseTextualInversionManager):
tokenizer: CLIPTokenizer tokenizer: CLIPTokenizer
def __init__(self, tokenizer: CLIPTokenizer): def __init__(self, tokenizer: CLIPTokenizer):
self.pad_tokens = dict() self.pad_tokens = {}
self.tokenizer = tokenizer self.tokenizer = tokenizer
def expand_textual_inversion_token_ids_if_necessary(self, token_ids: list[int]) -> list[int]: def expand_textual_inversion_token_ids_if_necessary(self, token_ids: list[int]) -> list[int]:
@ -385,10 +385,10 @@ class ONNXModelPatcher:
if not isinstance(model, IAIOnnxRuntimeModel): if not isinstance(model, IAIOnnxRuntimeModel):
raise Exception("Only IAIOnnxRuntimeModel models supported") raise Exception("Only IAIOnnxRuntimeModel models supported")
orig_weights = dict() orig_weights = {}
try: try:
blended_loras = dict() blended_loras = {}
for lora, lora_weight in loras: for lora, lora_weight in loras:
for layer_key, layer in lora.layers.items(): for layer_key, layer in lora.layers.items():
@ -404,7 +404,7 @@ class ONNXModelPatcher:
else: else:
blended_loras[layer_key] = layer_weight blended_loras[layer_key] = layer_weight
node_names = dict() node_names = {}
for node in model.nodes.values(): for node in model.nodes.values():
node_names[node.name.replace("/", "_").replace(".", "_").lstrip("_")] = node.name node_names[node.name.replace("/", "_").replace(".", "_").lstrip("_")] = node.name

View File

@ -132,7 +132,7 @@ class ModelCache(object):
snapshots, so it is recommended to disable this feature unless you are actively inspecting the model cache's snapshots, so it is recommended to disable this feature unless you are actively inspecting the model cache's
behaviour. behaviour.
""" """
self.model_infos: Dict[str, ModelBase] = dict() self.model_infos: Dict[str, ModelBase] = {}
# allow lazy offloading only when vram cache enabled # allow lazy offloading only when vram cache enabled
self.lazy_offloading = lazy_offloading and max_vram_cache_size > 0 self.lazy_offloading = lazy_offloading and max_vram_cache_size > 0
self.precision: torch.dtype = precision self.precision: torch.dtype = precision
@ -147,8 +147,8 @@ class ModelCache(object):
# used for stats collection # used for stats collection
self.stats = None self.stats = None
self._cached_models = dict() self._cached_models = {}
self._cache_stack = list() self._cache_stack = []
def _capture_memory_snapshot(self) -> Optional[MemorySnapshot]: def _capture_memory_snapshot(self) -> Optional[MemorySnapshot]:
if self._log_memory_usage: if self._log_memory_usage:

View File

@ -363,7 +363,7 @@ class ModelManager(object):
else: else:
return return
self.models = dict() self.models = {}
for model_key, model_config in config.items(): for model_key, model_config in config.items():
if model_key.startswith("_"): if model_key.startswith("_"):
continue continue
@ -374,7 +374,7 @@ class ModelManager(object):
self.models[model_key] = model_class.create_config(**model_config) self.models[model_key] = model_class.create_config(**model_config)
# check config version number and update on disk/RAM if necessary # check config version number and update on disk/RAM if necessary
self.cache_keys = dict() self.cache_keys = {}
# add controlnet, lora and textual_inversion models from disk # add controlnet, lora and textual_inversion models from disk
self.scan_models_directory() self.scan_models_directory()
@ -902,7 +902,7 @@ class ModelManager(object):
""" """
Write current configuration out to the indicated file. Write current configuration out to the indicated file.
""" """
data_to_save = dict() data_to_save = {}
data_to_save["__metadata__"] = self.config_meta.model_dump() data_to_save["__metadata__"] = self.config_meta.model_dump()
for model_key, model_config in self.models.items(): for model_key, model_config in self.models.items():
@ -1034,7 +1034,7 @@ class ModelManager(object):
self.ignore = ignore self.ignore = ignore
def on_search_started(self): def on_search_started(self):
self.new_models_found = dict() self.new_models_found = {}
def on_model_found(self, model: Path): def on_model_found(self, model: Path):
if model not in self.ignore: if model not in self.ignore:
@ -1106,7 +1106,7 @@ class ModelManager(object):
# avoid circular import here # avoid circular import here
from invokeai.backend.install.model_install_backend import ModelInstall from invokeai.backend.install.model_install_backend import ModelInstall
successfully_installed = dict() successfully_installed = {}
installer = ModelInstall( installer = ModelInstall(
config=self.app_config, prediction_type_helper=prediction_type_helper, model_manager=self config=self.app_config, prediction_type_helper=prediction_type_helper, model_manager=self

View File

@ -92,7 +92,7 @@ class ModelMerger(object):
**kwargs - the default DiffusionPipeline.get_config_dict kwargs: **kwargs - the default DiffusionPipeline.get_config_dict kwargs:
cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map
""" """
model_paths = list() model_paths = []
config = self.manager.app_config config = self.manager.app_config
base_model = BaseModelType(base_model) base_model = BaseModelType(base_model)
vae = None vae = None
@ -124,13 +124,13 @@ class ModelMerger(object):
dump_path = (dump_path / merged_model_name).as_posix() dump_path = (dump_path / merged_model_name).as_posix()
merged_pipe.save_pretrained(dump_path, safe_serialization=True) merged_pipe.save_pretrained(dump_path, safe_serialization=True)
attributes = dict( attributes = {
path=dump_path, "path": dump_path,
description=f"Merge of models {', '.join(model_names)}", "description": f"Merge of models {', '.join(model_names)}",
model_format="diffusers", "model_format": "diffusers",
variant=ModelVariantType.Normal.value, "variant": ModelVariantType.Normal.value,
vae=vae, "vae": vae,
) }
return self.manager.add_model( return self.manager.add_model(
merged_model_name, merged_model_name,
base_model=base_model, base_model=base_model,

View File

@ -59,7 +59,7 @@ class ModelSearch(ABC):
for root, dirs, files in os.walk(path, followlinks=True): for root, dirs, files in os.walk(path, followlinks=True):
if str(Path(root).name).startswith("."): if str(Path(root).name).startswith("."):
self._pruned_paths.add(root) self._pruned_paths.add(root)
if any([Path(root).is_relative_to(x) for x in self._pruned_paths]): if any(Path(root).is_relative_to(x) for x in self._pruned_paths):
continue continue
self._items_scanned += len(dirs) + len(files) self._items_scanned += len(dirs) + len(files)
@ -69,8 +69,7 @@ class ModelSearch(ABC):
self._scanned_dirs.add(path) self._scanned_dirs.add(path)
continue continue
if any( if any(
[ (path / x).exists()
(path / x).exists()
for x in { for x in {
"config.json", "config.json",
"model_index.json", "model_index.json",
@ -78,7 +77,6 @@ class ModelSearch(ABC):
"pytorch_lora_weights.bin", "pytorch_lora_weights.bin",
"image_encoder.txt", "image_encoder.txt",
} }
]
): ):
try: try:
self.on_model_found(path) self.on_model_found(path)

View File

@ -97,8 +97,8 @@ MODEL_CLASSES = {
# }, # },
} }
MODEL_CONFIGS = list() MODEL_CONFIGS = []
OPENAPI_MODEL_CONFIGS = list() OPENAPI_MODEL_CONFIGS = []
class OpenAPIModelInfoBase(BaseModel): class OpenAPIModelInfoBase(BaseModel):
@ -133,7 +133,7 @@ for base_model, models in MODEL_CLASSES.items():
def get_model_config_enums(): def get_model_config_enums():
enums = list() enums = []
for model_config in MODEL_CONFIGS: for model_config in MODEL_CONFIGS:
if hasattr(inspect, "get_annotations"): if hasattr(inspect, "get_annotations"):

View File

@ -164,7 +164,7 @@ class ModelBase(metaclass=ABCMeta):
with suppress(Exception): with suppress(Exception):
return cls.__configs return cls.__configs
configs = dict() configs = {}
for name in dir(cls): for name in dir(cls):
if name.startswith("__"): if name.startswith("__"):
continue continue
@ -246,8 +246,8 @@ class DiffusersModel(ModelBase):
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType): def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
super().__init__(model_path, base_model, model_type) super().__init__(model_path, base_model, model_type)
self.child_types: Dict[str, Type] = dict() self.child_types: Dict[str, Type] = {}
self.child_sizes: Dict[str, int] = dict() self.child_sizes: Dict[str, int] = {}
try: try:
config_data = DiffusionPipeline.load_config(self.model_path) config_data = DiffusionPipeline.load_config(self.model_path)
@ -326,8 +326,8 @@ def calc_model_size_by_fs(model_path: str, subfolder: Optional[str] = None, vari
all_files = os.listdir(model_path) all_files = os.listdir(model_path)
all_files = [f for f in all_files if os.path.isfile(os.path.join(model_path, f))] all_files = [f for f in all_files if os.path.isfile(os.path.join(model_path, f))]
fp16_files = set([f for f in all_files if ".fp16." in f or ".fp16-" in f]) fp16_files = {f for f in all_files if ".fp16." in f or ".fp16-" in f}
bit8_files = set([f for f in all_files if ".8bit." in f or ".8bit-" in f]) bit8_files = {f for f in all_files if ".8bit." in f or ".8bit-" in f}
other_files = set(all_files) - fp16_files - bit8_files other_files = set(all_files) - fp16_files - bit8_files
if variant is None: if variant is None:
@ -413,7 +413,7 @@ def _calc_onnx_model_by_data(model) -> int:
def _fast_safetensors_reader(path: str): def _fast_safetensors_reader(path: str):
checkpoint = dict() checkpoint = {}
device = torch.device("meta") device = torch.device("meta")
with open(path, "rb") as f: with open(path, "rb") as f:
definition_len = int.from_bytes(f.read(8), "little") definition_len = int.from_bytes(f.read(8), "little")
@ -483,7 +483,7 @@ class IAIOnnxRuntimeModel:
class _tensor_access: class _tensor_access:
def __init__(self, model): def __init__(self, model):
self.model = model self.model = model
self.indexes = dict() self.indexes = {}
for idx, obj in enumerate(self.model.proto.graph.initializer): for idx, obj in enumerate(self.model.proto.graph.initializer):
self.indexes[obj.name] = idx self.indexes[obj.name] = idx
@ -524,7 +524,7 @@ class IAIOnnxRuntimeModel:
class _access_helper: class _access_helper:
def __init__(self, raw_proto): def __init__(self, raw_proto):
self.indexes = dict() self.indexes = {}
self.raw_proto = raw_proto self.raw_proto = raw_proto
for idx, obj in enumerate(raw_proto): for idx, obj in enumerate(raw_proto):
self.indexes[obj.name] = idx self.indexes[obj.name] = idx
@ -549,7 +549,7 @@ class IAIOnnxRuntimeModel:
return self.indexes.keys() return self.indexes.keys()
def values(self): def values(self):
return [obj for obj in self.raw_proto] return list(self.raw_proto)
def __init__(self, model_path: str, provider: Optional[str]): def __init__(self, model_path: str, provider: Optional[str]):
self.path = model_path self.path = model_path

View File

@ -104,7 +104,7 @@ class ControlNetModel(ModelBase):
return ControlNetModelFormat.Diffusers return ControlNetModelFormat.Diffusers
if os.path.isfile(path): if os.path.isfile(path):
if any([path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt", "pth"]]): if any(path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt", "pth"]):
return ControlNetModelFormat.Checkpoint return ControlNetModelFormat.Checkpoint
raise InvalidModelException(f"Not a valid model: {path}") raise InvalidModelException(f"Not a valid model: {path}")

View File

@ -73,7 +73,7 @@ class LoRAModel(ModelBase):
return LoRAModelFormat.Diffusers return LoRAModelFormat.Diffusers
if os.path.isfile(path): if os.path.isfile(path):
if any([path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]]): if any(path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]):
return LoRAModelFormat.LyCORIS return LoRAModelFormat.LyCORIS
raise InvalidModelException(f"Not a valid model: {path}") raise InvalidModelException(f"Not a valid model: {path}")
@ -499,7 +499,7 @@ class LoRAModelRaw: # (torch.nn.Module):
stability_unet_keys = list(SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP) stability_unet_keys = list(SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP)
stability_unet_keys.sort() stability_unet_keys.sort()
new_state_dict = dict() new_state_dict = {}
for full_key, value in state_dict.items(): for full_key, value in state_dict.items():
if full_key.startswith("lora_unet_"): if full_key.startswith("lora_unet_"):
search_key = full_key.replace("lora_unet_", "") search_key = full_key.replace("lora_unet_", "")
@ -545,7 +545,7 @@ class LoRAModelRaw: # (torch.nn.Module):
model = cls( model = cls(
name=file_path.stem, # TODO: name=file_path.stem, # TODO:
layers=dict(), layers={},
) )
if file_path.suffix == ".safetensors": if file_path.suffix == ".safetensors":
@ -593,12 +593,12 @@ class LoRAModelRaw: # (torch.nn.Module):
@staticmethod @staticmethod
def _group_state(state_dict: dict): def _group_state(state_dict: dict):
state_dict_groupped = dict() state_dict_groupped = {}
for key, value in state_dict.items(): for key, value in state_dict.items():
stem, leaf = key.split(".", 1) stem, leaf = key.split(".", 1)
if stem not in state_dict_groupped: if stem not in state_dict_groupped:
state_dict_groupped[stem] = dict() state_dict_groupped[stem] = {}
state_dict_groupped[stem][leaf] = value state_dict_groupped[stem][leaf] = value
return state_dict_groupped return state_dict_groupped

View File

@ -110,7 +110,7 @@ class StableDiffusion1Model(DiffusersModel):
return StableDiffusion1ModelFormat.Diffusers return StableDiffusion1ModelFormat.Diffusers
if os.path.isfile(model_path): if os.path.isfile(model_path):
if any([model_path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]]): if any(model_path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]):
return StableDiffusion1ModelFormat.Checkpoint return StableDiffusion1ModelFormat.Checkpoint
raise InvalidModelException(f"Not a valid model: {model_path}") raise InvalidModelException(f"Not a valid model: {model_path}")
@ -221,7 +221,7 @@ class StableDiffusion2Model(DiffusersModel):
return StableDiffusion2ModelFormat.Diffusers return StableDiffusion2ModelFormat.Diffusers
if os.path.isfile(model_path): if os.path.isfile(model_path):
if any([model_path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]]): if any(model_path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]):
return StableDiffusion2ModelFormat.Checkpoint return StableDiffusion2ModelFormat.Checkpoint
raise InvalidModelException(f"Not a valid model: {model_path}") raise InvalidModelException(f"Not a valid model: {model_path}")

View File

@ -71,7 +71,7 @@ class TextualInversionModel(ModelBase):
return None # diffusers-ti return None # diffusers-ti
if os.path.isfile(path): if os.path.isfile(path):
if any([path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt", "bin"]]): if any(path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt", "bin"]):
return None return None
raise InvalidModelException(f"Not a valid model: {path}") raise InvalidModelException(f"Not a valid model: {path}")

View File

@ -89,7 +89,7 @@ class VaeModel(ModelBase):
return VaeModelFormat.Diffusers return VaeModelFormat.Diffusers
if os.path.isfile(path): if os.path.isfile(path):
if any([path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]]): if any(path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]):
return VaeModelFormat.Checkpoint return VaeModelFormat.Checkpoint
raise InvalidModelException(f"Not a valid model: {path}") raise InvalidModelException(f"Not a valid model: {path}")

View File

@ -16,28 +16,28 @@ from diffusers import (
UniPCMultistepScheduler, UniPCMultistepScheduler,
) )
SCHEDULER_MAP = dict( SCHEDULER_MAP = {
ddim=(DDIMScheduler, dict()), "ddim": (DDIMScheduler, {}),
ddpm=(DDPMScheduler, dict()), "ddpm": (DDPMScheduler, {}),
deis=(DEISMultistepScheduler, dict()), "deis": (DEISMultistepScheduler, {}),
lms=(LMSDiscreteScheduler, dict(use_karras_sigmas=False)), "lms": (LMSDiscreteScheduler, {"use_karras_sigmas": False}),
lms_k=(LMSDiscreteScheduler, dict(use_karras_sigmas=True)), "lms_k": (LMSDiscreteScheduler, {"use_karras_sigmas": True}),
pndm=(PNDMScheduler, dict()), "pndm": (PNDMScheduler, {}),
heun=(HeunDiscreteScheduler, dict(use_karras_sigmas=False)), "heun": (HeunDiscreteScheduler, {"use_karras_sigmas": False}),
heun_k=(HeunDiscreteScheduler, dict(use_karras_sigmas=True)), "heun_k": (HeunDiscreteScheduler, {"use_karras_sigmas": True}),
euler=(EulerDiscreteScheduler, dict(use_karras_sigmas=False)), "euler": (EulerDiscreteScheduler, {"use_karras_sigmas": False}),
euler_k=(EulerDiscreteScheduler, dict(use_karras_sigmas=True)), "euler_k": (EulerDiscreteScheduler, {"use_karras_sigmas": True}),
euler_a=(EulerAncestralDiscreteScheduler, dict()), "euler_a": (EulerAncestralDiscreteScheduler, {}),
kdpm_2=(KDPM2DiscreteScheduler, dict()), "kdpm_2": (KDPM2DiscreteScheduler, {}),
kdpm_2_a=(KDPM2AncestralDiscreteScheduler, dict()), "kdpm_2_a": (KDPM2AncestralDiscreteScheduler, {}),
dpmpp_2s=(DPMSolverSinglestepScheduler, dict(use_karras_sigmas=False)), "dpmpp_2s": (DPMSolverSinglestepScheduler, {"use_karras_sigmas": False}),
dpmpp_2s_k=(DPMSolverSinglestepScheduler, dict(use_karras_sigmas=True)), "dpmpp_2s_k": (DPMSolverSinglestepScheduler, {"use_karras_sigmas": True}),
dpmpp_2m=(DPMSolverMultistepScheduler, dict(use_karras_sigmas=False)), "dpmpp_2m": (DPMSolverMultistepScheduler, {"use_karras_sigmas": False}),
dpmpp_2m_k=(DPMSolverMultistepScheduler, dict(use_karras_sigmas=True)), "dpmpp_2m_k": (DPMSolverMultistepScheduler, {"use_karras_sigmas": True}),
dpmpp_2m_sde=(DPMSolverMultistepScheduler, dict(use_karras_sigmas=False, algorithm_type="sde-dpmsolver++")), "dpmpp_2m_sde": (DPMSolverMultistepScheduler, {"use_karras_sigmas": False, "algorithm_type": "sde-dpmsolver++"}),
dpmpp_2m_sde_k=(DPMSolverMultistepScheduler, dict(use_karras_sigmas=True, algorithm_type="sde-dpmsolver++")), "dpmpp_2m_sde_k": (DPMSolverMultistepScheduler, {"use_karras_sigmas": True, "algorithm_type": "sde-dpmsolver++"}),
dpmpp_sde=(DPMSolverSDEScheduler, dict(use_karras_sigmas=False, noise_sampler_seed=0)), "dpmpp_sde": (DPMSolverSDEScheduler, {"use_karras_sigmas": False, "noise_sampler_seed": 0}),
dpmpp_sde_k=(DPMSolverSDEScheduler, dict(use_karras_sigmas=True, noise_sampler_seed=0)), "dpmpp_sde_k": (DPMSolverSDEScheduler, {"use_karras_sigmas": True, "noise_sampler_seed": 0}),
unipc=(UniPCMultistepScheduler, dict(cpu_only=True)), "unipc": (UniPCMultistepScheduler, {"cpu_only": True}),
lcm=(LCMScheduler, dict()), "lcm": (LCMScheduler, {}),
) }

View File

@ -615,7 +615,7 @@ def do_textual_inversion_training(
vae_info = model_manager.get_model(*model_meta, submodel=SubModelType.Vae) vae_info = model_manager.get_model(*model_meta, submodel=SubModelType.Vae)
unet_info = model_manager.get_model(*model_meta, submodel=SubModelType.UNet) unet_info = model_manager.get_model(*model_meta, submodel=SubModelType.UNet)
pipeline_args = dict(local_files_only=True) pipeline_args = {"local_files_only": True}
if tokenizer_name: if tokenizer_name:
tokenizer = CLIPTokenizer.from_pretrained(tokenizer_name, **pipeline_args) tokenizer = CLIPTokenizer.from_pretrained(tokenizer_name, **pipeline_args)
else: else:

View File

@ -225,34 +225,34 @@ def basicConfig(**kwargs):
_FACILITY_MAP = ( _FACILITY_MAP = (
dict( {
LOG_KERN=syslog.LOG_KERN, "LOG_KERN": syslog.LOG_KERN,
LOG_USER=syslog.LOG_USER, "LOG_USER": syslog.LOG_USER,
LOG_MAIL=syslog.LOG_MAIL, "LOG_MAIL": syslog.LOG_MAIL,
LOG_DAEMON=syslog.LOG_DAEMON, "LOG_DAEMON": syslog.LOG_DAEMON,
LOG_AUTH=syslog.LOG_AUTH, "LOG_AUTH": syslog.LOG_AUTH,
LOG_LPR=syslog.LOG_LPR, "LOG_LPR": syslog.LOG_LPR,
LOG_NEWS=syslog.LOG_NEWS, "LOG_NEWS": syslog.LOG_NEWS,
LOG_UUCP=syslog.LOG_UUCP, "LOG_UUCP": syslog.LOG_UUCP,
LOG_CRON=syslog.LOG_CRON, "LOG_CRON": syslog.LOG_CRON,
LOG_SYSLOG=syslog.LOG_SYSLOG, "LOG_SYSLOG": syslog.LOG_SYSLOG,
LOG_LOCAL0=syslog.LOG_LOCAL0, "LOG_LOCAL0": syslog.LOG_LOCAL0,
LOG_LOCAL1=syslog.LOG_LOCAL1, "LOG_LOCAL1": syslog.LOG_LOCAL1,
LOG_LOCAL2=syslog.LOG_LOCAL2, "LOG_LOCAL2": syslog.LOG_LOCAL2,
LOG_LOCAL3=syslog.LOG_LOCAL3, "LOG_LOCAL3": syslog.LOG_LOCAL3,
LOG_LOCAL4=syslog.LOG_LOCAL4, "LOG_LOCAL4": syslog.LOG_LOCAL4,
LOG_LOCAL5=syslog.LOG_LOCAL5, "LOG_LOCAL5": syslog.LOG_LOCAL5,
LOG_LOCAL6=syslog.LOG_LOCAL6, "LOG_LOCAL6": syslog.LOG_LOCAL6,
LOG_LOCAL7=syslog.LOG_LOCAL7, "LOG_LOCAL7": syslog.LOG_LOCAL7,
) }
if SYSLOG_AVAILABLE if SYSLOG_AVAILABLE
else dict() else {}
) )
_SOCK_MAP = dict( _SOCK_MAP = {
SOCK_STREAM=socket.SOCK_STREAM, "SOCK_STREAM": socket.SOCK_STREAM,
SOCK_DGRAM=socket.SOCK_DGRAM, "SOCK_DGRAM": socket.SOCK_DGRAM,
) }
class InvokeAIFormatter(logging.Formatter): class InvokeAIFormatter(logging.Formatter):
@ -344,7 +344,7 @@ LOG_FORMATTERS = {
class InvokeAILogger(object): class InvokeAILogger(object):
loggers = dict() loggers = {}
@classmethod @classmethod
def get_logger( def get_logger(
@ -364,7 +364,7 @@ class InvokeAILogger(object):
@classmethod @classmethod
def get_loggers(cls, config: InvokeAIAppConfig) -> list[logging.Handler]: def get_loggers(cls, config: InvokeAIAppConfig) -> list[logging.Handler]:
handler_strs = config.log_handlers handler_strs = config.log_handlers
handlers = list() handlers = []
for handler in handler_strs: for handler in handler_strs:
handler_name, *args = handler.split("=", 2) handler_name, *args = handler.split("=", 2)
args = args[0] if len(args) > 0 else None args = args[0] if len(args) > 0 else None
@ -398,7 +398,7 @@ class InvokeAILogger(object):
raise ValueError("syslog is not available on this system") raise ValueError("syslog is not available on this system")
if not args: if not args:
args = "/dev/log" if Path("/dev/log").exists() else "address:localhost:514" args = "/dev/log" if Path("/dev/log").exists() else "address:localhost:514"
syslog_args = dict() syslog_args = {}
try: try:
for a in args.split(","): for a in args.split(","):
arg_name, *arg_value = a.split(":", 2) arg_name, *arg_value = a.split(":", 2)
@ -434,7 +434,7 @@ class InvokeAILogger(object):
path = url.path path = url.path
port = url.port or 80 port = url.port or 80
syslog_args = dict() syslog_args = {}
for a in arg_list: for a in arg_list:
arg_name, *arg_value = a.split(":", 2) arg_name, *arg_value = a.split(":", 2)
if arg_name == "method": if arg_name == "method":

View File

@ -26,7 +26,7 @@ def log_txt_as_img(wh, xc, size=10):
# wh a tuple of (width, height) # wh a tuple of (width, height)
# xc a list of captions to plot # xc a list of captions to plot
b = len(xc) b = len(xc)
txts = list() txts = []
for bi in range(b): for bi in range(b):
txt = Image.new("RGB", wh, color="white") txt = Image.new("RGB", wh, color="white")
draw = ImageDraw.Draw(txt) draw = ImageDraw.Draw(txt)
@ -90,7 +90,7 @@ def instantiate_from_config(config, **kwargs):
elif config == "__is_unconditional__": elif config == "__is_unconditional__":
return None return None
raise KeyError("Expected key `target` to instantiate.") raise KeyError("Expected key `target` to instantiate.")
return get_obj_from_str(config["target"])(**config.get("params", dict()), **kwargs) return get_obj_from_str(config["target"])(**config.get("params", {}), **kwargs)
def get_obj_from_str(string, reload=False): def get_obj_from_str(string, reload=False):

View File

@ -341,19 +341,19 @@ class InvokeAIMetadataParser:
# this was more elegant as a case statement, but that's not available in python 3.9 # this was more elegant as a case statement, but that's not available in python 3.9
if old_scheduler is None: if old_scheduler is None:
return None return None
scheduler_map = dict( scheduler_map = {
ddim="ddim", "ddim": "ddim",
plms="pnmd", "plms": "pnmd",
k_lms="lms", "k_lms": "lms",
k_dpm_2="kdpm_2", "k_dpm_2": "kdpm_2",
k_dpm_2_a="kdpm_2_a", "k_dpm_2_a": "kdpm_2_a",
dpmpp_2="dpmpp_2s", "dpmpp_2": "dpmpp_2s",
k_dpmpp_2="dpmpp_2m", "k_dpmpp_2": "dpmpp_2m",
k_dpmpp_2_a=None, # invalid, in 2.3.x, selecting this sample would just fallback to last run or plms if new session "k_dpmpp_2_a": None, # invalid, in 2.3.x, selecting this sample would just fallback to last run or plms if new session
k_euler="euler", "k_euler": "euler",
k_euler_a="euler_a", "k_euler_a": "euler_a",
k_heun="heun", "k_heun": "heun",
) }
return scheduler_map.get(old_scheduler) return scheduler_map.get(old_scheduler)
def split_prompt(self, raw_prompt: str): def split_prompt(self, raw_prompt: str):

View File

@ -210,7 +210,7 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
############# diffusers tab ########## ############# diffusers tab ##########
def add_starter_pipelines(self) -> dict[str, npyscreen.widget]: def add_starter_pipelines(self) -> dict[str, npyscreen.widget]:
"""Add widgets responsible for selecting diffusers models""" """Add widgets responsible for selecting diffusers models"""
widgets = dict() widgets = {}
models = self.all_models models = self.all_models
starters = self.starter_models starters = self.starter_models
starter_model_labels = self.model_labels starter_model_labels = self.model_labels
@ -261,7 +261,7 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
exclude: set = set(), exclude: set = set(),
) -> dict[str, npyscreen.widget]: ) -> dict[str, npyscreen.widget]:
"""Generic code to create model selection widgets""" """Generic code to create model selection widgets"""
widgets = dict() widgets = {}
model_list = [x for x in self.all_models if self.all_models[x].model_type == model_type and x not in exclude] model_list = [x for x in self.all_models if self.all_models[x].model_type == model_type and x not in exclude]
model_labels = [self.model_labels[x] for x in model_list] model_labels = [self.model_labels[x] for x in model_list]
@ -391,7 +391,7 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
label_width = max([len(models[x].name) for x in models]) label_width = max([len(models[x].name) for x in models])
description_width = window_width - label_width - checkbox_width - spacing_width description_width = window_width - label_width - checkbox_width - spacing_width
result = dict() result = {}
for x in models.keys(): for x in models.keys():
description = models[x].description description = models[x].description
description = ( description = (
@ -433,11 +433,11 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
parent_conn, child_conn = Pipe() parent_conn, child_conn = Pipe()
p = Process( p = Process(
target=process_and_execute, target=process_and_execute,
kwargs=dict( kwargs={
opt=app.program_opts, "opt": app.program_opts,
selections=app.install_selections, "selections": app.install_selections,
conn_out=child_conn, "conn_out": child_conn,
), },
) )
p.start() p.start()
child_conn.close() child_conn.close()
@ -558,7 +558,7 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
for section in ui_sections: for section in ui_sections:
if "models_selected" not in section: if "models_selected" not in section:
continue continue
selected = set([section["models"][x] for x in section["models_selected"].value]) selected = {section["models"][x] for x in section["models_selected"].value}
models_to_install = [x for x in selected if not self.all_models[x].installed] models_to_install = [x for x in selected if not self.all_models[x].installed]
models_to_remove = [x for x in section["models"] if x not in selected and self.all_models[x].installed] models_to_remove = [x for x in section["models"] if x not in selected and self.all_models[x].installed]
selections.remove_models.extend(models_to_remove) selections.remove_models.extend(models_to_remove)

View File

@ -275,14 +275,14 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
interp = self.interpolations[self.merge_method.value[0]] interp = self.interpolations[self.merge_method.value[0]]
bases = ["sd-1", "sd-2", "sdxl"] bases = ["sd-1", "sd-2", "sdxl"]
args = dict( args = {
model_names=models, "model_names": models,
base_model=BaseModelType(bases[self.base_select.value[0]]), "base_model": BaseModelType(bases[self.base_select.value[0]]),
alpha=self.alpha.value, "alpha": self.alpha.value,
interp=interp, "interp": interp,
force=self.force.value, "force": self.force.value,
merged_model_name=self.merged_model_name.value, "merged_model_name": self.merged_model_name.value,
) }
return args return args
def check_for_overwrite(self) -> bool: def check_for_overwrite(self) -> bool:
@ -297,7 +297,7 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
def validate_field_values(self) -> bool: def validate_field_values(self) -> bool:
bad_fields = [] bad_fields = []
model_names = self.model_names model_names = self.model_names
selected_models = set((model_names[self.model1.value[0]], model_names[self.model2.value[0]])) selected_models = {model_names[self.model1.value[0]], model_names[self.model2.value[0]]}
if self.model3.value[0] > 0: if self.model3.value[0] > 0:
selected_models.add(model_names[self.model3.value[0] - 1]) selected_models.add(model_names[self.model3.value[0] - 1])
if len(selected_models) < 2: if len(selected_models) < 2:

View File

@ -276,13 +276,13 @@ class textualInversionForm(npyscreen.FormMultiPageAction):
def get_model_names(self) -> Tuple[List[str], int]: def get_model_names(self) -> Tuple[List[str], int]:
conf = OmegaConf.load(config.root_dir / "configs/models.yaml") conf = OmegaConf.load(config.root_dir / "configs/models.yaml")
model_names = [idx for idx in sorted(list(conf.keys())) if conf[idx].get("format", None) == "diffusers"] model_names = [idx for idx in sorted(conf.keys()) if conf[idx].get("format", None) == "diffusers"]
defaults = [idx for idx in range(len(model_names)) if "default" in conf[model_names[idx]]] defaults = [idx for idx in range(len(model_names)) if "default" in conf[model_names[idx]]]
default = defaults[0] if len(defaults) > 0 else 0 default = defaults[0] if len(defaults) > 0 else 0
return (model_names, default) return (model_names, default)
def marshall_arguments(self) -> dict: def marshall_arguments(self) -> dict:
args = dict() args = {}
# the choices # the choices
args.update( args.update(

View File

@ -37,22 +37,22 @@ def main():
if args.all_models or model_type == "diffusers": if args.all_models or model_type == "diffusers":
for d in dirs: for d in dirs:
conf[f"{base}/{model_type}/{d}"] = dict( conf[f"{base}/{model_type}/{d}"] = {
path=os.path.join(root, d), "path": os.path.join(root, d),
description=f"{model_type} model {d}", "description": f"{model_type} model {d}",
format="folder", "format": "folder",
base=base, "base": base,
) }
for f in files: for f in files:
basename = Path(f).stem basename = Path(f).stem
format = Path(f).suffix[1:] format = Path(f).suffix[1:]
conf[f"{base}/{model_type}/{basename}"] = dict( conf[f"{base}/{model_type}/{basename}"] = {
path=os.path.join(root, f), "path": os.path.join(root, f),
description=f"{model_type} model {basename}", "description": f"{model_type} model {basename}",
format=format, "format": format,
base=base, "base": base,
) }
OmegaConf.save(config=dict(sorted(conf.items())), f=sys.stdout) OmegaConf.save(config=dict(sorted(conf.items())), f=sys.stdout)

View File

@ -149,8 +149,8 @@ def test_graph_state_expands_iterator(mock_services):
invoke_next(g, mock_services) invoke_next(g, mock_services)
prepared_add_nodes = g.source_prepared_mapping["3"] prepared_add_nodes = g.source_prepared_mapping["3"]
results = set([g.results[n].value for n in prepared_add_nodes]) results = {g.results[n].value for n in prepared_add_nodes}
expected = set([1, 11, 21]) expected = {1, 11, 21}
assert results == expected assert results == expected
@ -229,7 +229,7 @@ def test_graph_executes_depth_first(mock_services):
# Because ordering is not guaranteed, we cannot compare results directly. # Because ordering is not guaranteed, we cannot compare results directly.
# Instead, we must count the number of results. # Instead, we must count the number of results.
def get_completed_count(g, id): def get_completed_count(g, id):
ids = [i for i in g.source_prepared_mapping[id]] ids = list(g.source_prepared_mapping[id])
completed_ids = [i for i in g.executed if i in ids] completed_ids = [i for i in g.executed if i in ids]
return len(completed_ids) return len(completed_ids)

View File

@ -503,8 +503,8 @@ def test_graph_expands_subgraph():
g.add_edge(create_edge("1.2", "value", "2", "a")) g.add_edge(create_edge("1.2", "value", "2", "a"))
dg = g.nx_graph_flat() dg = g.nx_graph_flat()
assert set(dg.nodes) == set(["1.1", "1.2", "2"]) assert set(dg.nodes) == {"1.1", "1.2", "2"}
assert set(dg.edges) == set([("1.1", "1.2"), ("1.2", "2")]) assert set(dg.edges) == {("1.1", "1.2"), ("1.2", "2")}
def test_graph_subgraph_t2i(): def test_graph_subgraph_t2i():
@ -532,9 +532,7 @@ def test_graph_subgraph_t2i():
# Validate # Validate
dg = g.nx_graph_flat() dg = g.nx_graph_flat()
assert set(dg.nodes) == set( assert set(dg.nodes) == {"1.width", "1.height", "1.seed", "1.3", "1.4", "1.5", "1.6", "1.7", "1.8", "2", "3", "4"}
["1.width", "1.height", "1.seed", "1.3", "1.4", "1.5", "1.6", "1.7", "1.8", "2", "3", "4"]
)
expected_edges = [(f"1.{e.source.node_id}", f"1.{e.destination.node_id}") for e in lg.graph.edges] expected_edges = [(f"1.{e.source.node_id}", f"1.{e.destination.node_id}") for e in lg.graph.edges]
expected_edges.extend([("2", "1.width"), ("3", "1.height"), ("1.8", "4")]) expected_edges.extend([("2", "1.width"), ("3", "1.height"), ("1.8", "4")])
print(expected_edges) print(expected_edges)

View File

@ -130,7 +130,7 @@ class TestEventService(EventServiceBase):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.events = list() self.events = []
def dispatch(self, event_name: str, payload: Any) -> None: def dispatch(self, event_name: str, payload: Any) -> None:
pass pass