mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
chore: ruff check - fix flake8-comprensions
This commit is contained in:
parent
43f2398e14
commit
3a136420d5
@ -28,7 +28,7 @@ class FastAPIEventService(EventServiceBase):
|
|||||||
self.__queue.put(None)
|
self.__queue.put(None)
|
||||||
|
|
||||||
def dispatch(self, event_name: str, payload: Any) -> None:
|
def dispatch(self, event_name: str, payload: Any) -> None:
|
||||||
self.__queue.put(dict(event_name=event_name, payload=payload))
|
self.__queue.put({"event_name": event_name, "payload": payload})
|
||||||
|
|
||||||
async def __dispatch_from_queue(self, stop_event: threading.Event):
|
async def __dispatch_from_queue(self, stop_event: threading.Event):
|
||||||
"""Get events on from the queue and dispatch them, from the correct thread"""
|
"""Get events on from the queue and dispatch them, from the correct thread"""
|
||||||
|
@ -55,7 +55,7 @@ async def list_models(
|
|||||||
) -> ModelsList:
|
) -> ModelsList:
|
||||||
"""Gets a list of models"""
|
"""Gets a list of models"""
|
||||||
if base_models and len(base_models) > 0:
|
if base_models and len(base_models) > 0:
|
||||||
models_raw = list()
|
models_raw = []
|
||||||
for base_model in base_models:
|
for base_model in base_models:
|
||||||
models_raw.extend(ApiDependencies.invoker.services.model_manager.list_models(base_model, model_type))
|
models_raw.extend(ApiDependencies.invoker.services.model_manager.list_models(base_model, model_type))
|
||||||
else:
|
else:
|
||||||
|
@ -130,7 +130,7 @@ def custom_openapi() -> dict[str, Any]:
|
|||||||
# Add all outputs
|
# Add all outputs
|
||||||
all_invocations = BaseInvocation.get_invocations()
|
all_invocations = BaseInvocation.get_invocations()
|
||||||
output_types = set()
|
output_types = set()
|
||||||
output_type_titles = dict()
|
output_type_titles = {}
|
||||||
for invoker in all_invocations:
|
for invoker in all_invocations:
|
||||||
output_type = signature(invoker.invoke).return_annotation
|
output_type = signature(invoker.invoke).return_annotation
|
||||||
output_types.add(output_type)
|
output_types.add(output_type)
|
||||||
@ -171,12 +171,12 @@ def custom_openapi() -> dict[str, Any]:
|
|||||||
# print(f"Config with name {name} already defined")
|
# print(f"Config with name {name} already defined")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
openapi_schema["components"]["schemas"][name] = dict(
|
openapi_schema["components"]["schemas"][name] = {
|
||||||
title=name,
|
"title": name,
|
||||||
description="An enumeration.",
|
"description": "An enumeration.",
|
||||||
type="string",
|
"type": "string",
|
||||||
enum=list(v.value for v in model_config_format_enum),
|
"enum": [v.value for v in model_config_format_enum],
|
||||||
)
|
}
|
||||||
|
|
||||||
app.openapi_schema = openapi_schema
|
app.openapi_schema = openapi_schema
|
||||||
return app.openapi_schema
|
return app.openapi_schema
|
||||||
|
@ -25,4 +25,4 @@ spec.loader.exec_module(module)
|
|||||||
|
|
||||||
# add core nodes to __all__
|
# add core nodes to __all__
|
||||||
python_files = filter(lambda f: not f.name.startswith("_"), Path(__file__).parent.glob("*.py"))
|
python_files = filter(lambda f: not f.name.startswith("_"), Path(__file__).parent.glob("*.py"))
|
||||||
__all__ = list(f.stem for f in python_files) # type: ignore
|
__all__ = [f.stem for f in python_files] # type: ignore
|
||||||
|
@ -236,35 +236,35 @@ def InputField(
|
|||||||
Ignored for non-collection fields.
|
Ignored for non-collection fields.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
json_schema_extra_: dict[str, Any] = dict(
|
json_schema_extra_: dict[str, Any] = {
|
||||||
input=input,
|
"input": input,
|
||||||
ui_type=ui_type,
|
"ui_type": ui_type,
|
||||||
ui_component=ui_component,
|
"ui_component": ui_component,
|
||||||
ui_hidden=ui_hidden,
|
"ui_hidden": ui_hidden,
|
||||||
ui_order=ui_order,
|
"ui_order": ui_order,
|
||||||
item_default=item_default,
|
"item_default": item_default,
|
||||||
ui_choice_labels=ui_choice_labels,
|
"ui_choice_labels": ui_choice_labels,
|
||||||
_field_kind="input",
|
"_field_kind": "input",
|
||||||
)
|
}
|
||||||
|
|
||||||
field_args = dict(
|
field_args = {
|
||||||
default=default,
|
"default": default,
|
||||||
default_factory=default_factory,
|
"default_factory": default_factory,
|
||||||
title=title,
|
"title": title,
|
||||||
description=description,
|
"description": description,
|
||||||
pattern=pattern,
|
"pattern": pattern,
|
||||||
strict=strict,
|
"strict": strict,
|
||||||
gt=gt,
|
"gt": gt,
|
||||||
ge=ge,
|
"ge": ge,
|
||||||
lt=lt,
|
"lt": lt,
|
||||||
le=le,
|
"le": le,
|
||||||
multiple_of=multiple_of,
|
"multiple_of": multiple_of,
|
||||||
allow_inf_nan=allow_inf_nan,
|
"allow_inf_nan": allow_inf_nan,
|
||||||
max_digits=max_digits,
|
"max_digits": max_digits,
|
||||||
decimal_places=decimal_places,
|
"decimal_places": decimal_places,
|
||||||
min_length=min_length,
|
"min_length": min_length,
|
||||||
max_length=max_length,
|
"max_length": max_length,
|
||||||
)
|
}
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Invocation definitions have their fields typed correctly for their `invoke()` functions.
|
Invocation definitions have their fields typed correctly for their `invoke()` functions.
|
||||||
@ -299,24 +299,24 @@ def InputField(
|
|||||||
|
|
||||||
# because we are manually making fields optional, we need to store the original required bool for reference later
|
# because we are manually making fields optional, we need to store the original required bool for reference later
|
||||||
if default is PydanticUndefined and default_factory is PydanticUndefined:
|
if default is PydanticUndefined and default_factory is PydanticUndefined:
|
||||||
json_schema_extra_.update(dict(orig_required=True))
|
json_schema_extra_.update({"orig_required": True})
|
||||||
else:
|
else:
|
||||||
json_schema_extra_.update(dict(orig_required=False))
|
json_schema_extra_.update({"orig_required": False})
|
||||||
|
|
||||||
# make Input.Any and Input.Connection fields optional, providing None as a default if the field doesn't already have one
|
# make Input.Any and Input.Connection fields optional, providing None as a default if the field doesn't already have one
|
||||||
if (input is Input.Any or input is Input.Connection) and default_factory is PydanticUndefined:
|
if (input is Input.Any or input is Input.Connection) and default_factory is PydanticUndefined:
|
||||||
default_ = None if default is PydanticUndefined else default
|
default_ = None if default is PydanticUndefined else default
|
||||||
provided_args.update(dict(default=default_))
|
provided_args.update({"default": default_})
|
||||||
if default is not PydanticUndefined:
|
if default is not PydanticUndefined:
|
||||||
# before invoking, we'll grab the original default value and set it on the field if the field wasn't provided a value
|
# before invoking, we'll grab the original default value and set it on the field if the field wasn't provided a value
|
||||||
json_schema_extra_.update(dict(default=default))
|
json_schema_extra_.update({"default": default})
|
||||||
json_schema_extra_.update(dict(orig_default=default))
|
json_schema_extra_.update({"orig_default": default})
|
||||||
elif default is not PydanticUndefined and default_factory is PydanticUndefined:
|
elif default is not PydanticUndefined and default_factory is PydanticUndefined:
|
||||||
default_ = default
|
default_ = default
|
||||||
provided_args.update(dict(default=default_))
|
provided_args.update({"default": default_})
|
||||||
json_schema_extra_.update(dict(orig_default=default_))
|
json_schema_extra_.update({"orig_default": default_})
|
||||||
elif default_factory is not PydanticUndefined:
|
elif default_factory is not PydanticUndefined:
|
||||||
provided_args.update(dict(default_factory=default_factory))
|
provided_args.update({"default_factory": default_factory})
|
||||||
# TODO: cannot serialize default_factory...
|
# TODO: cannot serialize default_factory...
|
||||||
# json_schema_extra_.update(dict(orig_default_factory=default_factory))
|
# json_schema_extra_.update(dict(orig_default_factory=default_factory))
|
||||||
|
|
||||||
@ -383,12 +383,12 @@ def OutputField(
|
|||||||
decimal_places=decimal_places,
|
decimal_places=decimal_places,
|
||||||
min_length=min_length,
|
min_length=min_length,
|
||||||
max_length=max_length,
|
max_length=max_length,
|
||||||
json_schema_extra=dict(
|
json_schema_extra={
|
||||||
ui_type=ui_type,
|
"ui_type": ui_type,
|
||||||
ui_hidden=ui_hidden,
|
"ui_hidden": ui_hidden,
|
||||||
ui_order=ui_order,
|
"ui_order": ui_order,
|
||||||
_field_kind="output",
|
"_field_kind": "output",
|
||||||
),
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -460,14 +460,14 @@ class BaseInvocationOutput(BaseModel):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_output_types(cls) -> Iterable[str]:
|
def get_output_types(cls) -> Iterable[str]:
|
||||||
return map(lambda i: get_type(i), BaseInvocationOutput.get_outputs())
|
return (get_type(i) for i in BaseInvocationOutput.get_outputs())
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None:
|
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None:
|
||||||
# Because we use a pydantic Literal field with default value for the invocation type,
|
# Because we use a pydantic Literal field with default value for the invocation type,
|
||||||
# it will be typed as optional in the OpenAPI schema. Make it required manually.
|
# it will be typed as optional in the OpenAPI schema. Make it required manually.
|
||||||
if "required" not in schema or not isinstance(schema["required"], list):
|
if "required" not in schema or not isinstance(schema["required"], list):
|
||||||
schema["required"] = list()
|
schema["required"] = []
|
||||||
schema["required"].extend(["type"])
|
schema["required"].extend(["type"])
|
||||||
|
|
||||||
model_config = ConfigDict(
|
model_config = ConfigDict(
|
||||||
@ -527,16 +527,11 @@ class BaseInvocation(ABC, BaseModel):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def get_invocations_map(cls) -> dict[str, BaseInvocation]:
|
def get_invocations_map(cls) -> dict[str, BaseInvocation]:
|
||||||
# Get the type strings out of the literals and into a dictionary
|
# Get the type strings out of the literals and into a dictionary
|
||||||
return dict(
|
return {get_type(i): i for i in BaseInvocation.get_invocations()}
|
||||||
map(
|
|
||||||
lambda i: (get_type(i), i),
|
|
||||||
BaseInvocation.get_invocations(),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_invocation_types(cls) -> Iterable[str]:
|
def get_invocation_types(cls) -> Iterable[str]:
|
||||||
return map(lambda i: get_type(i), BaseInvocation.get_invocations())
|
return (get_type(i) for i in BaseInvocation.get_invocations())
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_output_type(cls) -> BaseInvocationOutput:
|
def get_output_type(cls) -> BaseInvocationOutput:
|
||||||
@ -555,7 +550,7 @@ class BaseInvocation(ABC, BaseModel):
|
|||||||
if uiconfig and hasattr(uiconfig, "version"):
|
if uiconfig and hasattr(uiconfig, "version"):
|
||||||
schema["version"] = uiconfig.version
|
schema["version"] = uiconfig.version
|
||||||
if "required" not in schema or not isinstance(schema["required"], list):
|
if "required" not in schema or not isinstance(schema["required"], list):
|
||||||
schema["required"] = list()
|
schema["required"] = []
|
||||||
schema["required"].extend(["type", "id"])
|
schema["required"].extend(["type", "id"])
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@ -609,15 +604,15 @@ class BaseInvocation(ABC, BaseModel):
|
|||||||
id: str = Field(
|
id: str = Field(
|
||||||
default_factory=uuid_string,
|
default_factory=uuid_string,
|
||||||
description="The id of this instance of an invocation. Must be unique among all instances of invocations.",
|
description="The id of this instance of an invocation. Must be unique among all instances of invocations.",
|
||||||
json_schema_extra=dict(_field_kind="internal"),
|
json_schema_extra={"_field_kind": "internal"},
|
||||||
)
|
)
|
||||||
is_intermediate: bool = Field(
|
is_intermediate: bool = Field(
|
||||||
default=False,
|
default=False,
|
||||||
description="Whether or not this is an intermediate invocation.",
|
description="Whether or not this is an intermediate invocation.",
|
||||||
json_schema_extra=dict(ui_type=UIType.IsIntermediate, _field_kind="internal"),
|
json_schema_extra={"ui_type": UIType.IsIntermediate, "_field_kind": "internal"},
|
||||||
)
|
)
|
||||||
use_cache: bool = Field(
|
use_cache: bool = Field(
|
||||||
default=True, description="Whether or not to use the cache", json_schema_extra=dict(_field_kind="internal")
|
default=True, description="Whether or not to use the cache", json_schema_extra={"_field_kind": "internal"}
|
||||||
)
|
)
|
||||||
|
|
||||||
UIConfig: ClassVar[Type[UIConfigBase]]
|
UIConfig: ClassVar[Type[UIConfigBase]]
|
||||||
@ -651,7 +646,7 @@ class _Model(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
# Get all pydantic model attrs, methods, etc
|
# Get all pydantic model attrs, methods, etc
|
||||||
RESERVED_PYDANTIC_FIELD_NAMES = set(map(lambda m: m[0], inspect.getmembers(_Model())))
|
RESERVED_PYDANTIC_FIELD_NAMES = {m[0] for m in inspect.getmembers(_Model())}
|
||||||
|
|
||||||
|
|
||||||
def validate_fields(model_fields: dict[str, FieldInfo], model_type: str) -> None:
|
def validate_fields(model_fields: dict[str, FieldInfo], model_type: str) -> None:
|
||||||
@ -729,7 +724,7 @@ def invocation(
|
|||||||
# Add OpenAPI schema extras
|
# Add OpenAPI schema extras
|
||||||
uiconf_name = cls.__qualname__ + ".UIConfig"
|
uiconf_name = cls.__qualname__ + ".UIConfig"
|
||||||
if not hasattr(cls, "UIConfig") or cls.UIConfig.__qualname__ != uiconf_name:
|
if not hasattr(cls, "UIConfig") or cls.UIConfig.__qualname__ != uiconf_name:
|
||||||
cls.UIConfig = type(uiconf_name, (UIConfigBase,), dict())
|
cls.UIConfig = type(uiconf_name, (UIConfigBase,), {})
|
||||||
if title is not None:
|
if title is not None:
|
||||||
cls.UIConfig.title = title
|
cls.UIConfig.title = title
|
||||||
if tags is not None:
|
if tags is not None:
|
||||||
@ -756,7 +751,7 @@ def invocation(
|
|||||||
|
|
||||||
invocation_type_annotation = Literal[invocation_type] # type: ignore
|
invocation_type_annotation = Literal[invocation_type] # type: ignore
|
||||||
invocation_type_field = Field(
|
invocation_type_field = Field(
|
||||||
title="type", default=invocation_type, json_schema_extra=dict(_field_kind="internal")
|
title="type", default=invocation_type, json_schema_extra={"_field_kind": "internal"}
|
||||||
)
|
)
|
||||||
|
|
||||||
docstring = cls.__doc__
|
docstring = cls.__doc__
|
||||||
@ -802,7 +797,7 @@ def invocation_output(
|
|||||||
# Add the output type to the model.
|
# Add the output type to the model.
|
||||||
|
|
||||||
output_type_annotation = Literal[output_type] # type: ignore
|
output_type_annotation = Literal[output_type] # type: ignore
|
||||||
output_type_field = Field(title="type", default=output_type, json_schema_extra=dict(_field_kind="internal"))
|
output_type_field = Field(title="type", default=output_type, json_schema_extra={"_field_kind": "internal"})
|
||||||
|
|
||||||
docstring = cls.__doc__
|
docstring = cls.__doc__
|
||||||
cls = create_model(
|
cls = create_model(
|
||||||
@ -834,7 +829,7 @@ WorkflowFieldValidator = TypeAdapter(WorkflowField)
|
|||||||
|
|
||||||
class WithWorkflow(BaseModel):
|
class WithWorkflow(BaseModel):
|
||||||
workflow: Optional[WorkflowField] = Field(
|
workflow: Optional[WorkflowField] = Field(
|
||||||
default=None, description=FieldDescriptions.workflow, json_schema_extra=dict(_field_kind="internal")
|
default=None, description=FieldDescriptions.workflow, json_schema_extra={"_field_kind": "internal"}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -852,5 +847,5 @@ MetadataFieldValidator = TypeAdapter(MetadataField)
|
|||||||
|
|
||||||
class WithMetadata(BaseModel):
|
class WithMetadata(BaseModel):
|
||||||
metadata: Optional[MetadataField] = Field(
|
metadata: Optional[MetadataField] = Field(
|
||||||
default=None, description=FieldDescriptions.metadata, json_schema_extra=dict(_field_kind="internal")
|
default=None, description=FieldDescriptions.metadata, json_schema_extra={"_field_kind": "internal"}
|
||||||
)
|
)
|
||||||
|
@ -131,7 +131,7 @@ def prepare_faces_list(
|
|||||||
deduped_faces: list[FaceResultData] = []
|
deduped_faces: list[FaceResultData] = []
|
||||||
|
|
||||||
if len(face_result_list) == 0:
|
if len(face_result_list) == 0:
|
||||||
return list()
|
return []
|
||||||
|
|
||||||
for candidate in face_result_list:
|
for candidate in face_result_list:
|
||||||
should_add = True
|
should_add = True
|
||||||
|
@ -77,7 +77,7 @@ if choose_torch_device() == torch.device("mps"):
|
|||||||
|
|
||||||
DEFAULT_PRECISION = choose_precision(choose_torch_device())
|
DEFAULT_PRECISION = choose_precision(choose_torch_device())
|
||||||
|
|
||||||
SAMPLER_NAME_VALUES = Literal[tuple(list(SCHEDULER_MAP.keys()))]
|
SAMPLER_NAME_VALUES = Literal[tuple(SCHEDULER_MAP.keys())]
|
||||||
|
|
||||||
|
|
||||||
@invocation_output("scheduler_output")
|
@invocation_output("scheduler_output")
|
||||||
|
@ -145,17 +145,17 @@ INTEGER_OPERATIONS = Literal[
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
INTEGER_OPERATIONS_LABELS = dict(
|
INTEGER_OPERATIONS_LABELS = {
|
||||||
ADD="Add A+B",
|
"ADD": "Add A+B",
|
||||||
SUB="Subtract A-B",
|
"SUB": "Subtract A-B",
|
||||||
MUL="Multiply A*B",
|
"MUL": "Multiply A*B",
|
||||||
DIV="Divide A/B",
|
"DIV": "Divide A/B",
|
||||||
EXP="Exponentiate A^B",
|
"EXP": "Exponentiate A^B",
|
||||||
MOD="Modulus A%B",
|
"MOD": "Modulus A%B",
|
||||||
ABS="Absolute Value of A",
|
"ABS": "Absolute Value of A",
|
||||||
MIN="Minimum(A,B)",
|
"MIN": "Minimum(A,B)",
|
||||||
MAX="Maximum(A,B)",
|
"MAX": "Maximum(A,B)",
|
||||||
)
|
}
|
||||||
|
|
||||||
|
|
||||||
@invocation(
|
@invocation(
|
||||||
@ -231,17 +231,17 @@ FLOAT_OPERATIONS = Literal[
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
FLOAT_OPERATIONS_LABELS = dict(
|
FLOAT_OPERATIONS_LABELS = {
|
||||||
ADD="Add A+B",
|
"ADD": "Add A+B",
|
||||||
SUB="Subtract A-B",
|
"SUB": "Subtract A-B",
|
||||||
MUL="Multiply A*B",
|
"MUL": "Multiply A*B",
|
||||||
DIV="Divide A/B",
|
"DIV": "Divide A/B",
|
||||||
EXP="Exponentiate A^B",
|
"EXP": "Exponentiate A^B",
|
||||||
ABS="Absolute Value of A",
|
"ABS": "Absolute Value of A",
|
||||||
SQRT="Square Root of A",
|
"SQRT": "Square Root of A",
|
||||||
MIN="Minimum(A,B)",
|
"MIN": "Minimum(A,B)",
|
||||||
MAX="Maximum(A,B)",
|
"MAX": "Maximum(A,B)",
|
||||||
)
|
}
|
||||||
|
|
||||||
|
|
||||||
@invocation(
|
@invocation(
|
||||||
|
@ -54,7 +54,7 @@ ORT_TO_NP_TYPE = {
|
|||||||
"tensor(double)": np.float64,
|
"tensor(double)": np.float64,
|
||||||
}
|
}
|
||||||
|
|
||||||
PRECISION_VALUES = Literal[tuple(list(ORT_TO_NP_TYPE.keys()))]
|
PRECISION_VALUES = Literal[tuple(ORT_TO_NP_TYPE.keys())]
|
||||||
|
|
||||||
|
|
||||||
@invocation("prompt_onnx", title="ONNX Prompt (Raw)", tags=["prompt", "onnx"], category="conditioning", version="1.0.0")
|
@invocation("prompt_onnx", title="ONNX Prompt (Raw)", tags=["prompt", "onnx"], category="conditioning", version="1.0.0")
|
||||||
@ -252,7 +252,7 @@ class ONNXTextToLatentsInvocation(BaseInvocation):
|
|||||||
scheduler.set_timesteps(self.steps)
|
scheduler.set_timesteps(self.steps)
|
||||||
latents = latents * np.float64(scheduler.init_noise_sigma)
|
latents = latents * np.float64(scheduler.init_noise_sigma)
|
||||||
|
|
||||||
extra_step_kwargs = dict()
|
extra_step_kwargs = {}
|
||||||
if "eta" in set(inspect.signature(scheduler.step).parameters.keys()):
|
if "eta" in set(inspect.signature(scheduler.step).parameters.keys()):
|
||||||
extra_step_kwargs.update(
|
extra_step_kwargs.update(
|
||||||
eta=0.0,
|
eta=0.0,
|
||||||
|
@ -100,7 +100,7 @@ EASING_FUNCTIONS_MAP = {
|
|||||||
"BounceInOut": BounceEaseInOut,
|
"BounceInOut": BounceEaseInOut,
|
||||||
}
|
}
|
||||||
|
|
||||||
EASING_FUNCTION_KEYS = Literal[tuple(list(EASING_FUNCTIONS_MAP.keys()))]
|
EASING_FUNCTION_KEYS = Literal[tuple(EASING_FUNCTIONS_MAP.keys())]
|
||||||
|
|
||||||
|
|
||||||
# actually I think for now could just use CollectionOutput (which is list[Any]
|
# actually I think for now could just use CollectionOutput (which is list[Any]
|
||||||
@ -161,7 +161,7 @@ class StepParamEasingInvocation(BaseInvocation):
|
|||||||
easing_class = EASING_FUNCTIONS_MAP[self.easing]
|
easing_class = EASING_FUNCTIONS_MAP[self.easing]
|
||||||
if log_diagnostics:
|
if log_diagnostics:
|
||||||
context.services.logger.debug("easing class: " + str(easing_class))
|
context.services.logger.debug("easing class: " + str(easing_class))
|
||||||
easing_list = list()
|
easing_list = []
|
||||||
if self.mirror: # "expected" mirroring
|
if self.mirror: # "expected" mirroring
|
||||||
# if number of steps is even, squeeze duration down to (number_of_steps)/2
|
# if number of steps is even, squeeze duration down to (number_of_steps)/2
|
||||||
# and create reverse copy of list to append
|
# and create reverse copy of list to append
|
||||||
@ -178,7 +178,7 @@ class StepParamEasingInvocation(BaseInvocation):
|
|||||||
end=self.end_value,
|
end=self.end_value,
|
||||||
duration=base_easing_duration - 1,
|
duration=base_easing_duration - 1,
|
||||||
)
|
)
|
||||||
base_easing_vals = list()
|
base_easing_vals = []
|
||||||
for step_index in range(base_easing_duration):
|
for step_index in range(base_easing_duration):
|
||||||
easing_val = easing_function.ease(step_index)
|
easing_val = easing_function.ease(step_index)
|
||||||
base_easing_vals.append(easing_val)
|
base_easing_vals.append(easing_val)
|
||||||
|
@ -139,7 +139,7 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
|
|||||||
(board_id,),
|
(board_id,),
|
||||||
)
|
)
|
||||||
result = cast(list[sqlite3.Row], self._cursor.fetchall())
|
result = cast(list[sqlite3.Row], self._cursor.fetchall())
|
||||||
images = list(map(lambda r: deserialize_image_record(dict(r)), result))
|
images = [deserialize_image_record(dict(r)) for r in result]
|
||||||
|
|
||||||
self._cursor.execute(
|
self._cursor.execute(
|
||||||
"""--sql
|
"""--sql
|
||||||
@ -167,7 +167,7 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
|
|||||||
(board_id,),
|
(board_id,),
|
||||||
)
|
)
|
||||||
result = cast(list[sqlite3.Row], self._cursor.fetchall())
|
result = cast(list[sqlite3.Row], self._cursor.fetchall())
|
||||||
image_names = list(map(lambda r: r[0], result))
|
image_names = [r[0] for r in result]
|
||||||
return image_names
|
return image_names
|
||||||
except sqlite3.Error as e:
|
except sqlite3.Error as e:
|
||||||
self._conn.rollback()
|
self._conn.rollback()
|
||||||
|
@ -199,7 +199,7 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
result = cast(list[sqlite3.Row], self._cursor.fetchall())
|
result = cast(list[sqlite3.Row], self._cursor.fetchall())
|
||||||
boards = list(map(lambda r: deserialize_board_record(dict(r)), result))
|
boards = [deserialize_board_record(dict(r)) for r in result]
|
||||||
|
|
||||||
# Get the total number of boards
|
# Get the total number of boards
|
||||||
self._cursor.execute(
|
self._cursor.execute(
|
||||||
@ -236,7 +236,7 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
result = cast(list[sqlite3.Row], self._cursor.fetchall())
|
result = cast(list[sqlite3.Row], self._cursor.fetchall())
|
||||||
boards = list(map(lambda r: deserialize_board_record(dict(r)), result))
|
boards = [deserialize_board_record(dict(r)) for r in result]
|
||||||
|
|
||||||
return boards
|
return boards
|
||||||
|
|
||||||
|
@ -55,7 +55,7 @@ class InvokeAISettings(BaseSettings):
|
|||||||
"""
|
"""
|
||||||
cls = self.__class__
|
cls = self.__class__
|
||||||
type = get_args(get_type_hints(cls)["type"])[0]
|
type = get_args(get_type_hints(cls)["type"])[0]
|
||||||
field_dict = dict({type: dict()})
|
field_dict = {type: {}}
|
||||||
for name, field in self.model_fields.items():
|
for name, field in self.model_fields.items():
|
||||||
if name in cls._excluded_from_yaml():
|
if name in cls._excluded_from_yaml():
|
||||||
continue
|
continue
|
||||||
@ -64,7 +64,7 @@ class InvokeAISettings(BaseSettings):
|
|||||||
)
|
)
|
||||||
value = getattr(self, name)
|
value = getattr(self, name)
|
||||||
if category not in field_dict[type]:
|
if category not in field_dict[type]:
|
||||||
field_dict[type][category] = dict()
|
field_dict[type][category] = {}
|
||||||
# keep paths as strings to make it easier to read
|
# keep paths as strings to make it easier to read
|
||||||
field_dict[type][category][name] = str(value) if isinstance(value, Path) else value
|
field_dict[type][category][name] = str(value) if isinstance(value, Path) else value
|
||||||
conf = OmegaConf.create(field_dict)
|
conf = OmegaConf.create(field_dict)
|
||||||
@ -89,7 +89,7 @@ class InvokeAISettings(BaseSettings):
|
|||||||
# create an upcase version of the environment in
|
# create an upcase version of the environment in
|
||||||
# order to achieve case-insensitive environment
|
# order to achieve case-insensitive environment
|
||||||
# variables (the way Windows does)
|
# variables (the way Windows does)
|
||||||
upcase_environ = dict()
|
upcase_environ = {}
|
||||||
for key, value in os.environ.items():
|
for key, value in os.environ.items():
|
||||||
upcase_environ[key.upper()] = value
|
upcase_environ[key.upper()] = value
|
||||||
|
|
||||||
|
@ -188,18 +188,18 @@ DEFAULT_MAX_VRAM = 0.5
|
|||||||
|
|
||||||
|
|
||||||
class Categories(object):
|
class Categories(object):
|
||||||
WebServer = dict(category="Web Server")
|
WebServer = {"category": "Web Server"}
|
||||||
Features = dict(category="Features")
|
Features = {"category": "Features"}
|
||||||
Paths = dict(category="Paths")
|
Paths = {"category": "Paths"}
|
||||||
Logging = dict(category="Logging")
|
Logging = {"category": "Logging"}
|
||||||
Development = dict(category="Development")
|
Development = {"category": "Development"}
|
||||||
Other = dict(category="Other")
|
Other = {"category": "Other"}
|
||||||
ModelCache = dict(category="Model Cache")
|
ModelCache = {"category": "Model Cache"}
|
||||||
Device = dict(category="Device")
|
Device = {"category": "Device"}
|
||||||
Generation = dict(category="Generation")
|
Generation = {"category": "Generation"}
|
||||||
Queue = dict(category="Queue")
|
Queue = {"category": "Queue"}
|
||||||
Nodes = dict(category="Nodes")
|
Nodes = {"category": "Nodes"}
|
||||||
MemoryPerformance = dict(category="Memory/Performance")
|
MemoryPerformance = {"category": "Memory/Performance"}
|
||||||
|
|
||||||
|
|
||||||
class InvokeAIAppConfig(InvokeAISettings):
|
class InvokeAIAppConfig(InvokeAISettings):
|
||||||
@ -482,7 +482,7 @@ def _find_root() -> Path:
|
|||||||
venv = Path(os.environ.get("VIRTUAL_ENV") or ".")
|
venv = Path(os.environ.get("VIRTUAL_ENV") or ".")
|
||||||
if os.environ.get("INVOKEAI_ROOT"):
|
if os.environ.get("INVOKEAI_ROOT"):
|
||||||
root = Path(os.environ["INVOKEAI_ROOT"])
|
root = Path(os.environ["INVOKEAI_ROOT"])
|
||||||
elif any([(venv.parent / x).exists() for x in [INIT_FILE, LEGACY_INIT_FILE]]):
|
elif any((venv.parent / x).exists() for x in [INIT_FILE, LEGACY_INIT_FILE]):
|
||||||
root = (venv.parent).resolve()
|
root = (venv.parent).resolve()
|
||||||
else:
|
else:
|
||||||
root = Path("~/invokeai").expanduser().resolve()
|
root = Path("~/invokeai").expanduser().resolve()
|
||||||
|
@ -27,7 +27,7 @@ class EventServiceBase:
|
|||||||
payload["timestamp"] = get_timestamp()
|
payload["timestamp"] = get_timestamp()
|
||||||
self.dispatch(
|
self.dispatch(
|
||||||
event_name=EventServiceBase.queue_event,
|
event_name=EventServiceBase.queue_event,
|
||||||
payload=dict(event=event_name, data=payload),
|
payload={"event": event_name, "data": payload},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Define events here for every event in the system.
|
# Define events here for every event in the system.
|
||||||
@ -48,18 +48,18 @@ class EventServiceBase:
|
|||||||
"""Emitted when there is generation progress"""
|
"""Emitted when there is generation progress"""
|
||||||
self.__emit_queue_event(
|
self.__emit_queue_event(
|
||||||
event_name="generator_progress",
|
event_name="generator_progress",
|
||||||
payload=dict(
|
payload={
|
||||||
queue_id=queue_id,
|
"queue_id": queue_id,
|
||||||
queue_item_id=queue_item_id,
|
"queue_item_id": queue_item_id,
|
||||||
queue_batch_id=queue_batch_id,
|
"queue_batch_id": queue_batch_id,
|
||||||
graph_execution_state_id=graph_execution_state_id,
|
"graph_execution_state_id": graph_execution_state_id,
|
||||||
node_id=node.get("id"),
|
"node_id": node.get("id"),
|
||||||
source_node_id=source_node_id,
|
"source_node_id": source_node_id,
|
||||||
progress_image=progress_image.model_dump() if progress_image is not None else None,
|
"progress_image": progress_image.model_dump() if progress_image is not None else None,
|
||||||
step=step,
|
"step": step,
|
||||||
order=order,
|
"order": order,
|
||||||
total_steps=total_steps,
|
"total_steps": total_steps,
|
||||||
),
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
def emit_invocation_complete(
|
def emit_invocation_complete(
|
||||||
@ -75,15 +75,15 @@ class EventServiceBase:
|
|||||||
"""Emitted when an invocation has completed"""
|
"""Emitted when an invocation has completed"""
|
||||||
self.__emit_queue_event(
|
self.__emit_queue_event(
|
||||||
event_name="invocation_complete",
|
event_name="invocation_complete",
|
||||||
payload=dict(
|
payload={
|
||||||
queue_id=queue_id,
|
"queue_id": queue_id,
|
||||||
queue_item_id=queue_item_id,
|
"queue_item_id": queue_item_id,
|
||||||
queue_batch_id=queue_batch_id,
|
"queue_batch_id": queue_batch_id,
|
||||||
graph_execution_state_id=graph_execution_state_id,
|
"graph_execution_state_id": graph_execution_state_id,
|
||||||
node=node,
|
"node": node,
|
||||||
source_node_id=source_node_id,
|
"source_node_id": source_node_id,
|
||||||
result=result,
|
"result": result,
|
||||||
),
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
def emit_invocation_error(
|
def emit_invocation_error(
|
||||||
@ -100,16 +100,16 @@ class EventServiceBase:
|
|||||||
"""Emitted when an invocation has completed"""
|
"""Emitted when an invocation has completed"""
|
||||||
self.__emit_queue_event(
|
self.__emit_queue_event(
|
||||||
event_name="invocation_error",
|
event_name="invocation_error",
|
||||||
payload=dict(
|
payload={
|
||||||
queue_id=queue_id,
|
"queue_id": queue_id,
|
||||||
queue_item_id=queue_item_id,
|
"queue_item_id": queue_item_id,
|
||||||
queue_batch_id=queue_batch_id,
|
"queue_batch_id": queue_batch_id,
|
||||||
graph_execution_state_id=graph_execution_state_id,
|
"graph_execution_state_id": graph_execution_state_id,
|
||||||
node=node,
|
"node": node,
|
||||||
source_node_id=source_node_id,
|
"source_node_id": source_node_id,
|
||||||
error_type=error_type,
|
"error_type": error_type,
|
||||||
error=error,
|
"error": error,
|
||||||
),
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
def emit_invocation_started(
|
def emit_invocation_started(
|
||||||
@ -124,14 +124,14 @@ class EventServiceBase:
|
|||||||
"""Emitted when an invocation has started"""
|
"""Emitted when an invocation has started"""
|
||||||
self.__emit_queue_event(
|
self.__emit_queue_event(
|
||||||
event_name="invocation_started",
|
event_name="invocation_started",
|
||||||
payload=dict(
|
payload={
|
||||||
queue_id=queue_id,
|
"queue_id": queue_id,
|
||||||
queue_item_id=queue_item_id,
|
"queue_item_id": queue_item_id,
|
||||||
queue_batch_id=queue_batch_id,
|
"queue_batch_id": queue_batch_id,
|
||||||
graph_execution_state_id=graph_execution_state_id,
|
"graph_execution_state_id": graph_execution_state_id,
|
||||||
node=node,
|
"node": node,
|
||||||
source_node_id=source_node_id,
|
"source_node_id": source_node_id,
|
||||||
),
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
def emit_graph_execution_complete(
|
def emit_graph_execution_complete(
|
||||||
@ -140,12 +140,12 @@ class EventServiceBase:
|
|||||||
"""Emitted when a session has completed all invocations"""
|
"""Emitted when a session has completed all invocations"""
|
||||||
self.__emit_queue_event(
|
self.__emit_queue_event(
|
||||||
event_name="graph_execution_state_complete",
|
event_name="graph_execution_state_complete",
|
||||||
payload=dict(
|
payload={
|
||||||
queue_id=queue_id,
|
"queue_id": queue_id,
|
||||||
queue_item_id=queue_item_id,
|
"queue_item_id": queue_item_id,
|
||||||
queue_batch_id=queue_batch_id,
|
"queue_batch_id": queue_batch_id,
|
||||||
graph_execution_state_id=graph_execution_state_id,
|
"graph_execution_state_id": graph_execution_state_id,
|
||||||
),
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
def emit_model_load_started(
|
def emit_model_load_started(
|
||||||
@ -162,16 +162,16 @@ class EventServiceBase:
|
|||||||
"""Emitted when a model is requested"""
|
"""Emitted when a model is requested"""
|
||||||
self.__emit_queue_event(
|
self.__emit_queue_event(
|
||||||
event_name="model_load_started",
|
event_name="model_load_started",
|
||||||
payload=dict(
|
payload={
|
||||||
queue_id=queue_id,
|
"queue_id": queue_id,
|
||||||
queue_item_id=queue_item_id,
|
"queue_item_id": queue_item_id,
|
||||||
queue_batch_id=queue_batch_id,
|
"queue_batch_id": queue_batch_id,
|
||||||
graph_execution_state_id=graph_execution_state_id,
|
"graph_execution_state_id": graph_execution_state_id,
|
||||||
model_name=model_name,
|
"model_name": model_name,
|
||||||
base_model=base_model,
|
"base_model": base_model,
|
||||||
model_type=model_type,
|
"model_type": model_type,
|
||||||
submodel=submodel,
|
"submodel": submodel,
|
||||||
),
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
def emit_model_load_completed(
|
def emit_model_load_completed(
|
||||||
@ -189,19 +189,19 @@ class EventServiceBase:
|
|||||||
"""Emitted when a model is correctly loaded (returns model info)"""
|
"""Emitted when a model is correctly loaded (returns model info)"""
|
||||||
self.__emit_queue_event(
|
self.__emit_queue_event(
|
||||||
event_name="model_load_completed",
|
event_name="model_load_completed",
|
||||||
payload=dict(
|
payload={
|
||||||
queue_id=queue_id,
|
"queue_id": queue_id,
|
||||||
queue_item_id=queue_item_id,
|
"queue_item_id": queue_item_id,
|
||||||
queue_batch_id=queue_batch_id,
|
"queue_batch_id": queue_batch_id,
|
||||||
graph_execution_state_id=graph_execution_state_id,
|
"graph_execution_state_id": graph_execution_state_id,
|
||||||
model_name=model_name,
|
"model_name": model_name,
|
||||||
base_model=base_model,
|
"base_model": base_model,
|
||||||
model_type=model_type,
|
"model_type": model_type,
|
||||||
submodel=submodel,
|
"submodel": submodel,
|
||||||
hash=model_info.hash,
|
"hash": model_info.hash,
|
||||||
location=str(model_info.location),
|
"location": str(model_info.location),
|
||||||
precision=str(model_info.precision),
|
"precision": str(model_info.precision),
|
||||||
),
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
def emit_session_retrieval_error(
|
def emit_session_retrieval_error(
|
||||||
@ -216,14 +216,14 @@ class EventServiceBase:
|
|||||||
"""Emitted when session retrieval fails"""
|
"""Emitted when session retrieval fails"""
|
||||||
self.__emit_queue_event(
|
self.__emit_queue_event(
|
||||||
event_name="session_retrieval_error",
|
event_name="session_retrieval_error",
|
||||||
payload=dict(
|
payload={
|
||||||
queue_id=queue_id,
|
"queue_id": queue_id,
|
||||||
queue_item_id=queue_item_id,
|
"queue_item_id": queue_item_id,
|
||||||
queue_batch_id=queue_batch_id,
|
"queue_batch_id": queue_batch_id,
|
||||||
graph_execution_state_id=graph_execution_state_id,
|
"graph_execution_state_id": graph_execution_state_id,
|
||||||
error_type=error_type,
|
"error_type": error_type,
|
||||||
error=error,
|
"error": error,
|
||||||
),
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
def emit_invocation_retrieval_error(
|
def emit_invocation_retrieval_error(
|
||||||
@ -239,15 +239,15 @@ class EventServiceBase:
|
|||||||
"""Emitted when invocation retrieval fails"""
|
"""Emitted when invocation retrieval fails"""
|
||||||
self.__emit_queue_event(
|
self.__emit_queue_event(
|
||||||
event_name="invocation_retrieval_error",
|
event_name="invocation_retrieval_error",
|
||||||
payload=dict(
|
payload={
|
||||||
queue_id=queue_id,
|
"queue_id": queue_id,
|
||||||
queue_item_id=queue_item_id,
|
"queue_item_id": queue_item_id,
|
||||||
queue_batch_id=queue_batch_id,
|
"queue_batch_id": queue_batch_id,
|
||||||
graph_execution_state_id=graph_execution_state_id,
|
"graph_execution_state_id": graph_execution_state_id,
|
||||||
node_id=node_id,
|
"node_id": node_id,
|
||||||
error_type=error_type,
|
"error_type": error_type,
|
||||||
error=error,
|
"error": error,
|
||||||
),
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
def emit_session_canceled(
|
def emit_session_canceled(
|
||||||
@ -260,12 +260,12 @@ class EventServiceBase:
|
|||||||
"""Emitted when a session is canceled"""
|
"""Emitted when a session is canceled"""
|
||||||
self.__emit_queue_event(
|
self.__emit_queue_event(
|
||||||
event_name="session_canceled",
|
event_name="session_canceled",
|
||||||
payload=dict(
|
payload={
|
||||||
queue_id=queue_id,
|
"queue_id": queue_id,
|
||||||
queue_item_id=queue_item_id,
|
"queue_item_id": queue_item_id,
|
||||||
queue_batch_id=queue_batch_id,
|
"queue_batch_id": queue_batch_id,
|
||||||
graph_execution_state_id=graph_execution_state_id,
|
"graph_execution_state_id": graph_execution_state_id,
|
||||||
),
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
def emit_queue_item_status_changed(
|
def emit_queue_item_status_changed(
|
||||||
@ -277,39 +277,39 @@ class EventServiceBase:
|
|||||||
"""Emitted when a queue item's status changes"""
|
"""Emitted when a queue item's status changes"""
|
||||||
self.__emit_queue_event(
|
self.__emit_queue_event(
|
||||||
event_name="queue_item_status_changed",
|
event_name="queue_item_status_changed",
|
||||||
payload=dict(
|
payload={
|
||||||
queue_id=queue_status.queue_id,
|
"queue_id": queue_status.queue_id,
|
||||||
queue_item=dict(
|
"queue_item": {
|
||||||
queue_id=session_queue_item.queue_id,
|
"queue_id": session_queue_item.queue_id,
|
||||||
item_id=session_queue_item.item_id,
|
"item_id": session_queue_item.item_id,
|
||||||
status=session_queue_item.status,
|
"status": session_queue_item.status,
|
||||||
batch_id=session_queue_item.batch_id,
|
"batch_id": session_queue_item.batch_id,
|
||||||
session_id=session_queue_item.session_id,
|
"session_id": session_queue_item.session_id,
|
||||||
error=session_queue_item.error,
|
"error": session_queue_item.error,
|
||||||
created_at=str(session_queue_item.created_at) if session_queue_item.created_at else None,
|
"created_at": str(session_queue_item.created_at) if session_queue_item.created_at else None,
|
||||||
updated_at=str(session_queue_item.updated_at) if session_queue_item.updated_at else None,
|
"updated_at": str(session_queue_item.updated_at) if session_queue_item.updated_at else None,
|
||||||
started_at=str(session_queue_item.started_at) if session_queue_item.started_at else None,
|
"started_at": str(session_queue_item.started_at) if session_queue_item.started_at else None,
|
||||||
completed_at=str(session_queue_item.completed_at) if session_queue_item.completed_at else None,
|
"completed_at": str(session_queue_item.completed_at) if session_queue_item.completed_at else None,
|
||||||
),
|
},
|
||||||
batch_status=batch_status.model_dump(),
|
"batch_status": batch_status.model_dump(),
|
||||||
queue_status=queue_status.model_dump(),
|
"queue_status": queue_status.model_dump(),
|
||||||
),
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
def emit_batch_enqueued(self, enqueue_result: EnqueueBatchResult) -> None:
|
def emit_batch_enqueued(self, enqueue_result: EnqueueBatchResult) -> None:
|
||||||
"""Emitted when a batch is enqueued"""
|
"""Emitted when a batch is enqueued"""
|
||||||
self.__emit_queue_event(
|
self.__emit_queue_event(
|
||||||
event_name="batch_enqueued",
|
event_name="batch_enqueued",
|
||||||
payload=dict(
|
payload={
|
||||||
queue_id=enqueue_result.queue_id,
|
"queue_id": enqueue_result.queue_id,
|
||||||
batch_id=enqueue_result.batch.batch_id,
|
"batch_id": enqueue_result.batch.batch_id,
|
||||||
enqueued=enqueue_result.enqueued,
|
"enqueued": enqueue_result.enqueued,
|
||||||
),
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
def emit_queue_cleared(self, queue_id: str) -> None:
|
def emit_queue_cleared(self, queue_id: str) -> None:
|
||||||
"""Emitted when the queue is cleared"""
|
"""Emitted when the queue is cleared"""
|
||||||
self.__emit_queue_event(
|
self.__emit_queue_event(
|
||||||
event_name="queue_cleared",
|
event_name="queue_cleared",
|
||||||
payload=dict(queue_id=queue_id),
|
payload={"queue_id": queue_id},
|
||||||
)
|
)
|
||||||
|
@ -25,7 +25,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
|
|||||||
__invoker: Invoker
|
__invoker: Invoker
|
||||||
|
|
||||||
def __init__(self, output_folder: Union[str, Path]):
|
def __init__(self, output_folder: Union[str, Path]):
|
||||||
self.__cache = dict()
|
self.__cache = {}
|
||||||
self.__cache_ids = Queue()
|
self.__cache_ids = Queue()
|
||||||
self.__max_cache_size = 10 # TODO: get this from config
|
self.__max_cache_size = 10 # TODO: get this from config
|
||||||
|
|
||||||
|
@ -90,10 +90,7 @@ class ImageRecordDeleteException(Exception):
|
|||||||
|
|
||||||
|
|
||||||
IMAGE_DTO_COLS = ", ".join(
|
IMAGE_DTO_COLS = ", ".join(
|
||||||
list(
|
["images." + c for c in [
|
||||||
map(
|
|
||||||
lambda c: "images." + c,
|
|
||||||
[
|
|
||||||
"image_name",
|
"image_name",
|
||||||
"image_origin",
|
"image_origin",
|
||||||
"image_category",
|
"image_category",
|
||||||
@ -106,9 +103,7 @@ IMAGE_DTO_COLS = ", ".join(
|
|||||||
"updated_at",
|
"updated_at",
|
||||||
"deleted_at",
|
"deleted_at",
|
||||||
"starred",
|
"starred",
|
||||||
],
|
]]
|
||||||
)
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -263,7 +263,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
|||||||
|
|
||||||
if categories is not None:
|
if categories is not None:
|
||||||
# Convert the enum values to unique list of strings
|
# Convert the enum values to unique list of strings
|
||||||
category_strings = list(map(lambda c: c.value, set(categories)))
|
category_strings = [c.value for c in set(categories)]
|
||||||
# Create the correct length of placeholders
|
# Create the correct length of placeholders
|
||||||
placeholders = ",".join("?" * len(category_strings))
|
placeholders = ",".join("?" * len(category_strings))
|
||||||
|
|
||||||
@ -307,7 +307,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
|||||||
# Build the list of images, deserializing each row
|
# Build the list of images, deserializing each row
|
||||||
self._cursor.execute(images_query, images_params)
|
self._cursor.execute(images_query, images_params)
|
||||||
result = cast(list[sqlite3.Row], self._cursor.fetchall())
|
result = cast(list[sqlite3.Row], self._cursor.fetchall())
|
||||||
images = list(map(lambda r: deserialize_image_record(dict(r)), result))
|
images = [deserialize_image_record(dict(r)) for r in result]
|
||||||
|
|
||||||
# Set up and execute the count query, without pagination
|
# Set up and execute the count query, without pagination
|
||||||
count_query += query_conditions + ";"
|
count_query += query_conditions + ";"
|
||||||
@ -386,7 +386,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
|||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
result = cast(list[sqlite3.Row], self._cursor.fetchall())
|
result = cast(list[sqlite3.Row], self._cursor.fetchall())
|
||||||
image_names = list(map(lambda r: r[0], result))
|
image_names = [r[0] for r in result]
|
||||||
self._cursor.execute(
|
self._cursor.execute(
|
||||||
"""--sql
|
"""--sql
|
||||||
DELETE FROM images
|
DELETE FROM images
|
||||||
|
@ -21,8 +21,8 @@ class ImageServiceABC(ABC):
|
|||||||
_on_deleted_callbacks: list[Callable[[str], None]]
|
_on_deleted_callbacks: list[Callable[[str], None]]
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self._on_changed_callbacks = list()
|
self._on_changed_callbacks = []
|
||||||
self._on_deleted_callbacks = list()
|
self._on_deleted_callbacks = []
|
||||||
|
|
||||||
def on_changed(self, on_changed: Callable[[ImageDTO], None]) -> None:
|
def on_changed(self, on_changed: Callable[[ImageDTO], None]) -> None:
|
||||||
"""Register a callback for when an image is changed"""
|
"""Register a callback for when an image is changed"""
|
||||||
|
@ -217,18 +217,13 @@ class ImageService(ImageServiceABC):
|
|||||||
board_id,
|
board_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
image_dtos = list(
|
image_dtos = [image_record_to_dto(
|
||||||
map(
|
|
||||||
lambda r: image_record_to_dto(
|
|
||||||
image_record=r,
|
image_record=r,
|
||||||
image_url=self.__invoker.services.urls.get_image_url(r.image_name),
|
image_url=self.__invoker.services.urls.get_image_url(r.image_name),
|
||||||
thumbnail_url=self.__invoker.services.urls.get_image_url(r.image_name, True),
|
thumbnail_url=self.__invoker.services.urls.get_image_url(r.image_name, True),
|
||||||
board_id=self.__invoker.services.board_image_records.get_board_for_image(r.image_name),
|
board_id=self.__invoker.services.board_image_records.get_board_for_image(r.image_name),
|
||||||
workflow_id=self.__invoker.services.workflow_image_records.get_workflow_for_image(r.image_name),
|
workflow_id=self.__invoker.services.workflow_image_records.get_workflow_for_image(r.image_name),
|
||||||
),
|
) for r in results.items]
|
||||||
results.items,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
return OffsetPaginatedResults[ImageDTO](
|
return OffsetPaginatedResults[ImageDTO](
|
||||||
items=image_dtos,
|
items=image_dtos,
|
||||||
|
@ -26,7 +26,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
|||||||
self.__invoker_thread = Thread(
|
self.__invoker_thread = Thread(
|
||||||
name="invoker_processor",
|
name="invoker_processor",
|
||||||
target=self.__process,
|
target=self.__process,
|
||||||
kwargs=dict(stop_event=self.__stop_event),
|
kwargs={"stop_event": self.__stop_event},
|
||||||
)
|
)
|
||||||
self.__invoker_thread.daemon = True # TODO: make async and do not use threads
|
self.__invoker_thread.daemon = True # TODO: make async and do not use threads
|
||||||
self.__invoker_thread.start()
|
self.__invoker_thread.start()
|
||||||
|
@ -14,7 +14,7 @@ class MemoryInvocationQueue(InvocationQueueABC):
|
|||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.__queue = Queue()
|
self.__queue = Queue()
|
||||||
self.__cancellations = dict()
|
self.__cancellations = {}
|
||||||
|
|
||||||
def get(self) -> InvocationQueueItem:
|
def get(self) -> InvocationQueueItem:
|
||||||
item = self.__queue.get()
|
item = self.__queue.get()
|
||||||
|
@ -142,7 +142,7 @@ class InvocationStatsService(InvocationStatsServiceBase):
|
|||||||
cache_stats = self._cache_stats[graph_id]
|
cache_stats = self._cache_stats[graph_id]
|
||||||
hwm = cache_stats.high_watermark / GIG
|
hwm = cache_stats.high_watermark / GIG
|
||||||
tot = cache_stats.cache_size / GIG
|
tot = cache_stats.cache_size / GIG
|
||||||
loaded = sum([v for v in cache_stats.loaded_model_sizes.values()]) / GIG
|
loaded = sum(list(cache_stats.loaded_model_sizes.values())) / GIG
|
||||||
|
|
||||||
logger.info(f"TOTAL GRAPH EXECUTION TIME: {total_time:7.3f}s")
|
logger.info(f"TOTAL GRAPH EXECUTION TIME: {total_time:7.3f}s")
|
||||||
logger.info("RAM used by InvokeAI process: " + "%4.2fG" % self.ram_used + f" ({self.ram_changed:+5.3f}G)")
|
logger.info("RAM used by InvokeAI process: " + "%4.2fG" % self.ram_used + f" ({self.ram_changed:+5.3f}G)")
|
||||||
|
@ -15,8 +15,8 @@ class ItemStorageABC(ABC, Generic[T]):
|
|||||||
_on_deleted_callbacks: list[Callable[[str], None]]
|
_on_deleted_callbacks: list[Callable[[str], None]]
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self._on_changed_callbacks = list()
|
self._on_changed_callbacks = []
|
||||||
self._on_deleted_callbacks = list()
|
self._on_deleted_callbacks = []
|
||||||
|
|
||||||
"""Base item storage class"""
|
"""Base item storage class"""
|
||||||
|
|
||||||
|
@ -112,7 +112,7 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
|||||||
)
|
)
|
||||||
result = self._cursor.fetchall()
|
result = self._cursor.fetchall()
|
||||||
|
|
||||||
items = list(map(lambda r: self._parse_item(r[0]), result))
|
items = [self._parse_item(r[0]) for r in result]
|
||||||
|
|
||||||
self._cursor.execute(f"""SELECT count(*) FROM {self._table_name};""")
|
self._cursor.execute(f"""SELECT count(*) FROM {self._table_name};""")
|
||||||
count = self._cursor.fetchone()[0]
|
count = self._cursor.fetchone()[0]
|
||||||
@ -132,7 +132,7 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
|||||||
)
|
)
|
||||||
result = self._cursor.fetchall()
|
result = self._cursor.fetchall()
|
||||||
|
|
||||||
items = list(map(lambda r: self._parse_item(r[0]), result))
|
items = [self._parse_item(r[0]) for r in result]
|
||||||
|
|
||||||
self._cursor.execute(
|
self._cursor.execute(
|
||||||
f"""SELECT count(*) FROM {self._table_name} WHERE item LIKE ?;""",
|
f"""SELECT count(*) FROM {self._table_name} WHERE item LIKE ?;""",
|
||||||
|
@ -13,8 +13,8 @@ class LatentsStorageBase(ABC):
|
|||||||
_on_deleted_callbacks: list[Callable[[str], None]]
|
_on_deleted_callbacks: list[Callable[[str], None]]
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self._on_changed_callbacks = list()
|
self._on_changed_callbacks = []
|
||||||
self._on_deleted_callbacks = list()
|
self._on_deleted_callbacks = []
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get(self, name: str) -> torch.Tensor:
|
def get(self, name: str) -> torch.Tensor:
|
||||||
|
@ -19,7 +19,7 @@ class ForwardCacheLatentsStorage(LatentsStorageBase):
|
|||||||
def __init__(self, underlying_storage: LatentsStorageBase, max_cache_size: int = 20):
|
def __init__(self, underlying_storage: LatentsStorageBase, max_cache_size: int = 20):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.__underlying_storage = underlying_storage
|
self.__underlying_storage = underlying_storage
|
||||||
self.__cache = dict()
|
self.__cache = {}
|
||||||
self.__cache_ids = Queue()
|
self.__cache_ids = Queue()
|
||||||
self.__max_cache_size = max_cache_size
|
self.__max_cache_size = max_cache_size
|
||||||
|
|
||||||
|
@ -33,9 +33,9 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
|||||||
self.__thread = Thread(
|
self.__thread = Thread(
|
||||||
name="session_processor",
|
name="session_processor",
|
||||||
target=self.__process,
|
target=self.__process,
|
||||||
kwargs=dict(
|
kwargs={
|
||||||
stop_event=self.__stop_event, poll_now_event=self.__poll_now_event, resume_event=self.__resume_event
|
"stop_event": self.__stop_event, "poll_now_event": self.__poll_now_event, "resume_event": self.__resume_event
|
||||||
),
|
},
|
||||||
)
|
)
|
||||||
self.__thread.start()
|
self.__thread.start()
|
||||||
|
|
||||||
|
@ -129,12 +129,12 @@ class Batch(BaseModel):
|
|||||||
return v
|
return v
|
||||||
|
|
||||||
model_config = ConfigDict(
|
model_config = ConfigDict(
|
||||||
json_schema_extra=dict(
|
json_schema_extra={
|
||||||
required=[
|
"required": [
|
||||||
"graph",
|
"graph",
|
||||||
"runs",
|
"runs",
|
||||||
]
|
]
|
||||||
)
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -191,8 +191,8 @@ class SessionQueueItemWithoutGraph(BaseModel):
|
|||||||
return SessionQueueItemDTO(**queue_item_dict)
|
return SessionQueueItemDTO(**queue_item_dict)
|
||||||
|
|
||||||
model_config = ConfigDict(
|
model_config = ConfigDict(
|
||||||
json_schema_extra=dict(
|
json_schema_extra={
|
||||||
required=[
|
"required": [
|
||||||
"item_id",
|
"item_id",
|
||||||
"status",
|
"status",
|
||||||
"batch_id",
|
"batch_id",
|
||||||
@ -203,7 +203,7 @@ class SessionQueueItemWithoutGraph(BaseModel):
|
|||||||
"created_at",
|
"created_at",
|
||||||
"updated_at",
|
"updated_at",
|
||||||
]
|
]
|
||||||
)
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -222,8 +222,8 @@ class SessionQueueItem(SessionQueueItemWithoutGraph):
|
|||||||
return SessionQueueItem(**queue_item_dict)
|
return SessionQueueItem(**queue_item_dict)
|
||||||
|
|
||||||
model_config = ConfigDict(
|
model_config = ConfigDict(
|
||||||
json_schema_extra=dict(
|
json_schema_extra={
|
||||||
required=[
|
"required": [
|
||||||
"item_id",
|
"item_id",
|
||||||
"status",
|
"status",
|
||||||
"batch_id",
|
"batch_id",
|
||||||
@ -235,7 +235,7 @@ class SessionQueueItem(SessionQueueItemWithoutGraph):
|
|||||||
"created_at",
|
"created_at",
|
||||||
"updated_at",
|
"updated_at",
|
||||||
]
|
]
|
||||||
)
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -78,7 +78,7 @@ def create_system_graphs(graph_library: ItemStorageABC[LibraryGraph]) -> list[Li
|
|||||||
"""Creates the default system graphs, or adds new versions if the old ones don't match"""
|
"""Creates the default system graphs, or adds new versions if the old ones don't match"""
|
||||||
|
|
||||||
# TODO: Uncomment this when we are ready to fix this up to prevent breaking changes
|
# TODO: Uncomment this when we are ready to fix this up to prevent breaking changes
|
||||||
graphs: list[LibraryGraph] = list()
|
graphs: list[LibraryGraph] = []
|
||||||
|
|
||||||
text_to_image = graph_library.get(default_text_to_image_graph_id)
|
text_to_image = graph_library.get(default_text_to_image_graph_id)
|
||||||
|
|
||||||
|
@ -352,7 +352,7 @@ class Graph(BaseModel):
|
|||||||
|
|
||||||
# Validate that all node ids are unique
|
# Validate that all node ids are unique
|
||||||
node_ids = [n.id for n in self.nodes.values()]
|
node_ids = [n.id for n in self.nodes.values()]
|
||||||
duplicate_node_ids = set([node_id for node_id in node_ids if node_ids.count(node_id) >= 2])
|
duplicate_node_ids = {node_id for node_id in node_ids if node_ids.count(node_id) >= 2}
|
||||||
if duplicate_node_ids:
|
if duplicate_node_ids:
|
||||||
raise DuplicateNodeIdError(f"Node ids must be unique, found duplicates {duplicate_node_ids}")
|
raise DuplicateNodeIdError(f"Node ids must be unique, found duplicates {duplicate_node_ids}")
|
||||||
|
|
||||||
@ -616,7 +616,7 @@ class Graph(BaseModel):
|
|||||||
self, node_path: str, prefix: Optional[str] = None
|
self, node_path: str, prefix: Optional[str] = None
|
||||||
) -> list[tuple["Graph", Union[str, None], Edge]]:
|
) -> list[tuple["Graph", Union[str, None], Edge]]:
|
||||||
"""Gets all input edges for a node along with the graph they are in and the graph's path"""
|
"""Gets all input edges for a node along with the graph they are in and the graph's path"""
|
||||||
edges = list()
|
edges = []
|
||||||
|
|
||||||
# Return any input edges that appear in this graph
|
# Return any input edges that appear in this graph
|
||||||
edges.extend([(self, prefix, e) for e in self.edges if e.destination.node_id == node_path])
|
edges.extend([(self, prefix, e) for e in self.edges if e.destination.node_id == node_path])
|
||||||
@ -658,7 +658,7 @@ class Graph(BaseModel):
|
|||||||
self, node_path: str, prefix: Optional[str] = None
|
self, node_path: str, prefix: Optional[str] = None
|
||||||
) -> list[tuple["Graph", Union[str, None], Edge]]:
|
) -> list[tuple["Graph", Union[str, None], Edge]]:
|
||||||
"""Gets all output edges for a node along with the graph they are in and the graph's path"""
|
"""Gets all output edges for a node along with the graph they are in and the graph's path"""
|
||||||
edges = list()
|
edges = []
|
||||||
|
|
||||||
# Return any input edges that appear in this graph
|
# Return any input edges that appear in this graph
|
||||||
edges.extend([(self, prefix, e) for e in self.edges if e.source.node_id == node_path])
|
edges.extend([(self, prefix, e) for e in self.edges if e.source.node_id == node_path])
|
||||||
@ -680,8 +680,8 @@ class Graph(BaseModel):
|
|||||||
new_input: Optional[EdgeConnection] = None,
|
new_input: Optional[EdgeConnection] = None,
|
||||||
new_output: Optional[EdgeConnection] = None,
|
new_output: Optional[EdgeConnection] = None,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
inputs = list([e.source for e in self._get_input_edges(node_path, "collection")])
|
inputs = [e.source for e in self._get_input_edges(node_path, "collection")]
|
||||||
outputs = list([e.destination for e in self._get_output_edges(node_path, "item")])
|
outputs = [e.destination for e in self._get_output_edges(node_path, "item")]
|
||||||
|
|
||||||
if new_input is not None:
|
if new_input is not None:
|
||||||
inputs.append(new_input)
|
inputs.append(new_input)
|
||||||
@ -694,7 +694,7 @@ class Graph(BaseModel):
|
|||||||
|
|
||||||
# Get input and output fields (the fields linked to the iterator's input/output)
|
# Get input and output fields (the fields linked to the iterator's input/output)
|
||||||
input_field = get_output_field(self.get_node(inputs[0].node_id), inputs[0].field)
|
input_field = get_output_field(self.get_node(inputs[0].node_id), inputs[0].field)
|
||||||
output_fields = list([get_input_field(self.get_node(e.node_id), e.field) for e in outputs])
|
output_fields = [get_input_field(self.get_node(e.node_id), e.field) for e in outputs]
|
||||||
|
|
||||||
# Input type must be a list
|
# Input type must be a list
|
||||||
if get_origin(input_field) != list:
|
if get_origin(input_field) != list:
|
||||||
@ -713,8 +713,8 @@ class Graph(BaseModel):
|
|||||||
new_input: Optional[EdgeConnection] = None,
|
new_input: Optional[EdgeConnection] = None,
|
||||||
new_output: Optional[EdgeConnection] = None,
|
new_output: Optional[EdgeConnection] = None,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
inputs = list([e.source for e in self._get_input_edges(node_path, "item")])
|
inputs = [e.source for e in self._get_input_edges(node_path, "item")]
|
||||||
outputs = list([e.destination for e in self._get_output_edges(node_path, "collection")])
|
outputs = [e.destination for e in self._get_output_edges(node_path, "collection")]
|
||||||
|
|
||||||
if new_input is not None:
|
if new_input is not None:
|
||||||
inputs.append(new_input)
|
inputs.append(new_input)
|
||||||
@ -722,18 +722,16 @@ class Graph(BaseModel):
|
|||||||
outputs.append(new_output)
|
outputs.append(new_output)
|
||||||
|
|
||||||
# Get input and output fields (the fields linked to the iterator's input/output)
|
# Get input and output fields (the fields linked to the iterator's input/output)
|
||||||
input_fields = list([get_output_field(self.get_node(e.node_id), e.field) for e in inputs])
|
input_fields = [get_output_field(self.get_node(e.node_id), e.field) for e in inputs]
|
||||||
output_fields = list([get_input_field(self.get_node(e.node_id), e.field) for e in outputs])
|
output_fields = [get_input_field(self.get_node(e.node_id), e.field) for e in outputs]
|
||||||
|
|
||||||
# Validate that all inputs are derived from or match a single type
|
# Validate that all inputs are derived from or match a single type
|
||||||
input_field_types = set(
|
input_field_types = {
|
||||||
[
|
t
|
||||||
t
|
|
||||||
for input_field in input_fields
|
for input_field in input_fields
|
||||||
for t in ([input_field] if get_origin(input_field) is None else get_args(input_field))
|
for t in ([input_field] if get_origin(input_field) is None else get_args(input_field))
|
||||||
if t != NoneType
|
if t != NoneType
|
||||||
]
|
} # Get unique types
|
||||||
) # Get unique types
|
|
||||||
type_tree = nx.DiGraph()
|
type_tree = nx.DiGraph()
|
||||||
type_tree.add_nodes_from(input_field_types)
|
type_tree.add_nodes_from(input_field_types)
|
||||||
type_tree.add_edges_from([e for e in itertools.permutations(input_field_types, 2) if issubclass(e[1], e[0])])
|
type_tree.add_edges_from([e for e in itertools.permutations(input_field_types, 2) if issubclass(e[1], e[0])])
|
||||||
@ -761,15 +759,15 @@ class Graph(BaseModel):
|
|||||||
"""Returns a NetworkX DiGraph representing the layout of this graph"""
|
"""Returns a NetworkX DiGraph representing the layout of this graph"""
|
||||||
# TODO: Cache this?
|
# TODO: Cache this?
|
||||||
g = nx.DiGraph()
|
g = nx.DiGraph()
|
||||||
g.add_nodes_from([n for n in self.nodes.keys()])
|
g.add_nodes_from(list(self.nodes.keys()))
|
||||||
g.add_edges_from(set([(e.source.node_id, e.destination.node_id) for e in self.edges]))
|
g.add_edges_from({(e.source.node_id, e.destination.node_id) for e in self.edges})
|
||||||
return g
|
return g
|
||||||
|
|
||||||
def nx_graph_with_data(self) -> nx.DiGraph:
|
def nx_graph_with_data(self) -> nx.DiGraph:
|
||||||
"""Returns a NetworkX DiGraph representing the data and layout of this graph"""
|
"""Returns a NetworkX DiGraph representing the data and layout of this graph"""
|
||||||
g = nx.DiGraph()
|
g = nx.DiGraph()
|
||||||
g.add_nodes_from([n for n in self.nodes.items()])
|
g.add_nodes_from(list(self.nodes.items()))
|
||||||
g.add_edges_from(set([(e.source.node_id, e.destination.node_id) for e in self.edges]))
|
g.add_edges_from({(e.source.node_id, e.destination.node_id) for e in self.edges})
|
||||||
return g
|
return g
|
||||||
|
|
||||||
def nx_graph_flat(self, nx_graph: Optional[nx.DiGraph] = None, prefix: Optional[str] = None) -> nx.DiGraph:
|
def nx_graph_flat(self, nx_graph: Optional[nx.DiGraph] = None, prefix: Optional[str] = None) -> nx.DiGraph:
|
||||||
@ -791,7 +789,7 @@ class Graph(BaseModel):
|
|||||||
|
|
||||||
# TODO: figure out if iteration nodes need to be expanded
|
# TODO: figure out if iteration nodes need to be expanded
|
||||||
|
|
||||||
unique_edges = set([(e.source.node_id, e.destination.node_id) for e in self.edges])
|
unique_edges = {(e.source.node_id, e.destination.node_id) for e in self.edges}
|
||||||
g.add_edges_from([(self._get_node_path(e[0], prefix), self._get_node_path(e[1], prefix)) for e in unique_edges])
|
g.add_edges_from([(self._get_node_path(e[0], prefix), self._get_node_path(e[1], prefix)) for e in unique_edges])
|
||||||
return g
|
return g
|
||||||
|
|
||||||
@ -843,8 +841,8 @@ class GraphExecutionState(BaseModel):
|
|||||||
return v
|
return v
|
||||||
|
|
||||||
model_config = ConfigDict(
|
model_config = ConfigDict(
|
||||||
json_schema_extra=dict(
|
json_schema_extra={
|
||||||
required=[
|
"required": [
|
||||||
"id",
|
"id",
|
||||||
"graph",
|
"graph",
|
||||||
"execution_graph",
|
"execution_graph",
|
||||||
@ -855,7 +853,7 @@ class GraphExecutionState(BaseModel):
|
|||||||
"prepared_source_mapping",
|
"prepared_source_mapping",
|
||||||
"source_prepared_mapping",
|
"source_prepared_mapping",
|
||||||
]
|
]
|
||||||
)
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
def next(self) -> Optional[BaseInvocation]:
|
def next(self) -> Optional[BaseInvocation]:
|
||||||
@ -895,7 +893,7 @@ class GraphExecutionState(BaseModel):
|
|||||||
source_node = self.prepared_source_mapping[node_id]
|
source_node = self.prepared_source_mapping[node_id]
|
||||||
prepared_nodes = self.source_prepared_mapping[source_node]
|
prepared_nodes = self.source_prepared_mapping[source_node]
|
||||||
|
|
||||||
if all([n in self.executed for n in prepared_nodes]):
|
if all(n in self.executed for n in prepared_nodes):
|
||||||
self.executed.add(source_node)
|
self.executed.add(source_node)
|
||||||
self.executed_history.append(source_node)
|
self.executed_history.append(source_node)
|
||||||
|
|
||||||
@ -930,7 +928,7 @@ class GraphExecutionState(BaseModel):
|
|||||||
input_collection = getattr(input_collection_prepared_node_output, input_collection_edge.source.field)
|
input_collection = getattr(input_collection_prepared_node_output, input_collection_edge.source.field)
|
||||||
self_iteration_count = len(input_collection)
|
self_iteration_count = len(input_collection)
|
||||||
|
|
||||||
new_nodes: list[str] = list()
|
new_nodes: list[str] = []
|
||||||
if self_iteration_count == 0:
|
if self_iteration_count == 0:
|
||||||
# TODO: should this raise a warning? It might just happen if an empty collection is input, and should be valid.
|
# TODO: should this raise a warning? It might just happen if an empty collection is input, and should be valid.
|
||||||
return new_nodes
|
return new_nodes
|
||||||
@ -940,7 +938,7 @@ class GraphExecutionState(BaseModel):
|
|||||||
|
|
||||||
# Create new edges for this iteration
|
# Create new edges for this iteration
|
||||||
# For collect nodes, this may contain multiple inputs to the same field
|
# For collect nodes, this may contain multiple inputs to the same field
|
||||||
new_edges: list[Edge] = list()
|
new_edges: list[Edge] = []
|
||||||
for edge in input_edges:
|
for edge in input_edges:
|
||||||
for input_node_id in (n[1] for n in iteration_node_map if n[0] == edge.source.node_id):
|
for input_node_id in (n[1] for n in iteration_node_map if n[0] == edge.source.node_id):
|
||||||
new_edge = Edge(
|
new_edge = Edge(
|
||||||
@ -1034,7 +1032,7 @@ class GraphExecutionState(BaseModel):
|
|||||||
|
|
||||||
# Create execution nodes
|
# Create execution nodes
|
||||||
next_node = self.graph.get_node(next_node_id)
|
next_node = self.graph.get_node(next_node_id)
|
||||||
new_node_ids = list()
|
new_node_ids = []
|
||||||
if isinstance(next_node, CollectInvocation):
|
if isinstance(next_node, CollectInvocation):
|
||||||
# Collapse all iterator input mappings and create a single execution node for the collect invocation
|
# Collapse all iterator input mappings and create a single execution node for the collect invocation
|
||||||
all_iteration_mappings = list(
|
all_iteration_mappings = list(
|
||||||
@ -1201,7 +1199,7 @@ class LibraryGraph(BaseModel):
|
|||||||
|
|
||||||
@field_validator("exposed_inputs", "exposed_outputs")
|
@field_validator("exposed_inputs", "exposed_outputs")
|
||||||
def validate_exposed_aliases(cls, v: list[Union[ExposedNodeInput, ExposedNodeOutput]]):
|
def validate_exposed_aliases(cls, v: list[Union[ExposedNodeInput, ExposedNodeOutput]]):
|
||||||
if len(v) != len(set(i.alias for i in v)):
|
if len(v) != len({i.alias for i in v}):
|
||||||
raise ValueError("Duplicate exposed alias")
|
raise ValueError("Duplicate exposed alias")
|
||||||
return v
|
return v
|
||||||
|
|
||||||
|
@ -88,7 +88,7 @@ class PromptFormatter:
|
|||||||
t2i = self.t2i
|
t2i = self.t2i
|
||||||
opt = self.opt
|
opt = self.opt
|
||||||
|
|
||||||
switches = list()
|
switches = []
|
||||||
switches.append(f'"{opt.prompt}"')
|
switches.append(f'"{opt.prompt}"')
|
||||||
switches.append(f"-s{opt.steps or t2i.steps}")
|
switches.append(f"-s{opt.steps or t2i.steps}")
|
||||||
switches.append(f"-W{opt.width or t2i.width}")
|
switches.append(f"-W{opt.width or t2i.width}")
|
||||||
|
@ -40,7 +40,7 @@ class InitImageResizer:
|
|||||||
(rw, rh) = (int(scale * im.width), int(scale * im.height))
|
(rw, rh) = (int(scale * im.width), int(scale * im.height))
|
||||||
|
|
||||||
# round everything to multiples of 64
|
# round everything to multiples of 64
|
||||||
width, height, rw, rh = map(lambda x: x - x % 64, (width, height, rw, rh))
|
width, height, rw, rh = (x - x % 64 for x in (width, height, rw, rh))
|
||||||
|
|
||||||
# no resize necessary, but return a copy
|
# no resize necessary, but return a copy
|
||||||
if im.width == width and im.height == height:
|
if im.width == width and im.height == height:
|
||||||
|
@ -197,7 +197,7 @@ def download_with_progress_bar(model_url: str, model_dest: str, label: str = "th
|
|||||||
|
|
||||||
def download_conversion_models():
|
def download_conversion_models():
|
||||||
target_dir = config.models_path / "core/convert"
|
target_dir = config.models_path / "core/convert"
|
||||||
kwargs = dict() # for future use
|
kwargs = {} # for future use
|
||||||
try:
|
try:
|
||||||
logger.info("Downloading core tokenizers and text encoders")
|
logger.info("Downloading core tokenizers and text encoders")
|
||||||
|
|
||||||
@ -252,26 +252,26 @@ def download_conversion_models():
|
|||||||
def download_realesrgan():
|
def download_realesrgan():
|
||||||
logger.info("Installing ESRGAN Upscaling models...")
|
logger.info("Installing ESRGAN Upscaling models...")
|
||||||
URLs = [
|
URLs = [
|
||||||
dict(
|
{
|
||||||
url="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
|
"url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
|
||||||
dest="core/upscaling/realesrgan/RealESRGAN_x4plus.pth",
|
"dest": "core/upscaling/realesrgan/RealESRGAN_x4plus.pth",
|
||||||
description="RealESRGAN_x4plus.pth",
|
"description": "RealESRGAN_x4plus.pth",
|
||||||
),
|
},
|
||||||
dict(
|
{
|
||||||
url="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth",
|
"url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth",
|
||||||
dest="core/upscaling/realesrgan/RealESRGAN_x4plus_anime_6B.pth",
|
"dest": "core/upscaling/realesrgan/RealESRGAN_x4plus_anime_6B.pth",
|
||||||
description="RealESRGAN_x4plus_anime_6B.pth",
|
"description": "RealESRGAN_x4plus_anime_6B.pth",
|
||||||
),
|
},
|
||||||
dict(
|
{
|
||||||
url="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth",
|
"url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth",
|
||||||
dest="core/upscaling/realesrgan/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth",
|
"dest": "core/upscaling/realesrgan/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth",
|
||||||
description="ESRGAN_SRx4_DF2KOST_official.pth",
|
"description": "ESRGAN_SRx4_DF2KOST_official.pth",
|
||||||
),
|
},
|
||||||
dict(
|
{
|
||||||
url="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
|
"url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
|
||||||
dest="core/upscaling/realesrgan/RealESRGAN_x2plus.pth",
|
"dest": "core/upscaling/realesrgan/RealESRGAN_x2plus.pth",
|
||||||
description="RealESRGAN_x2plus.pth",
|
"description": "RealESRGAN_x2plus.pth",
|
||||||
),
|
},
|
||||||
]
|
]
|
||||||
for model in URLs:
|
for model in URLs:
|
||||||
download_with_progress_bar(model["url"], config.models_path / model["dest"], model["description"])
|
download_with_progress_bar(model["url"], config.models_path / model["dest"], model["description"])
|
||||||
@ -680,7 +680,7 @@ def default_user_selections(program_opts: Namespace) -> InstallSelections:
|
|||||||
if program_opts.default_only
|
if program_opts.default_only
|
||||||
else [models[x].path or models[x].repo_id for x in installer.recommended_models()]
|
else [models[x].path or models[x].repo_id for x in installer.recommended_models()]
|
||||||
if program_opts.yes_to_all
|
if program_opts.yes_to_all
|
||||||
else list(),
|
else [],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -182,10 +182,10 @@ class MigrateTo3(object):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
dest_directory = self.dest_models
|
dest_directory = self.dest_models
|
||||||
kwargs = dict(
|
kwargs = {
|
||||||
cache_dir=self.root_directory / "models/hub",
|
"cache_dir": self.root_directory / "models/hub",
|
||||||
# local_files_only = True
|
# local_files_only = True
|
||||||
)
|
}
|
||||||
try:
|
try:
|
||||||
logger.info("Migrating core tokenizers and text encoders")
|
logger.info("Migrating core tokenizers and text encoders")
|
||||||
target_dir = dest_directory / "core" / "convert"
|
target_dir = dest_directory / "core" / "convert"
|
||||||
@ -316,11 +316,11 @@ class MigrateTo3(object):
|
|||||||
dest_dir = self.dest_models
|
dest_dir = self.dest_models
|
||||||
|
|
||||||
cache = self.root_directory / "models/hub"
|
cache = self.root_directory / "models/hub"
|
||||||
kwargs = dict(
|
kwargs = {
|
||||||
cache_dir=cache,
|
"cache_dir": cache,
|
||||||
safety_checker=None,
|
"safety_checker": None,
|
||||||
# local_files_only = True,
|
# local_files_only = True,
|
||||||
)
|
}
|
||||||
|
|
||||||
owner, repo_name = repo_id.split("/")
|
owner, repo_name = repo_id.split("/")
|
||||||
model_name = model_name or repo_name
|
model_name = model_name or repo_name
|
||||||
|
@ -120,7 +120,7 @@ class ModelInstall(object):
|
|||||||
be treated uniformly. It also sorts the models alphabetically
|
be treated uniformly. It also sorts the models alphabetically
|
||||||
by their name, to improve the display somewhat.
|
by their name, to improve the display somewhat.
|
||||||
"""
|
"""
|
||||||
model_dict = dict()
|
model_dict = {}
|
||||||
|
|
||||||
# first populate with the entries in INITIAL_MODELS.yaml
|
# first populate with the entries in INITIAL_MODELS.yaml
|
||||||
for key, value in self.datasets.items():
|
for key, value in self.datasets.items():
|
||||||
@ -134,7 +134,7 @@ class ModelInstall(object):
|
|||||||
model_dict[key] = model_info
|
model_dict[key] = model_info
|
||||||
|
|
||||||
# supplement with entries in models.yaml
|
# supplement with entries in models.yaml
|
||||||
installed_models = [x for x in self.mgr.list_models()]
|
installed_models = list(self.mgr.list_models())
|
||||||
|
|
||||||
for md in installed_models:
|
for md in installed_models:
|
||||||
base = md["base_model"]
|
base = md["base_model"]
|
||||||
@ -184,7 +184,7 @@ class ModelInstall(object):
|
|||||||
|
|
||||||
def recommended_models(self) -> Set[str]:
|
def recommended_models(self) -> Set[str]:
|
||||||
starters = self.starter_models(all_models=True)
|
starters = self.starter_models(all_models=True)
|
||||||
return set([x for x in starters if self.datasets[x].get("recommended", False)])
|
return {x for x in starters if self.datasets[x].get("recommended", False)}
|
||||||
|
|
||||||
def default_model(self) -> str:
|
def default_model(self) -> str:
|
||||||
starters = self.starter_models()
|
starters = self.starter_models()
|
||||||
@ -234,7 +234,7 @@ class ModelInstall(object):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
if not models_installed:
|
if not models_installed:
|
||||||
models_installed = dict()
|
models_installed = {}
|
||||||
|
|
||||||
model_path_id_or_url = str(model_path_id_or_url).strip("\"' ")
|
model_path_id_or_url = str(model_path_id_or_url).strip("\"' ")
|
||||||
|
|
||||||
@ -252,8 +252,7 @@ class ModelInstall(object):
|
|||||||
|
|
||||||
# folders style or similar
|
# folders style or similar
|
||||||
elif path.is_dir() and any(
|
elif path.is_dir() and any(
|
||||||
[
|
(path / x).exists()
|
||||||
(path / x).exists()
|
|
||||||
for x in {
|
for x in {
|
||||||
"config.json",
|
"config.json",
|
||||||
"model_index.json",
|
"model_index.json",
|
||||||
@ -261,7 +260,6 @@ class ModelInstall(object):
|
|||||||
"pytorch_lora_weights.bin",
|
"pytorch_lora_weights.bin",
|
||||||
"pytorch_lora_weights.safetensors",
|
"pytorch_lora_weights.safetensors",
|
||||||
}
|
}
|
||||||
]
|
|
||||||
):
|
):
|
||||||
models_installed.update({str(model_path_id_or_url): self._install_path(path)})
|
models_installed.update({str(model_path_id_or_url): self._install_path(path)})
|
||||||
|
|
||||||
@ -433,17 +431,17 @@ class ModelInstall(object):
|
|||||||
|
|
||||||
rel_path = self.relative_to_root(path, self.config.models_path)
|
rel_path = self.relative_to_root(path, self.config.models_path)
|
||||||
|
|
||||||
attributes = dict(
|
attributes = {
|
||||||
path=str(rel_path),
|
"path": str(rel_path),
|
||||||
description=str(description),
|
"description": str(description),
|
||||||
model_format=info.format,
|
"model_format": info.format,
|
||||||
)
|
}
|
||||||
legacy_conf = None
|
legacy_conf = None
|
||||||
if info.model_type == ModelType.Main or info.model_type == ModelType.ONNX:
|
if info.model_type == ModelType.Main or info.model_type == ModelType.ONNX:
|
||||||
attributes.update(
|
attributes.update(
|
||||||
dict(
|
{
|
||||||
variant=info.variant_type,
|
"variant": info.variant_type,
|
||||||
)
|
}
|
||||||
)
|
)
|
||||||
if info.format == "checkpoint":
|
if info.format == "checkpoint":
|
||||||
try:
|
try:
|
||||||
@ -474,7 +472,7 @@ class ModelInstall(object):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if legacy_conf:
|
if legacy_conf:
|
||||||
attributes.update(dict(config=str(legacy_conf)))
|
attributes.update({"config": str(legacy_conf)})
|
||||||
return attributes
|
return attributes
|
||||||
|
|
||||||
def relative_to_root(self, path: Path, root: Optional[Path] = None) -> Path:
|
def relative_to_root(self, path: Path, root: Optional[Path] = None) -> Path:
|
||||||
@ -519,7 +517,7 @@ class ModelInstall(object):
|
|||||||
def _download_hf_model(self, repo_id: str, files: List[str], staging: Path, subfolder: None) -> Path:
|
def _download_hf_model(self, repo_id: str, files: List[str], staging: Path, subfolder: None) -> Path:
|
||||||
_, name = repo_id.split("/")
|
_, name = repo_id.split("/")
|
||||||
location = staging / name
|
location = staging / name
|
||||||
paths = list()
|
paths = []
|
||||||
for filename in files:
|
for filename in files:
|
||||||
filePath = Path(filename)
|
filePath = Path(filename)
|
||||||
p = hf_download_with_resume(
|
p = hf_download_with_resume(
|
||||||
|
@ -104,7 +104,7 @@ class ModelPatcher:
|
|||||||
loras: List[Tuple[LoRAModel, float]],
|
loras: List[Tuple[LoRAModel, float]],
|
||||||
prefix: str,
|
prefix: str,
|
||||||
):
|
):
|
||||||
original_weights = dict()
|
original_weights = {}
|
||||||
try:
|
try:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for lora, lora_weight in loras:
|
for lora, lora_weight in loras:
|
||||||
@ -324,7 +324,7 @@ class TextualInversionManager(BaseTextualInversionManager):
|
|||||||
tokenizer: CLIPTokenizer
|
tokenizer: CLIPTokenizer
|
||||||
|
|
||||||
def __init__(self, tokenizer: CLIPTokenizer):
|
def __init__(self, tokenizer: CLIPTokenizer):
|
||||||
self.pad_tokens = dict()
|
self.pad_tokens = {}
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
|
|
||||||
def expand_textual_inversion_token_ids_if_necessary(self, token_ids: list[int]) -> list[int]:
|
def expand_textual_inversion_token_ids_if_necessary(self, token_ids: list[int]) -> list[int]:
|
||||||
@ -385,10 +385,10 @@ class ONNXModelPatcher:
|
|||||||
if not isinstance(model, IAIOnnxRuntimeModel):
|
if not isinstance(model, IAIOnnxRuntimeModel):
|
||||||
raise Exception("Only IAIOnnxRuntimeModel models supported")
|
raise Exception("Only IAIOnnxRuntimeModel models supported")
|
||||||
|
|
||||||
orig_weights = dict()
|
orig_weights = {}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
blended_loras = dict()
|
blended_loras = {}
|
||||||
|
|
||||||
for lora, lora_weight in loras:
|
for lora, lora_weight in loras:
|
||||||
for layer_key, layer in lora.layers.items():
|
for layer_key, layer in lora.layers.items():
|
||||||
@ -404,7 +404,7 @@ class ONNXModelPatcher:
|
|||||||
else:
|
else:
|
||||||
blended_loras[layer_key] = layer_weight
|
blended_loras[layer_key] = layer_weight
|
||||||
|
|
||||||
node_names = dict()
|
node_names = {}
|
||||||
for node in model.nodes.values():
|
for node in model.nodes.values():
|
||||||
node_names[node.name.replace("/", "_").replace(".", "_").lstrip("_")] = node.name
|
node_names[node.name.replace("/", "_").replace(".", "_").lstrip("_")] = node.name
|
||||||
|
|
||||||
|
@ -132,7 +132,7 @@ class ModelCache(object):
|
|||||||
snapshots, so it is recommended to disable this feature unless you are actively inspecting the model cache's
|
snapshots, so it is recommended to disable this feature unless you are actively inspecting the model cache's
|
||||||
behaviour.
|
behaviour.
|
||||||
"""
|
"""
|
||||||
self.model_infos: Dict[str, ModelBase] = dict()
|
self.model_infos: Dict[str, ModelBase] = {}
|
||||||
# allow lazy offloading only when vram cache enabled
|
# allow lazy offloading only when vram cache enabled
|
||||||
self.lazy_offloading = lazy_offloading and max_vram_cache_size > 0
|
self.lazy_offloading = lazy_offloading and max_vram_cache_size > 0
|
||||||
self.precision: torch.dtype = precision
|
self.precision: torch.dtype = precision
|
||||||
@ -147,8 +147,8 @@ class ModelCache(object):
|
|||||||
# used for stats collection
|
# used for stats collection
|
||||||
self.stats = None
|
self.stats = None
|
||||||
|
|
||||||
self._cached_models = dict()
|
self._cached_models = {}
|
||||||
self._cache_stack = list()
|
self._cache_stack = []
|
||||||
|
|
||||||
def _capture_memory_snapshot(self) -> Optional[MemorySnapshot]:
|
def _capture_memory_snapshot(self) -> Optional[MemorySnapshot]:
|
||||||
if self._log_memory_usage:
|
if self._log_memory_usage:
|
||||||
|
@ -363,7 +363,7 @@ class ModelManager(object):
|
|||||||
else:
|
else:
|
||||||
return
|
return
|
||||||
|
|
||||||
self.models = dict()
|
self.models = {}
|
||||||
for model_key, model_config in config.items():
|
for model_key, model_config in config.items():
|
||||||
if model_key.startswith("_"):
|
if model_key.startswith("_"):
|
||||||
continue
|
continue
|
||||||
@ -374,7 +374,7 @@ class ModelManager(object):
|
|||||||
self.models[model_key] = model_class.create_config(**model_config)
|
self.models[model_key] = model_class.create_config(**model_config)
|
||||||
|
|
||||||
# check config version number and update on disk/RAM if necessary
|
# check config version number and update on disk/RAM if necessary
|
||||||
self.cache_keys = dict()
|
self.cache_keys = {}
|
||||||
|
|
||||||
# add controlnet, lora and textual_inversion models from disk
|
# add controlnet, lora and textual_inversion models from disk
|
||||||
self.scan_models_directory()
|
self.scan_models_directory()
|
||||||
@ -902,7 +902,7 @@ class ModelManager(object):
|
|||||||
"""
|
"""
|
||||||
Write current configuration out to the indicated file.
|
Write current configuration out to the indicated file.
|
||||||
"""
|
"""
|
||||||
data_to_save = dict()
|
data_to_save = {}
|
||||||
data_to_save["__metadata__"] = self.config_meta.model_dump()
|
data_to_save["__metadata__"] = self.config_meta.model_dump()
|
||||||
|
|
||||||
for model_key, model_config in self.models.items():
|
for model_key, model_config in self.models.items():
|
||||||
@ -1034,7 +1034,7 @@ class ModelManager(object):
|
|||||||
self.ignore = ignore
|
self.ignore = ignore
|
||||||
|
|
||||||
def on_search_started(self):
|
def on_search_started(self):
|
||||||
self.new_models_found = dict()
|
self.new_models_found = {}
|
||||||
|
|
||||||
def on_model_found(self, model: Path):
|
def on_model_found(self, model: Path):
|
||||||
if model not in self.ignore:
|
if model not in self.ignore:
|
||||||
@ -1106,7 +1106,7 @@ class ModelManager(object):
|
|||||||
# avoid circular import here
|
# avoid circular import here
|
||||||
from invokeai.backend.install.model_install_backend import ModelInstall
|
from invokeai.backend.install.model_install_backend import ModelInstall
|
||||||
|
|
||||||
successfully_installed = dict()
|
successfully_installed = {}
|
||||||
|
|
||||||
installer = ModelInstall(
|
installer = ModelInstall(
|
||||||
config=self.app_config, prediction_type_helper=prediction_type_helper, model_manager=self
|
config=self.app_config, prediction_type_helper=prediction_type_helper, model_manager=self
|
||||||
|
@ -92,7 +92,7 @@ class ModelMerger(object):
|
|||||||
**kwargs - the default DiffusionPipeline.get_config_dict kwargs:
|
**kwargs - the default DiffusionPipeline.get_config_dict kwargs:
|
||||||
cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map
|
cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map
|
||||||
"""
|
"""
|
||||||
model_paths = list()
|
model_paths = []
|
||||||
config = self.manager.app_config
|
config = self.manager.app_config
|
||||||
base_model = BaseModelType(base_model)
|
base_model = BaseModelType(base_model)
|
||||||
vae = None
|
vae = None
|
||||||
@ -124,13 +124,13 @@ class ModelMerger(object):
|
|||||||
dump_path = (dump_path / merged_model_name).as_posix()
|
dump_path = (dump_path / merged_model_name).as_posix()
|
||||||
|
|
||||||
merged_pipe.save_pretrained(dump_path, safe_serialization=True)
|
merged_pipe.save_pretrained(dump_path, safe_serialization=True)
|
||||||
attributes = dict(
|
attributes = {
|
||||||
path=dump_path,
|
"path": dump_path,
|
||||||
description=f"Merge of models {', '.join(model_names)}",
|
"description": f"Merge of models {', '.join(model_names)}",
|
||||||
model_format="diffusers",
|
"model_format": "diffusers",
|
||||||
variant=ModelVariantType.Normal.value,
|
"variant": ModelVariantType.Normal.value,
|
||||||
vae=vae,
|
"vae": vae,
|
||||||
)
|
}
|
||||||
return self.manager.add_model(
|
return self.manager.add_model(
|
||||||
merged_model_name,
|
merged_model_name,
|
||||||
base_model=base_model,
|
base_model=base_model,
|
||||||
|
@ -59,7 +59,7 @@ class ModelSearch(ABC):
|
|||||||
for root, dirs, files in os.walk(path, followlinks=True):
|
for root, dirs, files in os.walk(path, followlinks=True):
|
||||||
if str(Path(root).name).startswith("."):
|
if str(Path(root).name).startswith("."):
|
||||||
self._pruned_paths.add(root)
|
self._pruned_paths.add(root)
|
||||||
if any([Path(root).is_relative_to(x) for x in self._pruned_paths]):
|
if any(Path(root).is_relative_to(x) for x in self._pruned_paths):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
self._items_scanned += len(dirs) + len(files)
|
self._items_scanned += len(dirs) + len(files)
|
||||||
@ -69,8 +69,7 @@ class ModelSearch(ABC):
|
|||||||
self._scanned_dirs.add(path)
|
self._scanned_dirs.add(path)
|
||||||
continue
|
continue
|
||||||
if any(
|
if any(
|
||||||
[
|
(path / x).exists()
|
||||||
(path / x).exists()
|
|
||||||
for x in {
|
for x in {
|
||||||
"config.json",
|
"config.json",
|
||||||
"model_index.json",
|
"model_index.json",
|
||||||
@ -78,7 +77,6 @@ class ModelSearch(ABC):
|
|||||||
"pytorch_lora_weights.bin",
|
"pytorch_lora_weights.bin",
|
||||||
"image_encoder.txt",
|
"image_encoder.txt",
|
||||||
}
|
}
|
||||||
]
|
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
self.on_model_found(path)
|
self.on_model_found(path)
|
||||||
|
@ -97,8 +97,8 @@ MODEL_CLASSES = {
|
|||||||
# },
|
# },
|
||||||
}
|
}
|
||||||
|
|
||||||
MODEL_CONFIGS = list()
|
MODEL_CONFIGS = []
|
||||||
OPENAPI_MODEL_CONFIGS = list()
|
OPENAPI_MODEL_CONFIGS = []
|
||||||
|
|
||||||
|
|
||||||
class OpenAPIModelInfoBase(BaseModel):
|
class OpenAPIModelInfoBase(BaseModel):
|
||||||
@ -133,7 +133,7 @@ for base_model, models in MODEL_CLASSES.items():
|
|||||||
|
|
||||||
|
|
||||||
def get_model_config_enums():
|
def get_model_config_enums():
|
||||||
enums = list()
|
enums = []
|
||||||
|
|
||||||
for model_config in MODEL_CONFIGS:
|
for model_config in MODEL_CONFIGS:
|
||||||
if hasattr(inspect, "get_annotations"):
|
if hasattr(inspect, "get_annotations"):
|
||||||
|
@ -164,7 +164,7 @@ class ModelBase(metaclass=ABCMeta):
|
|||||||
with suppress(Exception):
|
with suppress(Exception):
|
||||||
return cls.__configs
|
return cls.__configs
|
||||||
|
|
||||||
configs = dict()
|
configs = {}
|
||||||
for name in dir(cls):
|
for name in dir(cls):
|
||||||
if name.startswith("__"):
|
if name.startswith("__"):
|
||||||
continue
|
continue
|
||||||
@ -246,8 +246,8 @@ class DiffusersModel(ModelBase):
|
|||||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||||
super().__init__(model_path, base_model, model_type)
|
super().__init__(model_path, base_model, model_type)
|
||||||
|
|
||||||
self.child_types: Dict[str, Type] = dict()
|
self.child_types: Dict[str, Type] = {}
|
||||||
self.child_sizes: Dict[str, int] = dict()
|
self.child_sizes: Dict[str, int] = {}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
config_data = DiffusionPipeline.load_config(self.model_path)
|
config_data = DiffusionPipeline.load_config(self.model_path)
|
||||||
@ -326,8 +326,8 @@ def calc_model_size_by_fs(model_path: str, subfolder: Optional[str] = None, vari
|
|||||||
all_files = os.listdir(model_path)
|
all_files = os.listdir(model_path)
|
||||||
all_files = [f for f in all_files if os.path.isfile(os.path.join(model_path, f))]
|
all_files = [f for f in all_files if os.path.isfile(os.path.join(model_path, f))]
|
||||||
|
|
||||||
fp16_files = set([f for f in all_files if ".fp16." in f or ".fp16-" in f])
|
fp16_files = {f for f in all_files if ".fp16." in f or ".fp16-" in f}
|
||||||
bit8_files = set([f for f in all_files if ".8bit." in f or ".8bit-" in f])
|
bit8_files = {f for f in all_files if ".8bit." in f or ".8bit-" in f}
|
||||||
other_files = set(all_files) - fp16_files - bit8_files
|
other_files = set(all_files) - fp16_files - bit8_files
|
||||||
|
|
||||||
if variant is None:
|
if variant is None:
|
||||||
@ -413,7 +413,7 @@ def _calc_onnx_model_by_data(model) -> int:
|
|||||||
|
|
||||||
|
|
||||||
def _fast_safetensors_reader(path: str):
|
def _fast_safetensors_reader(path: str):
|
||||||
checkpoint = dict()
|
checkpoint = {}
|
||||||
device = torch.device("meta")
|
device = torch.device("meta")
|
||||||
with open(path, "rb") as f:
|
with open(path, "rb") as f:
|
||||||
definition_len = int.from_bytes(f.read(8), "little")
|
definition_len = int.from_bytes(f.read(8), "little")
|
||||||
@ -483,7 +483,7 @@ class IAIOnnxRuntimeModel:
|
|||||||
class _tensor_access:
|
class _tensor_access:
|
||||||
def __init__(self, model):
|
def __init__(self, model):
|
||||||
self.model = model
|
self.model = model
|
||||||
self.indexes = dict()
|
self.indexes = {}
|
||||||
for idx, obj in enumerate(self.model.proto.graph.initializer):
|
for idx, obj in enumerate(self.model.proto.graph.initializer):
|
||||||
self.indexes[obj.name] = idx
|
self.indexes[obj.name] = idx
|
||||||
|
|
||||||
@ -524,7 +524,7 @@ class IAIOnnxRuntimeModel:
|
|||||||
|
|
||||||
class _access_helper:
|
class _access_helper:
|
||||||
def __init__(self, raw_proto):
|
def __init__(self, raw_proto):
|
||||||
self.indexes = dict()
|
self.indexes = {}
|
||||||
self.raw_proto = raw_proto
|
self.raw_proto = raw_proto
|
||||||
for idx, obj in enumerate(raw_proto):
|
for idx, obj in enumerate(raw_proto):
|
||||||
self.indexes[obj.name] = idx
|
self.indexes[obj.name] = idx
|
||||||
@ -549,7 +549,7 @@ class IAIOnnxRuntimeModel:
|
|||||||
return self.indexes.keys()
|
return self.indexes.keys()
|
||||||
|
|
||||||
def values(self):
|
def values(self):
|
||||||
return [obj for obj in self.raw_proto]
|
return list(self.raw_proto)
|
||||||
|
|
||||||
def __init__(self, model_path: str, provider: Optional[str]):
|
def __init__(self, model_path: str, provider: Optional[str]):
|
||||||
self.path = model_path
|
self.path = model_path
|
||||||
|
@ -104,7 +104,7 @@ class ControlNetModel(ModelBase):
|
|||||||
return ControlNetModelFormat.Diffusers
|
return ControlNetModelFormat.Diffusers
|
||||||
|
|
||||||
if os.path.isfile(path):
|
if os.path.isfile(path):
|
||||||
if any([path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt", "pth"]]):
|
if any(path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt", "pth"]):
|
||||||
return ControlNetModelFormat.Checkpoint
|
return ControlNetModelFormat.Checkpoint
|
||||||
|
|
||||||
raise InvalidModelException(f"Not a valid model: {path}")
|
raise InvalidModelException(f"Not a valid model: {path}")
|
||||||
|
@ -73,7 +73,7 @@ class LoRAModel(ModelBase):
|
|||||||
return LoRAModelFormat.Diffusers
|
return LoRAModelFormat.Diffusers
|
||||||
|
|
||||||
if os.path.isfile(path):
|
if os.path.isfile(path):
|
||||||
if any([path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]]):
|
if any(path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]):
|
||||||
return LoRAModelFormat.LyCORIS
|
return LoRAModelFormat.LyCORIS
|
||||||
|
|
||||||
raise InvalidModelException(f"Not a valid model: {path}")
|
raise InvalidModelException(f"Not a valid model: {path}")
|
||||||
@ -499,7 +499,7 @@ class LoRAModelRaw: # (torch.nn.Module):
|
|||||||
stability_unet_keys = list(SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP)
|
stability_unet_keys = list(SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP)
|
||||||
stability_unet_keys.sort()
|
stability_unet_keys.sort()
|
||||||
|
|
||||||
new_state_dict = dict()
|
new_state_dict = {}
|
||||||
for full_key, value in state_dict.items():
|
for full_key, value in state_dict.items():
|
||||||
if full_key.startswith("lora_unet_"):
|
if full_key.startswith("lora_unet_"):
|
||||||
search_key = full_key.replace("lora_unet_", "")
|
search_key = full_key.replace("lora_unet_", "")
|
||||||
@ -545,7 +545,7 @@ class LoRAModelRaw: # (torch.nn.Module):
|
|||||||
|
|
||||||
model = cls(
|
model = cls(
|
||||||
name=file_path.stem, # TODO:
|
name=file_path.stem, # TODO:
|
||||||
layers=dict(),
|
layers={},
|
||||||
)
|
)
|
||||||
|
|
||||||
if file_path.suffix == ".safetensors":
|
if file_path.suffix == ".safetensors":
|
||||||
@ -593,12 +593,12 @@ class LoRAModelRaw: # (torch.nn.Module):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _group_state(state_dict: dict):
|
def _group_state(state_dict: dict):
|
||||||
state_dict_groupped = dict()
|
state_dict_groupped = {}
|
||||||
|
|
||||||
for key, value in state_dict.items():
|
for key, value in state_dict.items():
|
||||||
stem, leaf = key.split(".", 1)
|
stem, leaf = key.split(".", 1)
|
||||||
if stem not in state_dict_groupped:
|
if stem not in state_dict_groupped:
|
||||||
state_dict_groupped[stem] = dict()
|
state_dict_groupped[stem] = {}
|
||||||
state_dict_groupped[stem][leaf] = value
|
state_dict_groupped[stem][leaf] = value
|
||||||
|
|
||||||
return state_dict_groupped
|
return state_dict_groupped
|
||||||
|
@ -110,7 +110,7 @@ class StableDiffusion1Model(DiffusersModel):
|
|||||||
return StableDiffusion1ModelFormat.Diffusers
|
return StableDiffusion1ModelFormat.Diffusers
|
||||||
|
|
||||||
if os.path.isfile(model_path):
|
if os.path.isfile(model_path):
|
||||||
if any([model_path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]]):
|
if any(model_path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]):
|
||||||
return StableDiffusion1ModelFormat.Checkpoint
|
return StableDiffusion1ModelFormat.Checkpoint
|
||||||
|
|
||||||
raise InvalidModelException(f"Not a valid model: {model_path}")
|
raise InvalidModelException(f"Not a valid model: {model_path}")
|
||||||
@ -221,7 +221,7 @@ class StableDiffusion2Model(DiffusersModel):
|
|||||||
return StableDiffusion2ModelFormat.Diffusers
|
return StableDiffusion2ModelFormat.Diffusers
|
||||||
|
|
||||||
if os.path.isfile(model_path):
|
if os.path.isfile(model_path):
|
||||||
if any([model_path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]]):
|
if any(model_path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]):
|
||||||
return StableDiffusion2ModelFormat.Checkpoint
|
return StableDiffusion2ModelFormat.Checkpoint
|
||||||
|
|
||||||
raise InvalidModelException(f"Not a valid model: {model_path}")
|
raise InvalidModelException(f"Not a valid model: {model_path}")
|
||||||
|
@ -71,7 +71,7 @@ class TextualInversionModel(ModelBase):
|
|||||||
return None # diffusers-ti
|
return None # diffusers-ti
|
||||||
|
|
||||||
if os.path.isfile(path):
|
if os.path.isfile(path):
|
||||||
if any([path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt", "bin"]]):
|
if any(path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt", "bin"]):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
raise InvalidModelException(f"Not a valid model: {path}")
|
raise InvalidModelException(f"Not a valid model: {path}")
|
||||||
|
@ -89,7 +89,7 @@ class VaeModel(ModelBase):
|
|||||||
return VaeModelFormat.Diffusers
|
return VaeModelFormat.Diffusers
|
||||||
|
|
||||||
if os.path.isfile(path):
|
if os.path.isfile(path):
|
||||||
if any([path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]]):
|
if any(path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]):
|
||||||
return VaeModelFormat.Checkpoint
|
return VaeModelFormat.Checkpoint
|
||||||
|
|
||||||
raise InvalidModelException(f"Not a valid model: {path}")
|
raise InvalidModelException(f"Not a valid model: {path}")
|
||||||
|
@ -16,28 +16,28 @@ from diffusers import (
|
|||||||
UniPCMultistepScheduler,
|
UniPCMultistepScheduler,
|
||||||
)
|
)
|
||||||
|
|
||||||
SCHEDULER_MAP = dict(
|
SCHEDULER_MAP = {
|
||||||
ddim=(DDIMScheduler, dict()),
|
"ddim": (DDIMScheduler, {}),
|
||||||
ddpm=(DDPMScheduler, dict()),
|
"ddpm": (DDPMScheduler, {}),
|
||||||
deis=(DEISMultistepScheduler, dict()),
|
"deis": (DEISMultistepScheduler, {}),
|
||||||
lms=(LMSDiscreteScheduler, dict(use_karras_sigmas=False)),
|
"lms": (LMSDiscreteScheduler, {"use_karras_sigmas": False}),
|
||||||
lms_k=(LMSDiscreteScheduler, dict(use_karras_sigmas=True)),
|
"lms_k": (LMSDiscreteScheduler, {"use_karras_sigmas": True}),
|
||||||
pndm=(PNDMScheduler, dict()),
|
"pndm": (PNDMScheduler, {}),
|
||||||
heun=(HeunDiscreteScheduler, dict(use_karras_sigmas=False)),
|
"heun": (HeunDiscreteScheduler, {"use_karras_sigmas": False}),
|
||||||
heun_k=(HeunDiscreteScheduler, dict(use_karras_sigmas=True)),
|
"heun_k": (HeunDiscreteScheduler, {"use_karras_sigmas": True}),
|
||||||
euler=(EulerDiscreteScheduler, dict(use_karras_sigmas=False)),
|
"euler": (EulerDiscreteScheduler, {"use_karras_sigmas": False}),
|
||||||
euler_k=(EulerDiscreteScheduler, dict(use_karras_sigmas=True)),
|
"euler_k": (EulerDiscreteScheduler, {"use_karras_sigmas": True}),
|
||||||
euler_a=(EulerAncestralDiscreteScheduler, dict()),
|
"euler_a": (EulerAncestralDiscreteScheduler, {}),
|
||||||
kdpm_2=(KDPM2DiscreteScheduler, dict()),
|
"kdpm_2": (KDPM2DiscreteScheduler, {}),
|
||||||
kdpm_2_a=(KDPM2AncestralDiscreteScheduler, dict()),
|
"kdpm_2_a": (KDPM2AncestralDiscreteScheduler, {}),
|
||||||
dpmpp_2s=(DPMSolverSinglestepScheduler, dict(use_karras_sigmas=False)),
|
"dpmpp_2s": (DPMSolverSinglestepScheduler, {"use_karras_sigmas": False}),
|
||||||
dpmpp_2s_k=(DPMSolverSinglestepScheduler, dict(use_karras_sigmas=True)),
|
"dpmpp_2s_k": (DPMSolverSinglestepScheduler, {"use_karras_sigmas": True}),
|
||||||
dpmpp_2m=(DPMSolverMultistepScheduler, dict(use_karras_sigmas=False)),
|
"dpmpp_2m": (DPMSolverMultistepScheduler, {"use_karras_sigmas": False}),
|
||||||
dpmpp_2m_k=(DPMSolverMultistepScheduler, dict(use_karras_sigmas=True)),
|
"dpmpp_2m_k": (DPMSolverMultistepScheduler, {"use_karras_sigmas": True}),
|
||||||
dpmpp_2m_sde=(DPMSolverMultistepScheduler, dict(use_karras_sigmas=False, algorithm_type="sde-dpmsolver++")),
|
"dpmpp_2m_sde": (DPMSolverMultistepScheduler, {"use_karras_sigmas": False, "algorithm_type": "sde-dpmsolver++"}),
|
||||||
dpmpp_2m_sde_k=(DPMSolverMultistepScheduler, dict(use_karras_sigmas=True, algorithm_type="sde-dpmsolver++")),
|
"dpmpp_2m_sde_k": (DPMSolverMultistepScheduler, {"use_karras_sigmas": True, "algorithm_type": "sde-dpmsolver++"}),
|
||||||
dpmpp_sde=(DPMSolverSDEScheduler, dict(use_karras_sigmas=False, noise_sampler_seed=0)),
|
"dpmpp_sde": (DPMSolverSDEScheduler, {"use_karras_sigmas": False, "noise_sampler_seed": 0}),
|
||||||
dpmpp_sde_k=(DPMSolverSDEScheduler, dict(use_karras_sigmas=True, noise_sampler_seed=0)),
|
"dpmpp_sde_k": (DPMSolverSDEScheduler, {"use_karras_sigmas": True, "noise_sampler_seed": 0}),
|
||||||
unipc=(UniPCMultistepScheduler, dict(cpu_only=True)),
|
"unipc": (UniPCMultistepScheduler, {"cpu_only": True}),
|
||||||
lcm=(LCMScheduler, dict()),
|
"lcm": (LCMScheduler, {}),
|
||||||
)
|
}
|
||||||
|
@ -615,7 +615,7 @@ def do_textual_inversion_training(
|
|||||||
vae_info = model_manager.get_model(*model_meta, submodel=SubModelType.Vae)
|
vae_info = model_manager.get_model(*model_meta, submodel=SubModelType.Vae)
|
||||||
unet_info = model_manager.get_model(*model_meta, submodel=SubModelType.UNet)
|
unet_info = model_manager.get_model(*model_meta, submodel=SubModelType.UNet)
|
||||||
|
|
||||||
pipeline_args = dict(local_files_only=True)
|
pipeline_args = {"local_files_only": True}
|
||||||
if tokenizer_name:
|
if tokenizer_name:
|
||||||
tokenizer = CLIPTokenizer.from_pretrained(tokenizer_name, **pipeline_args)
|
tokenizer = CLIPTokenizer.from_pretrained(tokenizer_name, **pipeline_args)
|
||||||
else:
|
else:
|
||||||
|
@ -225,34 +225,34 @@ def basicConfig(**kwargs):
|
|||||||
|
|
||||||
|
|
||||||
_FACILITY_MAP = (
|
_FACILITY_MAP = (
|
||||||
dict(
|
{
|
||||||
LOG_KERN=syslog.LOG_KERN,
|
"LOG_KERN": syslog.LOG_KERN,
|
||||||
LOG_USER=syslog.LOG_USER,
|
"LOG_USER": syslog.LOG_USER,
|
||||||
LOG_MAIL=syslog.LOG_MAIL,
|
"LOG_MAIL": syslog.LOG_MAIL,
|
||||||
LOG_DAEMON=syslog.LOG_DAEMON,
|
"LOG_DAEMON": syslog.LOG_DAEMON,
|
||||||
LOG_AUTH=syslog.LOG_AUTH,
|
"LOG_AUTH": syslog.LOG_AUTH,
|
||||||
LOG_LPR=syslog.LOG_LPR,
|
"LOG_LPR": syslog.LOG_LPR,
|
||||||
LOG_NEWS=syslog.LOG_NEWS,
|
"LOG_NEWS": syslog.LOG_NEWS,
|
||||||
LOG_UUCP=syslog.LOG_UUCP,
|
"LOG_UUCP": syslog.LOG_UUCP,
|
||||||
LOG_CRON=syslog.LOG_CRON,
|
"LOG_CRON": syslog.LOG_CRON,
|
||||||
LOG_SYSLOG=syslog.LOG_SYSLOG,
|
"LOG_SYSLOG": syslog.LOG_SYSLOG,
|
||||||
LOG_LOCAL0=syslog.LOG_LOCAL0,
|
"LOG_LOCAL0": syslog.LOG_LOCAL0,
|
||||||
LOG_LOCAL1=syslog.LOG_LOCAL1,
|
"LOG_LOCAL1": syslog.LOG_LOCAL1,
|
||||||
LOG_LOCAL2=syslog.LOG_LOCAL2,
|
"LOG_LOCAL2": syslog.LOG_LOCAL2,
|
||||||
LOG_LOCAL3=syslog.LOG_LOCAL3,
|
"LOG_LOCAL3": syslog.LOG_LOCAL3,
|
||||||
LOG_LOCAL4=syslog.LOG_LOCAL4,
|
"LOG_LOCAL4": syslog.LOG_LOCAL4,
|
||||||
LOG_LOCAL5=syslog.LOG_LOCAL5,
|
"LOG_LOCAL5": syslog.LOG_LOCAL5,
|
||||||
LOG_LOCAL6=syslog.LOG_LOCAL6,
|
"LOG_LOCAL6": syslog.LOG_LOCAL6,
|
||||||
LOG_LOCAL7=syslog.LOG_LOCAL7,
|
"LOG_LOCAL7": syslog.LOG_LOCAL7,
|
||||||
)
|
}
|
||||||
if SYSLOG_AVAILABLE
|
if SYSLOG_AVAILABLE
|
||||||
else dict()
|
else {}
|
||||||
)
|
)
|
||||||
|
|
||||||
_SOCK_MAP = dict(
|
_SOCK_MAP = {
|
||||||
SOCK_STREAM=socket.SOCK_STREAM,
|
"SOCK_STREAM": socket.SOCK_STREAM,
|
||||||
SOCK_DGRAM=socket.SOCK_DGRAM,
|
"SOCK_DGRAM": socket.SOCK_DGRAM,
|
||||||
)
|
}
|
||||||
|
|
||||||
|
|
||||||
class InvokeAIFormatter(logging.Formatter):
|
class InvokeAIFormatter(logging.Formatter):
|
||||||
@ -344,7 +344,7 @@ LOG_FORMATTERS = {
|
|||||||
|
|
||||||
|
|
||||||
class InvokeAILogger(object):
|
class InvokeAILogger(object):
|
||||||
loggers = dict()
|
loggers = {}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_logger(
|
def get_logger(
|
||||||
@ -364,7 +364,7 @@ class InvokeAILogger(object):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def get_loggers(cls, config: InvokeAIAppConfig) -> list[logging.Handler]:
|
def get_loggers(cls, config: InvokeAIAppConfig) -> list[logging.Handler]:
|
||||||
handler_strs = config.log_handlers
|
handler_strs = config.log_handlers
|
||||||
handlers = list()
|
handlers = []
|
||||||
for handler in handler_strs:
|
for handler in handler_strs:
|
||||||
handler_name, *args = handler.split("=", 2)
|
handler_name, *args = handler.split("=", 2)
|
||||||
args = args[0] if len(args) > 0 else None
|
args = args[0] if len(args) > 0 else None
|
||||||
@ -398,7 +398,7 @@ class InvokeAILogger(object):
|
|||||||
raise ValueError("syslog is not available on this system")
|
raise ValueError("syslog is not available on this system")
|
||||||
if not args:
|
if not args:
|
||||||
args = "/dev/log" if Path("/dev/log").exists() else "address:localhost:514"
|
args = "/dev/log" if Path("/dev/log").exists() else "address:localhost:514"
|
||||||
syslog_args = dict()
|
syslog_args = {}
|
||||||
try:
|
try:
|
||||||
for a in args.split(","):
|
for a in args.split(","):
|
||||||
arg_name, *arg_value = a.split(":", 2)
|
arg_name, *arg_value = a.split(":", 2)
|
||||||
@ -434,7 +434,7 @@ class InvokeAILogger(object):
|
|||||||
path = url.path
|
path = url.path
|
||||||
port = url.port or 80
|
port = url.port or 80
|
||||||
|
|
||||||
syslog_args = dict()
|
syslog_args = {}
|
||||||
for a in arg_list:
|
for a in arg_list:
|
||||||
arg_name, *arg_value = a.split(":", 2)
|
arg_name, *arg_value = a.split(":", 2)
|
||||||
if arg_name == "method":
|
if arg_name == "method":
|
||||||
|
@ -26,7 +26,7 @@ def log_txt_as_img(wh, xc, size=10):
|
|||||||
# wh a tuple of (width, height)
|
# wh a tuple of (width, height)
|
||||||
# xc a list of captions to plot
|
# xc a list of captions to plot
|
||||||
b = len(xc)
|
b = len(xc)
|
||||||
txts = list()
|
txts = []
|
||||||
for bi in range(b):
|
for bi in range(b):
|
||||||
txt = Image.new("RGB", wh, color="white")
|
txt = Image.new("RGB", wh, color="white")
|
||||||
draw = ImageDraw.Draw(txt)
|
draw = ImageDraw.Draw(txt)
|
||||||
@ -90,7 +90,7 @@ def instantiate_from_config(config, **kwargs):
|
|||||||
elif config == "__is_unconditional__":
|
elif config == "__is_unconditional__":
|
||||||
return None
|
return None
|
||||||
raise KeyError("Expected key `target` to instantiate.")
|
raise KeyError("Expected key `target` to instantiate.")
|
||||||
return get_obj_from_str(config["target"])(**config.get("params", dict()), **kwargs)
|
return get_obj_from_str(config["target"])(**config.get("params", {}), **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def get_obj_from_str(string, reload=False):
|
def get_obj_from_str(string, reload=False):
|
||||||
|
@ -341,19 +341,19 @@ class InvokeAIMetadataParser:
|
|||||||
# this was more elegant as a case statement, but that's not available in python 3.9
|
# this was more elegant as a case statement, but that's not available in python 3.9
|
||||||
if old_scheduler is None:
|
if old_scheduler is None:
|
||||||
return None
|
return None
|
||||||
scheduler_map = dict(
|
scheduler_map = {
|
||||||
ddim="ddim",
|
"ddim": "ddim",
|
||||||
plms="pnmd",
|
"plms": "pnmd",
|
||||||
k_lms="lms",
|
"k_lms": "lms",
|
||||||
k_dpm_2="kdpm_2",
|
"k_dpm_2": "kdpm_2",
|
||||||
k_dpm_2_a="kdpm_2_a",
|
"k_dpm_2_a": "kdpm_2_a",
|
||||||
dpmpp_2="dpmpp_2s",
|
"dpmpp_2": "dpmpp_2s",
|
||||||
k_dpmpp_2="dpmpp_2m",
|
"k_dpmpp_2": "dpmpp_2m",
|
||||||
k_dpmpp_2_a=None, # invalid, in 2.3.x, selecting this sample would just fallback to last run or plms if new session
|
"k_dpmpp_2_a": None, # invalid, in 2.3.x, selecting this sample would just fallback to last run or plms if new session
|
||||||
k_euler="euler",
|
"k_euler": "euler",
|
||||||
k_euler_a="euler_a",
|
"k_euler_a": "euler_a",
|
||||||
k_heun="heun",
|
"k_heun": "heun",
|
||||||
)
|
}
|
||||||
return scheduler_map.get(old_scheduler)
|
return scheduler_map.get(old_scheduler)
|
||||||
|
|
||||||
def split_prompt(self, raw_prompt: str):
|
def split_prompt(self, raw_prompt: str):
|
||||||
|
@ -210,7 +210,7 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
|||||||
############# diffusers tab ##########
|
############# diffusers tab ##########
|
||||||
def add_starter_pipelines(self) -> dict[str, npyscreen.widget]:
|
def add_starter_pipelines(self) -> dict[str, npyscreen.widget]:
|
||||||
"""Add widgets responsible for selecting diffusers models"""
|
"""Add widgets responsible for selecting diffusers models"""
|
||||||
widgets = dict()
|
widgets = {}
|
||||||
models = self.all_models
|
models = self.all_models
|
||||||
starters = self.starter_models
|
starters = self.starter_models
|
||||||
starter_model_labels = self.model_labels
|
starter_model_labels = self.model_labels
|
||||||
@ -261,7 +261,7 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
|||||||
exclude: set = set(),
|
exclude: set = set(),
|
||||||
) -> dict[str, npyscreen.widget]:
|
) -> dict[str, npyscreen.widget]:
|
||||||
"""Generic code to create model selection widgets"""
|
"""Generic code to create model selection widgets"""
|
||||||
widgets = dict()
|
widgets = {}
|
||||||
model_list = [x for x in self.all_models if self.all_models[x].model_type == model_type and x not in exclude]
|
model_list = [x for x in self.all_models if self.all_models[x].model_type == model_type and x not in exclude]
|
||||||
model_labels = [self.model_labels[x] for x in model_list]
|
model_labels = [self.model_labels[x] for x in model_list]
|
||||||
|
|
||||||
@ -391,7 +391,7 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
|||||||
label_width = max([len(models[x].name) for x in models])
|
label_width = max([len(models[x].name) for x in models])
|
||||||
description_width = window_width - label_width - checkbox_width - spacing_width
|
description_width = window_width - label_width - checkbox_width - spacing_width
|
||||||
|
|
||||||
result = dict()
|
result = {}
|
||||||
for x in models.keys():
|
for x in models.keys():
|
||||||
description = models[x].description
|
description = models[x].description
|
||||||
description = (
|
description = (
|
||||||
@ -433,11 +433,11 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
|||||||
parent_conn, child_conn = Pipe()
|
parent_conn, child_conn = Pipe()
|
||||||
p = Process(
|
p = Process(
|
||||||
target=process_and_execute,
|
target=process_and_execute,
|
||||||
kwargs=dict(
|
kwargs={
|
||||||
opt=app.program_opts,
|
"opt": app.program_opts,
|
||||||
selections=app.install_selections,
|
"selections": app.install_selections,
|
||||||
conn_out=child_conn,
|
"conn_out": child_conn,
|
||||||
),
|
},
|
||||||
)
|
)
|
||||||
p.start()
|
p.start()
|
||||||
child_conn.close()
|
child_conn.close()
|
||||||
@ -558,7 +558,7 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
|||||||
for section in ui_sections:
|
for section in ui_sections:
|
||||||
if "models_selected" not in section:
|
if "models_selected" not in section:
|
||||||
continue
|
continue
|
||||||
selected = set([section["models"][x] for x in section["models_selected"].value])
|
selected = {section["models"][x] for x in section["models_selected"].value}
|
||||||
models_to_install = [x for x in selected if not self.all_models[x].installed]
|
models_to_install = [x for x in selected if not self.all_models[x].installed]
|
||||||
models_to_remove = [x for x in section["models"] if x not in selected and self.all_models[x].installed]
|
models_to_remove = [x for x in section["models"] if x not in selected and self.all_models[x].installed]
|
||||||
selections.remove_models.extend(models_to_remove)
|
selections.remove_models.extend(models_to_remove)
|
||||||
|
@ -275,14 +275,14 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
|
|||||||
interp = self.interpolations[self.merge_method.value[0]]
|
interp = self.interpolations[self.merge_method.value[0]]
|
||||||
|
|
||||||
bases = ["sd-1", "sd-2", "sdxl"]
|
bases = ["sd-1", "sd-2", "sdxl"]
|
||||||
args = dict(
|
args = {
|
||||||
model_names=models,
|
"model_names": models,
|
||||||
base_model=BaseModelType(bases[self.base_select.value[0]]),
|
"base_model": BaseModelType(bases[self.base_select.value[0]]),
|
||||||
alpha=self.alpha.value,
|
"alpha": self.alpha.value,
|
||||||
interp=interp,
|
"interp": interp,
|
||||||
force=self.force.value,
|
"force": self.force.value,
|
||||||
merged_model_name=self.merged_model_name.value,
|
"merged_model_name": self.merged_model_name.value,
|
||||||
)
|
}
|
||||||
return args
|
return args
|
||||||
|
|
||||||
def check_for_overwrite(self) -> bool:
|
def check_for_overwrite(self) -> bool:
|
||||||
@ -297,7 +297,7 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
|
|||||||
def validate_field_values(self) -> bool:
|
def validate_field_values(self) -> bool:
|
||||||
bad_fields = []
|
bad_fields = []
|
||||||
model_names = self.model_names
|
model_names = self.model_names
|
||||||
selected_models = set((model_names[self.model1.value[0]], model_names[self.model2.value[0]]))
|
selected_models = {model_names[self.model1.value[0]], model_names[self.model2.value[0]]}
|
||||||
if self.model3.value[0] > 0:
|
if self.model3.value[0] > 0:
|
||||||
selected_models.add(model_names[self.model3.value[0] - 1])
|
selected_models.add(model_names[self.model3.value[0] - 1])
|
||||||
if len(selected_models) < 2:
|
if len(selected_models) < 2:
|
||||||
|
@ -276,13 +276,13 @@ class textualInversionForm(npyscreen.FormMultiPageAction):
|
|||||||
|
|
||||||
def get_model_names(self) -> Tuple[List[str], int]:
|
def get_model_names(self) -> Tuple[List[str], int]:
|
||||||
conf = OmegaConf.load(config.root_dir / "configs/models.yaml")
|
conf = OmegaConf.load(config.root_dir / "configs/models.yaml")
|
||||||
model_names = [idx for idx in sorted(list(conf.keys())) if conf[idx].get("format", None) == "diffusers"]
|
model_names = [idx for idx in sorted(conf.keys()) if conf[idx].get("format", None) == "diffusers"]
|
||||||
defaults = [idx for idx in range(len(model_names)) if "default" in conf[model_names[idx]]]
|
defaults = [idx for idx in range(len(model_names)) if "default" in conf[model_names[idx]]]
|
||||||
default = defaults[0] if len(defaults) > 0 else 0
|
default = defaults[0] if len(defaults) > 0 else 0
|
||||||
return (model_names, default)
|
return (model_names, default)
|
||||||
|
|
||||||
def marshall_arguments(self) -> dict:
|
def marshall_arguments(self) -> dict:
|
||||||
args = dict()
|
args = {}
|
||||||
|
|
||||||
# the choices
|
# the choices
|
||||||
args.update(
|
args.update(
|
||||||
|
@ -37,22 +37,22 @@ def main():
|
|||||||
|
|
||||||
if args.all_models or model_type == "diffusers":
|
if args.all_models or model_type == "diffusers":
|
||||||
for d in dirs:
|
for d in dirs:
|
||||||
conf[f"{base}/{model_type}/{d}"] = dict(
|
conf[f"{base}/{model_type}/{d}"] = {
|
||||||
path=os.path.join(root, d),
|
"path": os.path.join(root, d),
|
||||||
description=f"{model_type} model {d}",
|
"description": f"{model_type} model {d}",
|
||||||
format="folder",
|
"format": "folder",
|
||||||
base=base,
|
"base": base,
|
||||||
)
|
}
|
||||||
|
|
||||||
for f in files:
|
for f in files:
|
||||||
basename = Path(f).stem
|
basename = Path(f).stem
|
||||||
format = Path(f).suffix[1:]
|
format = Path(f).suffix[1:]
|
||||||
conf[f"{base}/{model_type}/{basename}"] = dict(
|
conf[f"{base}/{model_type}/{basename}"] = {
|
||||||
path=os.path.join(root, f),
|
"path": os.path.join(root, f),
|
||||||
description=f"{model_type} model {basename}",
|
"description": f"{model_type} model {basename}",
|
||||||
format=format,
|
"format": format,
|
||||||
base=base,
|
"base": base,
|
||||||
)
|
}
|
||||||
|
|
||||||
OmegaConf.save(config=dict(sorted(conf.items())), f=sys.stdout)
|
OmegaConf.save(config=dict(sorted(conf.items())), f=sys.stdout)
|
||||||
|
|
||||||
|
@ -149,8 +149,8 @@ def test_graph_state_expands_iterator(mock_services):
|
|||||||
invoke_next(g, mock_services)
|
invoke_next(g, mock_services)
|
||||||
|
|
||||||
prepared_add_nodes = g.source_prepared_mapping["3"]
|
prepared_add_nodes = g.source_prepared_mapping["3"]
|
||||||
results = set([g.results[n].value for n in prepared_add_nodes])
|
results = {g.results[n].value for n in prepared_add_nodes}
|
||||||
expected = set([1, 11, 21])
|
expected = {1, 11, 21}
|
||||||
assert results == expected
|
assert results == expected
|
||||||
|
|
||||||
|
|
||||||
@ -229,7 +229,7 @@ def test_graph_executes_depth_first(mock_services):
|
|||||||
# Because ordering is not guaranteed, we cannot compare results directly.
|
# Because ordering is not guaranteed, we cannot compare results directly.
|
||||||
# Instead, we must count the number of results.
|
# Instead, we must count the number of results.
|
||||||
def get_completed_count(g, id):
|
def get_completed_count(g, id):
|
||||||
ids = [i for i in g.source_prepared_mapping[id]]
|
ids = list(g.source_prepared_mapping[id])
|
||||||
completed_ids = [i for i in g.executed if i in ids]
|
completed_ids = [i for i in g.executed if i in ids]
|
||||||
return len(completed_ids)
|
return len(completed_ids)
|
||||||
|
|
||||||
|
@ -503,8 +503,8 @@ def test_graph_expands_subgraph():
|
|||||||
g.add_edge(create_edge("1.2", "value", "2", "a"))
|
g.add_edge(create_edge("1.2", "value", "2", "a"))
|
||||||
|
|
||||||
dg = g.nx_graph_flat()
|
dg = g.nx_graph_flat()
|
||||||
assert set(dg.nodes) == set(["1.1", "1.2", "2"])
|
assert set(dg.nodes) == {"1.1", "1.2", "2"}
|
||||||
assert set(dg.edges) == set([("1.1", "1.2"), ("1.2", "2")])
|
assert set(dg.edges) == {("1.1", "1.2"), ("1.2", "2")}
|
||||||
|
|
||||||
|
|
||||||
def test_graph_subgraph_t2i():
|
def test_graph_subgraph_t2i():
|
||||||
@ -532,9 +532,7 @@ def test_graph_subgraph_t2i():
|
|||||||
|
|
||||||
# Validate
|
# Validate
|
||||||
dg = g.nx_graph_flat()
|
dg = g.nx_graph_flat()
|
||||||
assert set(dg.nodes) == set(
|
assert set(dg.nodes) == {"1.width", "1.height", "1.seed", "1.3", "1.4", "1.5", "1.6", "1.7", "1.8", "2", "3", "4"}
|
||||||
["1.width", "1.height", "1.seed", "1.3", "1.4", "1.5", "1.6", "1.7", "1.8", "2", "3", "4"]
|
|
||||||
)
|
|
||||||
expected_edges = [(f"1.{e.source.node_id}", f"1.{e.destination.node_id}") for e in lg.graph.edges]
|
expected_edges = [(f"1.{e.source.node_id}", f"1.{e.destination.node_id}") for e in lg.graph.edges]
|
||||||
expected_edges.extend([("2", "1.width"), ("3", "1.height"), ("1.8", "4")])
|
expected_edges.extend([("2", "1.width"), ("3", "1.height"), ("1.8", "4")])
|
||||||
print(expected_edges)
|
print(expected_edges)
|
||||||
|
@ -130,7 +130,7 @@ class TestEventService(EventServiceBase):
|
|||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.events = list()
|
self.events = []
|
||||||
|
|
||||||
def dispatch(self, event_name: str, payload: Any) -> None:
|
def dispatch(self, event_name: str, payload: Any) -> None:
|
||||||
pass
|
pass
|
||||||
|
Loading…
Reference in New Issue
Block a user