mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
resolve conflicts; blackify
This commit is contained in:
commit
38c1436f02
20
.github/workflows/pyflakes.yml
vendored
20
.github/workflows/pyflakes.yml
vendored
@ -1,20 +0,0 @@
|
|||||||
on:
|
|
||||||
pull_request:
|
|
||||||
push:
|
|
||||||
branches:
|
|
||||||
- main
|
|
||||||
- development
|
|
||||||
- 'release-candidate-*'
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
pyflakes:
|
|
||||||
name: runner / pyflakes
|
|
||||||
if: github.event.pull_request.draft == false
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
steps:
|
|
||||||
- uses: actions/checkout@v2
|
|
||||||
- name: pyflakes
|
|
||||||
uses: reviewdog/action-pyflakes@v1
|
|
||||||
with:
|
|
||||||
github_token: ${{ secrets.GITHUB_TOKEN }}
|
|
||||||
reporter: github-pr-review
|
|
7
.github/workflows/style-checks.yml
vendored
7
.github/workflows/style-checks.yml
vendored
@ -18,8 +18,7 @@ jobs:
|
|||||||
|
|
||||||
- name: Install dependencies with pip
|
- name: Install dependencies with pip
|
||||||
run: |
|
run: |
|
||||||
pip install black flake8 Flake8-pyproject isort
|
pip install ruff
|
||||||
|
|
||||||
- run: isort --check-only .
|
- run: ruff check --output-format=github .
|
||||||
- run: black --check .
|
- run: ruff format --check .
|
||||||
- run: flake8
|
|
||||||
|
@ -137,7 +137,7 @@ def dest_path(dest=None) -> Path:
|
|||||||
path_completer = PathCompleter(
|
path_completer = PathCompleter(
|
||||||
only_directories=True,
|
only_directories=True,
|
||||||
expanduser=True,
|
expanduser=True,
|
||||||
get_paths=lambda: [browse_start],
|
get_paths=lambda: [browse_start], # noqa: B023
|
||||||
# get_paths=lambda: [".."].extend(list(browse_start.iterdir()))
|
# get_paths=lambda: [".."].extend(list(browse_start.iterdir()))
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -149,7 +149,7 @@ def dest_path(dest=None) -> Path:
|
|||||||
completer=path_completer,
|
completer=path_completer,
|
||||||
default=str(browse_start) + os.sep,
|
default=str(browse_start) + os.sep,
|
||||||
vi_mode=True,
|
vi_mode=True,
|
||||||
complete_while_typing=True
|
complete_while_typing=True,
|
||||||
# Test that this is not needed on Windows
|
# Test that this is not needed on Windows
|
||||||
# complete_style=CompleteStyle.READLINE_LIKE,
|
# complete_style=CompleteStyle.READLINE_LIKE,
|
||||||
)
|
)
|
||||||
|
@ -28,7 +28,7 @@ class FastAPIEventService(EventServiceBase):
|
|||||||
self.__queue.put(None)
|
self.__queue.put(None)
|
||||||
|
|
||||||
def dispatch(self, event_name: str, payload: Any) -> None:
|
def dispatch(self, event_name: str, payload: Any) -> None:
|
||||||
self.__queue.put(dict(event_name=event_name, payload=payload))
|
self.__queue.put({"event_name": event_name, "payload": payload})
|
||||||
|
|
||||||
async def __dispatch_from_queue(self, stop_event: threading.Event):
|
async def __dispatch_from_queue(self, stop_event: threading.Event):
|
||||||
"""Get events on from the queue and dispatch them, from the correct thread"""
|
"""Get events on from the queue and dispatch them, from the correct thread"""
|
||||||
|
@ -8,12 +8,20 @@ from typing import List, Optional
|
|||||||
|
|
||||||
from fastapi import Body, Path, Query, Response
|
from fastapi import Body, Path, Query, Response
|
||||||
from fastapi.routing import APIRouter
|
from fastapi.routing import APIRouter
|
||||||
from pydantic import BaseModel, ConfigDict, TypeAdapter
|
from pydantic import BaseModel, ConfigDict
|
||||||
from starlette.exceptions import HTTPException
|
from starlette.exceptions import HTTPException
|
||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
from invokeai.app.services.model_records import DuplicateModelException, InvalidModelException, UnknownModelException
|
from invokeai.app.services.model_records import (
|
||||||
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelType
|
DuplicateModelException,
|
||||||
|
InvalidModelException,
|
||||||
|
UnknownModelException,
|
||||||
|
)
|
||||||
|
from invokeai.backend.model_manager.config import (
|
||||||
|
AnyModelConfig,
|
||||||
|
BaseModelType,
|
||||||
|
ModelType,
|
||||||
|
)
|
||||||
|
|
||||||
from ..dependencies import ApiDependencies
|
from ..dependencies import ApiDependencies
|
||||||
|
|
||||||
@ -28,26 +36,32 @@ class ModelsList(BaseModel):
|
|||||||
model_config = ConfigDict(use_enum_values=True)
|
model_config = ConfigDict(use_enum_values=True)
|
||||||
|
|
||||||
|
|
||||||
ModelsListValidator = TypeAdapter(ModelsList)
|
|
||||||
|
|
||||||
|
|
||||||
@model_records_router.get(
|
@model_records_router.get(
|
||||||
"/",
|
"/",
|
||||||
operation_id="list_model_records",
|
operation_id="list_model_records",
|
||||||
)
|
)
|
||||||
async def list_model_records(
|
async def list_model_records(
|
||||||
base_models: Optional[List[BaseModelType]] = Query(default=None, description="Base models to include"),
|
base_models: Optional[List[BaseModelType]] = Query(
|
||||||
model_type: Optional[ModelType] = Query(default=None, description="The type of model to get"),
|
default=None, description="Base models to include"
|
||||||
|
),
|
||||||
|
model_type: Optional[ModelType] = Query(
|
||||||
|
default=None, description="The type of model to get"
|
||||||
|
),
|
||||||
) -> ModelsList:
|
) -> ModelsList:
|
||||||
"""Get a list of models."""
|
"""Get a list of models."""
|
||||||
record_store = ApiDependencies.invoker.services.model_records
|
record_store = ApiDependencies.invoker.services.model_records
|
||||||
models = list()
|
found_models: list[AnyModelConfig] = []
|
||||||
if base_models:
|
if base_models:
|
||||||
for base_model in base_models:
|
for base_model in base_models:
|
||||||
models.extend(record_store.search_by_attr(base_model=base_model, model_type=model_type))
|
found_models.extend(
|
||||||
|
record_store.search_by_attr(
|
||||||
|
base_model=base_model, model_type=model_type
|
||||||
|
)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
models.extend(record_store.search_by_attr(model_type=model_type))
|
found_models.extend(record_store.search_by_attr(model_type=model_type))
|
||||||
return ModelsList(models=models)
|
return ModelsList(models=found_models)
|
||||||
|
|
||||||
|
|
||||||
@model_records_router.get(
|
@model_records_router.get(
|
||||||
"/i/{key}",
|
"/i/{key}",
|
||||||
@ -83,13 +97,16 @@ async def get_model_record(
|
|||||||
)
|
)
|
||||||
async def update_model_record(
|
async def update_model_record(
|
||||||
key: Annotated[str, Path(description="Unique key of model")],
|
key: Annotated[str, Path(description="Unique key of model")],
|
||||||
info: Annotated[AnyModelConfig, Body(description="Model config", discriminator="type")],
|
info: Annotated[
|
||||||
|
AnyModelConfig, Body(description="Model config", discriminator="type")
|
||||||
|
],
|
||||||
) -> AnyModelConfig:
|
) -> AnyModelConfig:
|
||||||
"""Update model contents with a new config. If the model name or base fields are changed, then the model is renamed."""
|
"""Update model contents with a new config. If the model name or base fields are changed, then the model is renamed."""
|
||||||
logger = ApiDependencies.invoker.services.logger
|
logger = ApiDependencies.invoker.services.logger
|
||||||
record_store = ApiDependencies.invoker.services.model_records
|
record_store = ApiDependencies.invoker.services.model_records
|
||||||
try:
|
try:
|
||||||
model_response = record_store.update_model(key, config=info)
|
model_response = record_store.update_model(key, config=info)
|
||||||
|
logger.info(f"Updated model: {key}")
|
||||||
except UnknownModelException as e:
|
except UnknownModelException as e:
|
||||||
raise HTTPException(status_code=404, detail=str(e))
|
raise HTTPException(status_code=404, detail=str(e))
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
@ -101,7 +118,10 @@ async def update_model_record(
|
|||||||
@model_records_router.delete(
|
@model_records_router.delete(
|
||||||
"/i/{key}",
|
"/i/{key}",
|
||||||
operation_id="del_model_record",
|
operation_id="del_model_record",
|
||||||
responses={204: {"description": "Model deleted successfully"}, 404: {"description": "Model not found"}},
|
responses={
|
||||||
|
204: {"description": "Model deleted successfully"},
|
||||||
|
404: {"description": "Model not found"},
|
||||||
|
},
|
||||||
status_code=204,
|
status_code=204,
|
||||||
)
|
)
|
||||||
async def del_model_record(
|
async def del_model_record(
|
||||||
@ -125,13 +145,17 @@ async def del_model_record(
|
|||||||
operation_id="add_model_record",
|
operation_id="add_model_record",
|
||||||
responses={
|
responses={
|
||||||
201: {"description": "The model added successfully"},
|
201: {"description": "The model added successfully"},
|
||||||
409: {"description": "There is already a model corresponding to this path or repo_id"},
|
409: {
|
||||||
|
"description": "There is already a model corresponding to this path or repo_id"
|
||||||
|
},
|
||||||
415: {"description": "Unrecognized file/folder format"},
|
415: {"description": "Unrecognized file/folder format"},
|
||||||
},
|
},
|
||||||
status_code=201,
|
status_code=201,
|
||||||
)
|
)
|
||||||
async def add_model_record(
|
async def add_model_record(
|
||||||
config: Annotated[AnyModelConfig, Body(description="Model config", discriminator="type")]
|
config: Annotated[
|
||||||
|
AnyModelConfig, Body(description="Model config", discriminator="type")
|
||||||
|
]
|
||||||
) -> AnyModelConfig:
|
) -> AnyModelConfig:
|
||||||
"""
|
"""
|
||||||
Add a model using the configuration information appropriate for its type.
|
Add a model using the configuration information appropriate for its type.
|
||||||
|
@ -54,7 +54,7 @@ async def list_models(
|
|||||||
) -> ModelsList:
|
) -> ModelsList:
|
||||||
"""Gets a list of models"""
|
"""Gets a list of models"""
|
||||||
if base_models and len(base_models) > 0:
|
if base_models and len(base_models) > 0:
|
||||||
models_raw = list()
|
models_raw = []
|
||||||
for base_model in base_models:
|
for base_model in base_models:
|
||||||
models_raw.extend(ApiDependencies.invoker.services.model_manager.list_models(base_model, model_type))
|
models_raw.extend(ApiDependencies.invoker.services.model_manager.list_models(base_model, model_type))
|
||||||
else:
|
else:
|
||||||
|
@ -132,7 +132,7 @@ def custom_openapi() -> dict[str, Any]:
|
|||||||
# Add all outputs
|
# Add all outputs
|
||||||
all_invocations = BaseInvocation.get_invocations()
|
all_invocations = BaseInvocation.get_invocations()
|
||||||
output_types = set()
|
output_types = set()
|
||||||
output_type_titles = dict()
|
output_type_titles = {}
|
||||||
for invoker in all_invocations:
|
for invoker in all_invocations:
|
||||||
output_type = signature(invoker.invoke).return_annotation
|
output_type = signature(invoker.invoke).return_annotation
|
||||||
output_types.add(output_type)
|
output_types.add(output_type)
|
||||||
@ -173,12 +173,12 @@ def custom_openapi() -> dict[str, Any]:
|
|||||||
# print(f"Config with name {name} already defined")
|
# print(f"Config with name {name} already defined")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
openapi_schema["components"]["schemas"][name] = dict(
|
openapi_schema["components"]["schemas"][name] = {
|
||||||
title=name,
|
"title": name,
|
||||||
description="An enumeration.",
|
"description": "An enumeration.",
|
||||||
type="string",
|
"type": "string",
|
||||||
enum=list(v.value for v in model_config_format_enum),
|
"enum": [v.value for v in model_config_format_enum],
|
||||||
)
|
}
|
||||||
|
|
||||||
app.openapi_schema = openapi_schema
|
app.openapi_schema = openapi_schema
|
||||||
return app.openapi_schema
|
return app.openapi_schema
|
||||||
|
@ -25,4 +25,4 @@ spec.loader.exec_module(module)
|
|||||||
|
|
||||||
# add core nodes to __all__
|
# add core nodes to __all__
|
||||||
python_files = filter(lambda f: not f.name.startswith("_"), Path(__file__).parent.glob("*.py"))
|
python_files = filter(lambda f: not f.name.startswith("_"), Path(__file__).parent.glob("*.py"))
|
||||||
__all__ = list(f.stem for f in python_files) # type: ignore
|
__all__ = [f.stem for f in python_files] # type: ignore
|
||||||
|
@ -236,35 +236,35 @@ def InputField(
|
|||||||
Ignored for non-collection fields.
|
Ignored for non-collection fields.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
json_schema_extra_: dict[str, Any] = dict(
|
json_schema_extra_: dict[str, Any] = {
|
||||||
input=input,
|
"input": input,
|
||||||
ui_type=ui_type,
|
"ui_type": ui_type,
|
||||||
ui_component=ui_component,
|
"ui_component": ui_component,
|
||||||
ui_hidden=ui_hidden,
|
"ui_hidden": ui_hidden,
|
||||||
ui_order=ui_order,
|
"ui_order": ui_order,
|
||||||
item_default=item_default,
|
"item_default": item_default,
|
||||||
ui_choice_labels=ui_choice_labels,
|
"ui_choice_labels": ui_choice_labels,
|
||||||
_field_kind="input",
|
"_field_kind": "input",
|
||||||
)
|
}
|
||||||
|
|
||||||
field_args = dict(
|
field_args = {
|
||||||
default=default,
|
"default": default,
|
||||||
default_factory=default_factory,
|
"default_factory": default_factory,
|
||||||
title=title,
|
"title": title,
|
||||||
description=description,
|
"description": description,
|
||||||
pattern=pattern,
|
"pattern": pattern,
|
||||||
strict=strict,
|
"strict": strict,
|
||||||
gt=gt,
|
"gt": gt,
|
||||||
ge=ge,
|
"ge": ge,
|
||||||
lt=lt,
|
"lt": lt,
|
||||||
le=le,
|
"le": le,
|
||||||
multiple_of=multiple_of,
|
"multiple_of": multiple_of,
|
||||||
allow_inf_nan=allow_inf_nan,
|
"allow_inf_nan": allow_inf_nan,
|
||||||
max_digits=max_digits,
|
"max_digits": max_digits,
|
||||||
decimal_places=decimal_places,
|
"decimal_places": decimal_places,
|
||||||
min_length=min_length,
|
"min_length": min_length,
|
||||||
max_length=max_length,
|
"max_length": max_length,
|
||||||
)
|
}
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Invocation definitions have their fields typed correctly for their `invoke()` functions.
|
Invocation definitions have their fields typed correctly for their `invoke()` functions.
|
||||||
@ -299,24 +299,24 @@ def InputField(
|
|||||||
|
|
||||||
# because we are manually making fields optional, we need to store the original required bool for reference later
|
# because we are manually making fields optional, we need to store the original required bool for reference later
|
||||||
if default is PydanticUndefined and default_factory is PydanticUndefined:
|
if default is PydanticUndefined and default_factory is PydanticUndefined:
|
||||||
json_schema_extra_.update(dict(orig_required=True))
|
json_schema_extra_.update({"orig_required": True})
|
||||||
else:
|
else:
|
||||||
json_schema_extra_.update(dict(orig_required=False))
|
json_schema_extra_.update({"orig_required": False})
|
||||||
|
|
||||||
# make Input.Any and Input.Connection fields optional, providing None as a default if the field doesn't already have one
|
# make Input.Any and Input.Connection fields optional, providing None as a default if the field doesn't already have one
|
||||||
if (input is Input.Any or input is Input.Connection) and default_factory is PydanticUndefined:
|
if (input is Input.Any or input is Input.Connection) and default_factory is PydanticUndefined:
|
||||||
default_ = None if default is PydanticUndefined else default
|
default_ = None if default is PydanticUndefined else default
|
||||||
provided_args.update(dict(default=default_))
|
provided_args.update({"default": default_})
|
||||||
if default is not PydanticUndefined:
|
if default is not PydanticUndefined:
|
||||||
# before invoking, we'll grab the original default value and set it on the field if the field wasn't provided a value
|
# before invoking, we'll grab the original default value and set it on the field if the field wasn't provided a value
|
||||||
json_schema_extra_.update(dict(default=default))
|
json_schema_extra_.update({"default": default})
|
||||||
json_schema_extra_.update(dict(orig_default=default))
|
json_schema_extra_.update({"orig_default": default})
|
||||||
elif default is not PydanticUndefined and default_factory is PydanticUndefined:
|
elif default is not PydanticUndefined and default_factory is PydanticUndefined:
|
||||||
default_ = default
|
default_ = default
|
||||||
provided_args.update(dict(default=default_))
|
provided_args.update({"default": default_})
|
||||||
json_schema_extra_.update(dict(orig_default=default_))
|
json_schema_extra_.update({"orig_default": default_})
|
||||||
elif default_factory is not PydanticUndefined:
|
elif default_factory is not PydanticUndefined:
|
||||||
provided_args.update(dict(default_factory=default_factory))
|
provided_args.update({"default_factory": default_factory})
|
||||||
# TODO: cannot serialize default_factory...
|
# TODO: cannot serialize default_factory...
|
||||||
# json_schema_extra_.update(dict(orig_default_factory=default_factory))
|
# json_schema_extra_.update(dict(orig_default_factory=default_factory))
|
||||||
|
|
||||||
@ -383,12 +383,12 @@ def OutputField(
|
|||||||
decimal_places=decimal_places,
|
decimal_places=decimal_places,
|
||||||
min_length=min_length,
|
min_length=min_length,
|
||||||
max_length=max_length,
|
max_length=max_length,
|
||||||
json_schema_extra=dict(
|
json_schema_extra={
|
||||||
ui_type=ui_type,
|
"ui_type": ui_type,
|
||||||
ui_hidden=ui_hidden,
|
"ui_hidden": ui_hidden,
|
||||||
ui_order=ui_order,
|
"ui_order": ui_order,
|
||||||
_field_kind="output",
|
"_field_kind": "output",
|
||||||
),
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -460,14 +460,14 @@ class BaseInvocationOutput(BaseModel):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_output_types(cls) -> Iterable[str]:
|
def get_output_types(cls) -> Iterable[str]:
|
||||||
return map(lambda i: get_type(i), BaseInvocationOutput.get_outputs())
|
return (get_type(i) for i in BaseInvocationOutput.get_outputs())
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None:
|
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None:
|
||||||
# Because we use a pydantic Literal field with default value for the invocation type,
|
# Because we use a pydantic Literal field with default value for the invocation type,
|
||||||
# it will be typed as optional in the OpenAPI schema. Make it required manually.
|
# it will be typed as optional in the OpenAPI schema. Make it required manually.
|
||||||
if "required" not in schema or not isinstance(schema["required"], list):
|
if "required" not in schema or not isinstance(schema["required"], list):
|
||||||
schema["required"] = list()
|
schema["required"] = []
|
||||||
schema["required"].extend(["type"])
|
schema["required"].extend(["type"])
|
||||||
|
|
||||||
model_config = ConfigDict(
|
model_config = ConfigDict(
|
||||||
@ -527,16 +527,11 @@ class BaseInvocation(ABC, BaseModel):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def get_invocations_map(cls) -> dict[str, BaseInvocation]:
|
def get_invocations_map(cls) -> dict[str, BaseInvocation]:
|
||||||
# Get the type strings out of the literals and into a dictionary
|
# Get the type strings out of the literals and into a dictionary
|
||||||
return dict(
|
return {get_type(i): i for i in BaseInvocation.get_invocations()}
|
||||||
map(
|
|
||||||
lambda i: (get_type(i), i),
|
|
||||||
BaseInvocation.get_invocations(),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_invocation_types(cls) -> Iterable[str]:
|
def get_invocation_types(cls) -> Iterable[str]:
|
||||||
return map(lambda i: get_type(i), BaseInvocation.get_invocations())
|
return (get_type(i) for i in BaseInvocation.get_invocations())
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_output_type(cls) -> BaseInvocationOutput:
|
def get_output_type(cls) -> BaseInvocationOutput:
|
||||||
@ -555,7 +550,7 @@ class BaseInvocation(ABC, BaseModel):
|
|||||||
if uiconfig and hasattr(uiconfig, "version"):
|
if uiconfig and hasattr(uiconfig, "version"):
|
||||||
schema["version"] = uiconfig.version
|
schema["version"] = uiconfig.version
|
||||||
if "required" not in schema or not isinstance(schema["required"], list):
|
if "required" not in schema or not isinstance(schema["required"], list):
|
||||||
schema["required"] = list()
|
schema["required"] = []
|
||||||
schema["required"].extend(["type", "id"])
|
schema["required"].extend(["type", "id"])
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@ -609,15 +604,15 @@ class BaseInvocation(ABC, BaseModel):
|
|||||||
id: str = Field(
|
id: str = Field(
|
||||||
default_factory=uuid_string,
|
default_factory=uuid_string,
|
||||||
description="The id of this instance of an invocation. Must be unique among all instances of invocations.",
|
description="The id of this instance of an invocation. Must be unique among all instances of invocations.",
|
||||||
json_schema_extra=dict(_field_kind="internal"),
|
json_schema_extra={"_field_kind": "internal"},
|
||||||
)
|
)
|
||||||
is_intermediate: bool = Field(
|
is_intermediate: bool = Field(
|
||||||
default=False,
|
default=False,
|
||||||
description="Whether or not this is an intermediate invocation.",
|
description="Whether or not this is an intermediate invocation.",
|
||||||
json_schema_extra=dict(ui_type=UIType.IsIntermediate, _field_kind="internal"),
|
json_schema_extra={"ui_type": UIType.IsIntermediate, "_field_kind": "internal"},
|
||||||
)
|
)
|
||||||
use_cache: bool = Field(
|
use_cache: bool = Field(
|
||||||
default=True, description="Whether or not to use the cache", json_schema_extra=dict(_field_kind="internal")
|
default=True, description="Whether or not to use the cache", json_schema_extra={"_field_kind": "internal"}
|
||||||
)
|
)
|
||||||
|
|
||||||
UIConfig: ClassVar[Type[UIConfigBase]]
|
UIConfig: ClassVar[Type[UIConfigBase]]
|
||||||
@ -651,7 +646,7 @@ class _Model(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
# Get all pydantic model attrs, methods, etc
|
# Get all pydantic model attrs, methods, etc
|
||||||
RESERVED_PYDANTIC_FIELD_NAMES = set(map(lambda m: m[0], inspect.getmembers(_Model())))
|
RESERVED_PYDANTIC_FIELD_NAMES = {m[0] for m in inspect.getmembers(_Model())}
|
||||||
|
|
||||||
|
|
||||||
def validate_fields(model_fields: dict[str, FieldInfo], model_type: str) -> None:
|
def validate_fields(model_fields: dict[str, FieldInfo], model_type: str) -> None:
|
||||||
@ -666,9 +661,7 @@ def validate_fields(model_fields: dict[str, FieldInfo], model_type: str) -> None
|
|||||||
|
|
||||||
field_kind = (
|
field_kind = (
|
||||||
# _field_kind is defined via InputField(), OutputField() or by one of the internal fields defined in this file
|
# _field_kind is defined via InputField(), OutputField() or by one of the internal fields defined in this file
|
||||||
field.json_schema_extra.get("_field_kind", None)
|
field.json_schema_extra.get("_field_kind", None) if field.json_schema_extra else None
|
||||||
if field.json_schema_extra
|
|
||||||
else None
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# must have a field_kind
|
# must have a field_kind
|
||||||
@ -729,7 +722,7 @@ def invocation(
|
|||||||
# Add OpenAPI schema extras
|
# Add OpenAPI schema extras
|
||||||
uiconf_name = cls.__qualname__ + ".UIConfig"
|
uiconf_name = cls.__qualname__ + ".UIConfig"
|
||||||
if not hasattr(cls, "UIConfig") or cls.UIConfig.__qualname__ != uiconf_name:
|
if not hasattr(cls, "UIConfig") or cls.UIConfig.__qualname__ != uiconf_name:
|
||||||
cls.UIConfig = type(uiconf_name, (UIConfigBase,), dict())
|
cls.UIConfig = type(uiconf_name, (UIConfigBase,), {})
|
||||||
if title is not None:
|
if title is not None:
|
||||||
cls.UIConfig.title = title
|
cls.UIConfig.title = title
|
||||||
if tags is not None:
|
if tags is not None:
|
||||||
@ -756,7 +749,7 @@ def invocation(
|
|||||||
|
|
||||||
invocation_type_annotation = Literal[invocation_type] # type: ignore
|
invocation_type_annotation = Literal[invocation_type] # type: ignore
|
||||||
invocation_type_field = Field(
|
invocation_type_field = Field(
|
||||||
title="type", default=invocation_type, json_schema_extra=dict(_field_kind="internal")
|
title="type", default=invocation_type, json_schema_extra={"_field_kind": "internal"}
|
||||||
)
|
)
|
||||||
|
|
||||||
docstring = cls.__doc__
|
docstring = cls.__doc__
|
||||||
@ -802,7 +795,7 @@ def invocation_output(
|
|||||||
# Add the output type to the model.
|
# Add the output type to the model.
|
||||||
|
|
||||||
output_type_annotation = Literal[output_type] # type: ignore
|
output_type_annotation = Literal[output_type] # type: ignore
|
||||||
output_type_field = Field(title="type", default=output_type, json_schema_extra=dict(_field_kind="internal"))
|
output_type_field = Field(title="type", default=output_type, json_schema_extra={"_field_kind": "internal"})
|
||||||
|
|
||||||
docstring = cls.__doc__
|
docstring = cls.__doc__
|
||||||
cls = create_model(
|
cls = create_model(
|
||||||
@ -834,7 +827,7 @@ WorkflowFieldValidator = TypeAdapter(WorkflowField)
|
|||||||
|
|
||||||
class WithWorkflow(BaseModel):
|
class WithWorkflow(BaseModel):
|
||||||
workflow: Optional[WorkflowField] = Field(
|
workflow: Optional[WorkflowField] = Field(
|
||||||
default=None, description=FieldDescriptions.workflow, json_schema_extra=dict(_field_kind="internal")
|
default=None, description=FieldDescriptions.workflow, json_schema_extra={"_field_kind": "internal"}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -852,5 +845,5 @@ MetadataFieldValidator = TypeAdapter(MetadataField)
|
|||||||
|
|
||||||
class WithMetadata(BaseModel):
|
class WithMetadata(BaseModel):
|
||||||
metadata: Optional[MetadataField] = Field(
|
metadata: Optional[MetadataField] = Field(
|
||||||
default=None, description=FieldDescriptions.metadata, json_schema_extra=dict(_field_kind="internal")
|
default=None, description=FieldDescriptions.metadata, json_schema_extra={"_field_kind": "internal"}
|
||||||
)
|
)
|
||||||
|
@ -131,7 +131,7 @@ def prepare_faces_list(
|
|||||||
deduped_faces: list[FaceResultData] = []
|
deduped_faces: list[FaceResultData] = []
|
||||||
|
|
||||||
if len(face_result_list) == 0:
|
if len(face_result_list) == 0:
|
||||||
return list()
|
return []
|
||||||
|
|
||||||
for candidate in face_result_list:
|
for candidate in face_result_list:
|
||||||
should_add = True
|
should_add = True
|
||||||
@ -210,7 +210,7 @@ def generate_face_box_mask(
|
|||||||
# Check if any face is detected.
|
# Check if any face is detected.
|
||||||
if results.multi_face_landmarks: # type: ignore # this are via protobuf and not typed
|
if results.multi_face_landmarks: # type: ignore # this are via protobuf and not typed
|
||||||
# Search for the face_id in the detected faces.
|
# Search for the face_id in the detected faces.
|
||||||
for face_id, face_landmarks in enumerate(results.multi_face_landmarks): # type: ignore #this are via protobuf and not typed
|
for _face_id, face_landmarks in enumerate(results.multi_face_landmarks): # type: ignore #this are via protobuf and not typed
|
||||||
# Get the bounding box of the face mesh.
|
# Get the bounding box of the face mesh.
|
||||||
x_coordinates = [landmark.x for landmark in face_landmarks.landmark]
|
x_coordinates = [landmark.x for landmark in face_landmarks.landmark]
|
||||||
y_coordinates = [landmark.y for landmark in face_landmarks.landmark]
|
y_coordinates = [landmark.y for landmark in face_landmarks.landmark]
|
||||||
|
@ -77,7 +77,7 @@ if choose_torch_device() == torch.device("mps"):
|
|||||||
|
|
||||||
DEFAULT_PRECISION = choose_precision(choose_torch_device())
|
DEFAULT_PRECISION = choose_precision(choose_torch_device())
|
||||||
|
|
||||||
SAMPLER_NAME_VALUES = Literal[tuple(list(SCHEDULER_MAP.keys()))]
|
SAMPLER_NAME_VALUES = Literal[tuple(SCHEDULER_MAP.keys())]
|
||||||
|
|
||||||
|
|
||||||
@invocation_output("scheduler_output")
|
@invocation_output("scheduler_output")
|
||||||
@ -1105,7 +1105,7 @@ class BlendLatentsInvocation(BaseInvocation):
|
|||||||
latents_b = context.services.latents.get(self.latents_b.latents_name)
|
latents_b = context.services.latents.get(self.latents_b.latents_name)
|
||||||
|
|
||||||
if latents_a.shape != latents_b.shape:
|
if latents_a.shape != latents_b.shape:
|
||||||
raise "Latents to blend must be the same size."
|
raise Exception("Latents to blend must be the same size.")
|
||||||
|
|
||||||
# TODO:
|
# TODO:
|
||||||
device = choose_torch_device()
|
device = choose_torch_device()
|
||||||
|
@ -145,17 +145,17 @@ INTEGER_OPERATIONS = Literal[
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
INTEGER_OPERATIONS_LABELS = dict(
|
INTEGER_OPERATIONS_LABELS = {
|
||||||
ADD="Add A+B",
|
"ADD": "Add A+B",
|
||||||
SUB="Subtract A-B",
|
"SUB": "Subtract A-B",
|
||||||
MUL="Multiply A*B",
|
"MUL": "Multiply A*B",
|
||||||
DIV="Divide A/B",
|
"DIV": "Divide A/B",
|
||||||
EXP="Exponentiate A^B",
|
"EXP": "Exponentiate A^B",
|
||||||
MOD="Modulus A%B",
|
"MOD": "Modulus A%B",
|
||||||
ABS="Absolute Value of A",
|
"ABS": "Absolute Value of A",
|
||||||
MIN="Minimum(A,B)",
|
"MIN": "Minimum(A,B)",
|
||||||
MAX="Maximum(A,B)",
|
"MAX": "Maximum(A,B)",
|
||||||
)
|
}
|
||||||
|
|
||||||
|
|
||||||
@invocation(
|
@invocation(
|
||||||
@ -231,17 +231,17 @@ FLOAT_OPERATIONS = Literal[
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
FLOAT_OPERATIONS_LABELS = dict(
|
FLOAT_OPERATIONS_LABELS = {
|
||||||
ADD="Add A+B",
|
"ADD": "Add A+B",
|
||||||
SUB="Subtract A-B",
|
"SUB": "Subtract A-B",
|
||||||
MUL="Multiply A*B",
|
"MUL": "Multiply A*B",
|
||||||
DIV="Divide A/B",
|
"DIV": "Divide A/B",
|
||||||
EXP="Exponentiate A^B",
|
"EXP": "Exponentiate A^B",
|
||||||
ABS="Absolute Value of A",
|
"ABS": "Absolute Value of A",
|
||||||
SQRT="Square Root of A",
|
"SQRT": "Square Root of A",
|
||||||
MIN="Minimum(A,B)",
|
"MIN": "Minimum(A,B)",
|
||||||
MAX="Maximum(A,B)",
|
"MAX": "Maximum(A,B)",
|
||||||
)
|
}
|
||||||
|
|
||||||
|
|
||||||
@invocation(
|
@invocation(
|
||||||
@ -266,7 +266,7 @@ class FloatMathInvocation(BaseInvocation):
|
|||||||
raise ValueError("Cannot divide by zero")
|
raise ValueError("Cannot divide by zero")
|
||||||
elif info.data["operation"] == "EXP" and info.data["a"] == 0 and v < 0:
|
elif info.data["operation"] == "EXP" and info.data["a"] == 0 and v < 0:
|
||||||
raise ValueError("Cannot raise zero to a negative power")
|
raise ValueError("Cannot raise zero to a negative power")
|
||||||
elif info.data["operation"] == "EXP" and type(info.data["a"] ** v) is complex:
|
elif info.data["operation"] == "EXP" and isinstance(info.data["a"] ** v, complex):
|
||||||
raise ValueError("Root operation resulted in a complex number")
|
raise ValueError("Root operation resulted in a complex number")
|
||||||
return v
|
return v
|
||||||
|
|
||||||
|
@ -54,7 +54,7 @@ ORT_TO_NP_TYPE = {
|
|||||||
"tensor(double)": np.float64,
|
"tensor(double)": np.float64,
|
||||||
}
|
}
|
||||||
|
|
||||||
PRECISION_VALUES = Literal[tuple(list(ORT_TO_NP_TYPE.keys()))]
|
PRECISION_VALUES = Literal[tuple(ORT_TO_NP_TYPE.keys())]
|
||||||
|
|
||||||
|
|
||||||
@invocation("prompt_onnx", title="ONNX Prompt (Raw)", tags=["prompt", "onnx"], category="conditioning", version="1.0.0")
|
@invocation("prompt_onnx", title="ONNX Prompt (Raw)", tags=["prompt", "onnx"], category="conditioning", version="1.0.0")
|
||||||
@ -252,7 +252,7 @@ class ONNXTextToLatentsInvocation(BaseInvocation):
|
|||||||
scheduler.set_timesteps(self.steps)
|
scheduler.set_timesteps(self.steps)
|
||||||
latents = latents * np.float64(scheduler.init_noise_sigma)
|
latents = latents * np.float64(scheduler.init_noise_sigma)
|
||||||
|
|
||||||
extra_step_kwargs = dict()
|
extra_step_kwargs = {}
|
||||||
if "eta" in set(inspect.signature(scheduler.step).parameters.keys()):
|
if "eta" in set(inspect.signature(scheduler.step).parameters.keys()):
|
||||||
extra_step_kwargs.update(
|
extra_step_kwargs.update(
|
||||||
eta=0.0,
|
eta=0.0,
|
||||||
|
@ -100,7 +100,7 @@ EASING_FUNCTIONS_MAP = {
|
|||||||
"BounceInOut": BounceEaseInOut,
|
"BounceInOut": BounceEaseInOut,
|
||||||
}
|
}
|
||||||
|
|
||||||
EASING_FUNCTION_KEYS = Literal[tuple(list(EASING_FUNCTIONS_MAP.keys()))]
|
EASING_FUNCTION_KEYS = Literal[tuple(EASING_FUNCTIONS_MAP.keys())]
|
||||||
|
|
||||||
|
|
||||||
# actually I think for now could just use CollectionOutput (which is list[Any]
|
# actually I think for now could just use CollectionOutput (which is list[Any]
|
||||||
@ -161,7 +161,7 @@ class StepParamEasingInvocation(BaseInvocation):
|
|||||||
easing_class = EASING_FUNCTIONS_MAP[self.easing]
|
easing_class = EASING_FUNCTIONS_MAP[self.easing]
|
||||||
if log_diagnostics:
|
if log_diagnostics:
|
||||||
context.services.logger.debug("easing class: " + str(easing_class))
|
context.services.logger.debug("easing class: " + str(easing_class))
|
||||||
easing_list = list()
|
easing_list = []
|
||||||
if self.mirror: # "expected" mirroring
|
if self.mirror: # "expected" mirroring
|
||||||
# if number of steps is even, squeeze duration down to (number_of_steps)/2
|
# if number of steps is even, squeeze duration down to (number_of_steps)/2
|
||||||
# and create reverse copy of list to append
|
# and create reverse copy of list to append
|
||||||
@ -178,7 +178,7 @@ class StepParamEasingInvocation(BaseInvocation):
|
|||||||
end=self.end_value,
|
end=self.end_value,
|
||||||
duration=base_easing_duration - 1,
|
duration=base_easing_duration - 1,
|
||||||
)
|
)
|
||||||
base_easing_vals = list()
|
base_easing_vals = []
|
||||||
for step_index in range(base_easing_duration):
|
for step_index in range(base_easing_duration):
|
||||||
easing_val = easing_function.ease(step_index)
|
easing_val = easing_function.ease(step_index)
|
||||||
base_easing_vals.append(easing_val)
|
base_easing_vals.append(easing_val)
|
||||||
|
@ -139,7 +139,7 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
|
|||||||
(board_id,),
|
(board_id,),
|
||||||
)
|
)
|
||||||
result = cast(list[sqlite3.Row], self._cursor.fetchall())
|
result = cast(list[sqlite3.Row], self._cursor.fetchall())
|
||||||
images = list(map(lambda r: deserialize_image_record(dict(r)), result))
|
images = [deserialize_image_record(dict(r)) for r in result]
|
||||||
|
|
||||||
self._cursor.execute(
|
self._cursor.execute(
|
||||||
"""--sql
|
"""--sql
|
||||||
@ -167,7 +167,7 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
|
|||||||
(board_id,),
|
(board_id,),
|
||||||
)
|
)
|
||||||
result = cast(list[sqlite3.Row], self._cursor.fetchall())
|
result = cast(list[sqlite3.Row], self._cursor.fetchall())
|
||||||
image_names = list(map(lambda r: r[0], result))
|
image_names = [r[0] for r in result]
|
||||||
return image_names
|
return image_names
|
||||||
except sqlite3.Error as e:
|
except sqlite3.Error as e:
|
||||||
self._conn.rollback()
|
self._conn.rollback()
|
||||||
|
@ -199,7 +199,7 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
result = cast(list[sqlite3.Row], self._cursor.fetchall())
|
result = cast(list[sqlite3.Row], self._cursor.fetchall())
|
||||||
boards = list(map(lambda r: deserialize_board_record(dict(r)), result))
|
boards = [deserialize_board_record(dict(r)) for r in result]
|
||||||
|
|
||||||
# Get the total number of boards
|
# Get the total number of boards
|
||||||
self._cursor.execute(
|
self._cursor.execute(
|
||||||
@ -236,7 +236,7 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
result = cast(list[sqlite3.Row], self._cursor.fetchall())
|
result = cast(list[sqlite3.Row], self._cursor.fetchall())
|
||||||
boards = list(map(lambda r: deserialize_board_record(dict(r)), result))
|
boards = [deserialize_board_record(dict(r)) for r in result]
|
||||||
|
|
||||||
return boards
|
return boards
|
||||||
|
|
||||||
|
@ -55,7 +55,7 @@ class InvokeAISettings(BaseSettings):
|
|||||||
"""
|
"""
|
||||||
cls = self.__class__
|
cls = self.__class__
|
||||||
type = get_args(get_type_hints(cls)["type"])[0]
|
type = get_args(get_type_hints(cls)["type"])[0]
|
||||||
field_dict = dict({type: dict()})
|
field_dict = {type: {}}
|
||||||
for name, field in self.model_fields.items():
|
for name, field in self.model_fields.items():
|
||||||
if name in cls._excluded_from_yaml():
|
if name in cls._excluded_from_yaml():
|
||||||
continue
|
continue
|
||||||
@ -64,7 +64,7 @@ class InvokeAISettings(BaseSettings):
|
|||||||
)
|
)
|
||||||
value = getattr(self, name)
|
value = getattr(self, name)
|
||||||
if category not in field_dict[type]:
|
if category not in field_dict[type]:
|
||||||
field_dict[type][category] = dict()
|
field_dict[type][category] = {}
|
||||||
# keep paths as strings to make it easier to read
|
# keep paths as strings to make it easier to read
|
||||||
field_dict[type][category][name] = str(value) if isinstance(value, Path) else value
|
field_dict[type][category][name] = str(value) if isinstance(value, Path) else value
|
||||||
conf = OmegaConf.create(field_dict)
|
conf = OmegaConf.create(field_dict)
|
||||||
@ -89,7 +89,7 @@ class InvokeAISettings(BaseSettings):
|
|||||||
# create an upcase version of the environment in
|
# create an upcase version of the environment in
|
||||||
# order to achieve case-insensitive environment
|
# order to achieve case-insensitive environment
|
||||||
# variables (the way Windows does)
|
# variables (the way Windows does)
|
||||||
upcase_environ = dict()
|
upcase_environ = {}
|
||||||
for key, value in os.environ.items():
|
for key, value in os.environ.items():
|
||||||
upcase_environ[key.upper()] = value
|
upcase_environ[key.upper()] = value
|
||||||
|
|
||||||
|
@ -188,18 +188,18 @@ DEFAULT_MAX_VRAM = 0.5
|
|||||||
|
|
||||||
|
|
||||||
class Categories(object):
|
class Categories(object):
|
||||||
WebServer = dict(category="Web Server")
|
WebServer = {"category": "Web Server"}
|
||||||
Features = dict(category="Features")
|
Features = {"category": "Features"}
|
||||||
Paths = dict(category="Paths")
|
Paths = {"category": "Paths"}
|
||||||
Logging = dict(category="Logging")
|
Logging = {"category": "Logging"}
|
||||||
Development = dict(category="Development")
|
Development = {"category": "Development"}
|
||||||
Other = dict(category="Other")
|
Other = {"category": "Other"}
|
||||||
ModelCache = dict(category="Model Cache")
|
ModelCache = {"category": "Model Cache"}
|
||||||
Device = dict(category="Device")
|
Device = {"category": "Device"}
|
||||||
Generation = dict(category="Generation")
|
Generation = {"category": "Generation"}
|
||||||
Queue = dict(category="Queue")
|
Queue = {"category": "Queue"}
|
||||||
Nodes = dict(category="Nodes")
|
Nodes = {"category": "Nodes"}
|
||||||
MemoryPerformance = dict(category="Memory/Performance")
|
MemoryPerformance = {"category": "Memory/Performance"}
|
||||||
|
|
||||||
|
|
||||||
class InvokeAIAppConfig(InvokeAISettings):
|
class InvokeAIAppConfig(InvokeAISettings):
|
||||||
@ -482,7 +482,7 @@ def _find_root() -> Path:
|
|||||||
venv = Path(os.environ.get("VIRTUAL_ENV") or ".")
|
venv = Path(os.environ.get("VIRTUAL_ENV") or ".")
|
||||||
if os.environ.get("INVOKEAI_ROOT"):
|
if os.environ.get("INVOKEAI_ROOT"):
|
||||||
root = Path(os.environ["INVOKEAI_ROOT"])
|
root = Path(os.environ["INVOKEAI_ROOT"])
|
||||||
elif any([(venv.parent / x).exists() for x in [INIT_FILE, LEGACY_INIT_FILE]]):
|
elif any((venv.parent / x).exists() for x in [INIT_FILE, LEGACY_INIT_FILE]):
|
||||||
root = (venv.parent).resolve()
|
root = (venv.parent).resolve()
|
||||||
else:
|
else:
|
||||||
root = Path("~/invokeai").expanduser().resolve()
|
root = Path("~/invokeai").expanduser().resolve()
|
||||||
|
@ -27,7 +27,7 @@ class EventServiceBase:
|
|||||||
payload["timestamp"] = get_timestamp()
|
payload["timestamp"] = get_timestamp()
|
||||||
self.dispatch(
|
self.dispatch(
|
||||||
event_name=EventServiceBase.queue_event,
|
event_name=EventServiceBase.queue_event,
|
||||||
payload=dict(event=event_name, data=payload),
|
payload={"event": event_name, "data": payload},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Define events here for every event in the system.
|
# Define events here for every event in the system.
|
||||||
@ -48,18 +48,18 @@ class EventServiceBase:
|
|||||||
"""Emitted when there is generation progress"""
|
"""Emitted when there is generation progress"""
|
||||||
self.__emit_queue_event(
|
self.__emit_queue_event(
|
||||||
event_name="generator_progress",
|
event_name="generator_progress",
|
||||||
payload=dict(
|
payload={
|
||||||
queue_id=queue_id,
|
"queue_id": queue_id,
|
||||||
queue_item_id=queue_item_id,
|
"queue_item_id": queue_item_id,
|
||||||
queue_batch_id=queue_batch_id,
|
"queue_batch_id": queue_batch_id,
|
||||||
graph_execution_state_id=graph_execution_state_id,
|
"graph_execution_state_id": graph_execution_state_id,
|
||||||
node_id=node.get("id"),
|
"node_id": node.get("id"),
|
||||||
source_node_id=source_node_id,
|
"source_node_id": source_node_id,
|
||||||
progress_image=progress_image.model_dump() if progress_image is not None else None,
|
"progress_image": progress_image.model_dump() if progress_image is not None else None,
|
||||||
step=step,
|
"step": step,
|
||||||
order=order,
|
"order": order,
|
||||||
total_steps=total_steps,
|
"total_steps": total_steps,
|
||||||
),
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
def emit_invocation_complete(
|
def emit_invocation_complete(
|
||||||
@ -75,15 +75,15 @@ class EventServiceBase:
|
|||||||
"""Emitted when an invocation has completed"""
|
"""Emitted when an invocation has completed"""
|
||||||
self.__emit_queue_event(
|
self.__emit_queue_event(
|
||||||
event_name="invocation_complete",
|
event_name="invocation_complete",
|
||||||
payload=dict(
|
payload={
|
||||||
queue_id=queue_id,
|
"queue_id": queue_id,
|
||||||
queue_item_id=queue_item_id,
|
"queue_item_id": queue_item_id,
|
||||||
queue_batch_id=queue_batch_id,
|
"queue_batch_id": queue_batch_id,
|
||||||
graph_execution_state_id=graph_execution_state_id,
|
"graph_execution_state_id": graph_execution_state_id,
|
||||||
node=node,
|
"node": node,
|
||||||
source_node_id=source_node_id,
|
"source_node_id": source_node_id,
|
||||||
result=result,
|
"result": result,
|
||||||
),
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
def emit_invocation_error(
|
def emit_invocation_error(
|
||||||
@ -100,16 +100,16 @@ class EventServiceBase:
|
|||||||
"""Emitted when an invocation has completed"""
|
"""Emitted when an invocation has completed"""
|
||||||
self.__emit_queue_event(
|
self.__emit_queue_event(
|
||||||
event_name="invocation_error",
|
event_name="invocation_error",
|
||||||
payload=dict(
|
payload={
|
||||||
queue_id=queue_id,
|
"queue_id": queue_id,
|
||||||
queue_item_id=queue_item_id,
|
"queue_item_id": queue_item_id,
|
||||||
queue_batch_id=queue_batch_id,
|
"queue_batch_id": queue_batch_id,
|
||||||
graph_execution_state_id=graph_execution_state_id,
|
"graph_execution_state_id": graph_execution_state_id,
|
||||||
node=node,
|
"node": node,
|
||||||
source_node_id=source_node_id,
|
"source_node_id": source_node_id,
|
||||||
error_type=error_type,
|
"error_type": error_type,
|
||||||
error=error,
|
"error": error,
|
||||||
),
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
def emit_invocation_started(
|
def emit_invocation_started(
|
||||||
@ -124,14 +124,14 @@ class EventServiceBase:
|
|||||||
"""Emitted when an invocation has started"""
|
"""Emitted when an invocation has started"""
|
||||||
self.__emit_queue_event(
|
self.__emit_queue_event(
|
||||||
event_name="invocation_started",
|
event_name="invocation_started",
|
||||||
payload=dict(
|
payload={
|
||||||
queue_id=queue_id,
|
"queue_id": queue_id,
|
||||||
queue_item_id=queue_item_id,
|
"queue_item_id": queue_item_id,
|
||||||
queue_batch_id=queue_batch_id,
|
"queue_batch_id": queue_batch_id,
|
||||||
graph_execution_state_id=graph_execution_state_id,
|
"graph_execution_state_id": graph_execution_state_id,
|
||||||
node=node,
|
"node": node,
|
||||||
source_node_id=source_node_id,
|
"source_node_id": source_node_id,
|
||||||
),
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
def emit_graph_execution_complete(
|
def emit_graph_execution_complete(
|
||||||
@ -140,12 +140,12 @@ class EventServiceBase:
|
|||||||
"""Emitted when a session has completed all invocations"""
|
"""Emitted when a session has completed all invocations"""
|
||||||
self.__emit_queue_event(
|
self.__emit_queue_event(
|
||||||
event_name="graph_execution_state_complete",
|
event_name="graph_execution_state_complete",
|
||||||
payload=dict(
|
payload={
|
||||||
queue_id=queue_id,
|
"queue_id": queue_id,
|
||||||
queue_item_id=queue_item_id,
|
"queue_item_id": queue_item_id,
|
||||||
queue_batch_id=queue_batch_id,
|
"queue_batch_id": queue_batch_id,
|
||||||
graph_execution_state_id=graph_execution_state_id,
|
"graph_execution_state_id": graph_execution_state_id,
|
||||||
),
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
def emit_model_load_started(
|
def emit_model_load_started(
|
||||||
@ -162,16 +162,16 @@ class EventServiceBase:
|
|||||||
"""Emitted when a model is requested"""
|
"""Emitted when a model is requested"""
|
||||||
self.__emit_queue_event(
|
self.__emit_queue_event(
|
||||||
event_name="model_load_started",
|
event_name="model_load_started",
|
||||||
payload=dict(
|
payload={
|
||||||
queue_id=queue_id,
|
"queue_id": queue_id,
|
||||||
queue_item_id=queue_item_id,
|
"queue_item_id": queue_item_id,
|
||||||
queue_batch_id=queue_batch_id,
|
"queue_batch_id": queue_batch_id,
|
||||||
graph_execution_state_id=graph_execution_state_id,
|
"graph_execution_state_id": graph_execution_state_id,
|
||||||
model_name=model_name,
|
"model_name": model_name,
|
||||||
base_model=base_model,
|
"base_model": base_model,
|
||||||
model_type=model_type,
|
"model_type": model_type,
|
||||||
submodel=submodel,
|
"submodel": submodel,
|
||||||
),
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
def emit_model_load_completed(
|
def emit_model_load_completed(
|
||||||
@ -189,19 +189,19 @@ class EventServiceBase:
|
|||||||
"""Emitted when a model is correctly loaded (returns model info)"""
|
"""Emitted when a model is correctly loaded (returns model info)"""
|
||||||
self.__emit_queue_event(
|
self.__emit_queue_event(
|
||||||
event_name="model_load_completed",
|
event_name="model_load_completed",
|
||||||
payload=dict(
|
payload={
|
||||||
queue_id=queue_id,
|
"queue_id": queue_id,
|
||||||
queue_item_id=queue_item_id,
|
"queue_item_id": queue_item_id,
|
||||||
queue_batch_id=queue_batch_id,
|
"queue_batch_id": queue_batch_id,
|
||||||
graph_execution_state_id=graph_execution_state_id,
|
"graph_execution_state_id": graph_execution_state_id,
|
||||||
model_name=model_name,
|
"model_name": model_name,
|
||||||
base_model=base_model,
|
"base_model": base_model,
|
||||||
model_type=model_type,
|
"model_type": model_type,
|
||||||
submodel=submodel,
|
"submodel": submodel,
|
||||||
hash=model_info.hash,
|
"hash": model_info.hash,
|
||||||
location=str(model_info.location),
|
"location": str(model_info.location),
|
||||||
precision=str(model_info.precision),
|
"precision": str(model_info.precision),
|
||||||
),
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
def emit_session_retrieval_error(
|
def emit_session_retrieval_error(
|
||||||
@ -216,14 +216,14 @@ class EventServiceBase:
|
|||||||
"""Emitted when session retrieval fails"""
|
"""Emitted when session retrieval fails"""
|
||||||
self.__emit_queue_event(
|
self.__emit_queue_event(
|
||||||
event_name="session_retrieval_error",
|
event_name="session_retrieval_error",
|
||||||
payload=dict(
|
payload={
|
||||||
queue_id=queue_id,
|
"queue_id": queue_id,
|
||||||
queue_item_id=queue_item_id,
|
"queue_item_id": queue_item_id,
|
||||||
queue_batch_id=queue_batch_id,
|
"queue_batch_id": queue_batch_id,
|
||||||
graph_execution_state_id=graph_execution_state_id,
|
"graph_execution_state_id": graph_execution_state_id,
|
||||||
error_type=error_type,
|
"error_type": error_type,
|
||||||
error=error,
|
"error": error,
|
||||||
),
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
def emit_invocation_retrieval_error(
|
def emit_invocation_retrieval_error(
|
||||||
@ -239,15 +239,15 @@ class EventServiceBase:
|
|||||||
"""Emitted when invocation retrieval fails"""
|
"""Emitted when invocation retrieval fails"""
|
||||||
self.__emit_queue_event(
|
self.__emit_queue_event(
|
||||||
event_name="invocation_retrieval_error",
|
event_name="invocation_retrieval_error",
|
||||||
payload=dict(
|
payload={
|
||||||
queue_id=queue_id,
|
"queue_id": queue_id,
|
||||||
queue_item_id=queue_item_id,
|
"queue_item_id": queue_item_id,
|
||||||
queue_batch_id=queue_batch_id,
|
"queue_batch_id": queue_batch_id,
|
||||||
graph_execution_state_id=graph_execution_state_id,
|
"graph_execution_state_id": graph_execution_state_id,
|
||||||
node_id=node_id,
|
"node_id": node_id,
|
||||||
error_type=error_type,
|
"error_type": error_type,
|
||||||
error=error,
|
"error": error,
|
||||||
),
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
def emit_session_canceled(
|
def emit_session_canceled(
|
||||||
@ -260,12 +260,12 @@ class EventServiceBase:
|
|||||||
"""Emitted when a session is canceled"""
|
"""Emitted when a session is canceled"""
|
||||||
self.__emit_queue_event(
|
self.__emit_queue_event(
|
||||||
event_name="session_canceled",
|
event_name="session_canceled",
|
||||||
payload=dict(
|
payload={
|
||||||
queue_id=queue_id,
|
"queue_id": queue_id,
|
||||||
queue_item_id=queue_item_id,
|
"queue_item_id": queue_item_id,
|
||||||
queue_batch_id=queue_batch_id,
|
"queue_batch_id": queue_batch_id,
|
||||||
graph_execution_state_id=graph_execution_state_id,
|
"graph_execution_state_id": graph_execution_state_id,
|
||||||
),
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
def emit_queue_item_status_changed(
|
def emit_queue_item_status_changed(
|
||||||
@ -277,39 +277,39 @@ class EventServiceBase:
|
|||||||
"""Emitted when a queue item's status changes"""
|
"""Emitted when a queue item's status changes"""
|
||||||
self.__emit_queue_event(
|
self.__emit_queue_event(
|
||||||
event_name="queue_item_status_changed",
|
event_name="queue_item_status_changed",
|
||||||
payload=dict(
|
payload={
|
||||||
queue_id=queue_status.queue_id,
|
"queue_id": queue_status.queue_id,
|
||||||
queue_item=dict(
|
"queue_item": {
|
||||||
queue_id=session_queue_item.queue_id,
|
"queue_id": session_queue_item.queue_id,
|
||||||
item_id=session_queue_item.item_id,
|
"item_id": session_queue_item.item_id,
|
||||||
status=session_queue_item.status,
|
"status": session_queue_item.status,
|
||||||
batch_id=session_queue_item.batch_id,
|
"batch_id": session_queue_item.batch_id,
|
||||||
session_id=session_queue_item.session_id,
|
"session_id": session_queue_item.session_id,
|
||||||
error=session_queue_item.error,
|
"error": session_queue_item.error,
|
||||||
created_at=str(session_queue_item.created_at) if session_queue_item.created_at else None,
|
"created_at": str(session_queue_item.created_at) if session_queue_item.created_at else None,
|
||||||
updated_at=str(session_queue_item.updated_at) if session_queue_item.updated_at else None,
|
"updated_at": str(session_queue_item.updated_at) if session_queue_item.updated_at else None,
|
||||||
started_at=str(session_queue_item.started_at) if session_queue_item.started_at else None,
|
"started_at": str(session_queue_item.started_at) if session_queue_item.started_at else None,
|
||||||
completed_at=str(session_queue_item.completed_at) if session_queue_item.completed_at else None,
|
"completed_at": str(session_queue_item.completed_at) if session_queue_item.completed_at else None,
|
||||||
),
|
},
|
||||||
batch_status=batch_status.model_dump(),
|
"batch_status": batch_status.model_dump(),
|
||||||
queue_status=queue_status.model_dump(),
|
"queue_status": queue_status.model_dump(),
|
||||||
),
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
def emit_batch_enqueued(self, enqueue_result: EnqueueBatchResult) -> None:
|
def emit_batch_enqueued(self, enqueue_result: EnqueueBatchResult) -> None:
|
||||||
"""Emitted when a batch is enqueued"""
|
"""Emitted when a batch is enqueued"""
|
||||||
self.__emit_queue_event(
|
self.__emit_queue_event(
|
||||||
event_name="batch_enqueued",
|
event_name="batch_enqueued",
|
||||||
payload=dict(
|
payload={
|
||||||
queue_id=enqueue_result.queue_id,
|
"queue_id": enqueue_result.queue_id,
|
||||||
batch_id=enqueue_result.batch.batch_id,
|
"batch_id": enqueue_result.batch.batch_id,
|
||||||
enqueued=enqueue_result.enqueued,
|
"enqueued": enqueue_result.enqueued,
|
||||||
),
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
def emit_queue_cleared(self, queue_id: str) -> None:
|
def emit_queue_cleared(self, queue_id: str) -> None:
|
||||||
"""Emitted when the queue is cleared"""
|
"""Emitted when the queue is cleared"""
|
||||||
self.__emit_queue_event(
|
self.__emit_queue_event(
|
||||||
event_name="queue_cleared",
|
event_name="queue_cleared",
|
||||||
payload=dict(queue_id=queue_id),
|
payload={"queue_id": queue_id},
|
||||||
)
|
)
|
||||||
|
@ -25,7 +25,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
|
|||||||
__invoker: Invoker
|
__invoker: Invoker
|
||||||
|
|
||||||
def __init__(self, output_folder: Union[str, Path]):
|
def __init__(self, output_folder: Union[str, Path]):
|
||||||
self.__cache = dict()
|
self.__cache = {}
|
||||||
self.__cache_ids = Queue()
|
self.__cache_ids = Queue()
|
||||||
self.__max_cache_size = 10 # TODO: get this from config
|
self.__max_cache_size = 10 # TODO: get this from config
|
||||||
|
|
||||||
|
@ -90,10 +90,9 @@ class ImageRecordDeleteException(Exception):
|
|||||||
|
|
||||||
|
|
||||||
IMAGE_DTO_COLS = ", ".join(
|
IMAGE_DTO_COLS = ", ".join(
|
||||||
list(
|
|
||||||
map(
|
|
||||||
lambda c: "images." + c,
|
|
||||||
[
|
[
|
||||||
|
"images." + c
|
||||||
|
for c in [
|
||||||
"image_name",
|
"image_name",
|
||||||
"image_origin",
|
"image_origin",
|
||||||
"image_category",
|
"image_category",
|
||||||
@ -106,9 +105,8 @@ IMAGE_DTO_COLS = ", ".join(
|
|||||||
"updated_at",
|
"updated_at",
|
||||||
"deleted_at",
|
"deleted_at",
|
||||||
"starred",
|
"starred",
|
||||||
],
|
]
|
||||||
)
|
]
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -263,7 +263,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
|||||||
|
|
||||||
if categories is not None:
|
if categories is not None:
|
||||||
# Convert the enum values to unique list of strings
|
# Convert the enum values to unique list of strings
|
||||||
category_strings = list(map(lambda c: c.value, set(categories)))
|
category_strings = [c.value for c in set(categories)]
|
||||||
# Create the correct length of placeholders
|
# Create the correct length of placeholders
|
||||||
placeholders = ",".join("?" * len(category_strings))
|
placeholders = ",".join("?" * len(category_strings))
|
||||||
|
|
||||||
@ -307,7 +307,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
|||||||
# Build the list of images, deserializing each row
|
# Build the list of images, deserializing each row
|
||||||
self._cursor.execute(images_query, images_params)
|
self._cursor.execute(images_query, images_params)
|
||||||
result = cast(list[sqlite3.Row], self._cursor.fetchall())
|
result = cast(list[sqlite3.Row], self._cursor.fetchall())
|
||||||
images = list(map(lambda r: deserialize_image_record(dict(r)), result))
|
images = [deserialize_image_record(dict(r)) for r in result]
|
||||||
|
|
||||||
# Set up and execute the count query, without pagination
|
# Set up and execute the count query, without pagination
|
||||||
count_query += query_conditions + ";"
|
count_query += query_conditions + ";"
|
||||||
@ -386,7 +386,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
|||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
result = cast(list[sqlite3.Row], self._cursor.fetchall())
|
result = cast(list[sqlite3.Row], self._cursor.fetchall())
|
||||||
image_names = list(map(lambda r: r[0], result))
|
image_names = [r[0] for r in result]
|
||||||
self._cursor.execute(
|
self._cursor.execute(
|
||||||
"""--sql
|
"""--sql
|
||||||
DELETE FROM images
|
DELETE FROM images
|
||||||
|
@ -21,8 +21,8 @@ class ImageServiceABC(ABC):
|
|||||||
_on_deleted_callbacks: list[Callable[[str], None]]
|
_on_deleted_callbacks: list[Callable[[str], None]]
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self._on_changed_callbacks = list()
|
self._on_changed_callbacks = []
|
||||||
self._on_deleted_callbacks = list()
|
self._on_deleted_callbacks = []
|
||||||
|
|
||||||
def on_changed(self, on_changed: Callable[[ImageDTO], None]) -> None:
|
def on_changed(self, on_changed: Callable[[ImageDTO], None]) -> None:
|
||||||
"""Register a callback for when an image is changed"""
|
"""Register a callback for when an image is changed"""
|
||||||
|
@ -217,18 +217,16 @@ class ImageService(ImageServiceABC):
|
|||||||
board_id,
|
board_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
image_dtos = list(
|
image_dtos = [
|
||||||
map(
|
image_record_to_dto(
|
||||||
lambda r: image_record_to_dto(
|
|
||||||
image_record=r,
|
image_record=r,
|
||||||
image_url=self.__invoker.services.urls.get_image_url(r.image_name),
|
image_url=self.__invoker.services.urls.get_image_url(r.image_name),
|
||||||
thumbnail_url=self.__invoker.services.urls.get_image_url(r.image_name, True),
|
thumbnail_url=self.__invoker.services.urls.get_image_url(r.image_name, True),
|
||||||
board_id=self.__invoker.services.board_image_records.get_board_for_image(r.image_name),
|
board_id=self.__invoker.services.board_image_records.get_board_for_image(r.image_name),
|
||||||
workflow_id=self.__invoker.services.workflow_image_records.get_workflow_for_image(r.image_name),
|
workflow_id=self.__invoker.services.workflow_image_records.get_workflow_for_image(r.image_name),
|
||||||
),
|
|
||||||
results.items,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
for r in results.items
|
||||||
|
]
|
||||||
|
|
||||||
return OffsetPaginatedResults[ImageDTO](
|
return OffsetPaginatedResults[ImageDTO](
|
||||||
items=image_dtos,
|
items=image_dtos,
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
from abc import ABC
|
from abc import ABC
|
||||||
|
|
||||||
|
|
||||||
class InvocationProcessorABC(ABC):
|
class InvocationProcessorABC(ABC): # noqa: B024
|
||||||
pass
|
pass
|
||||||
|
@ -26,7 +26,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
|||||||
self.__invoker_thread = Thread(
|
self.__invoker_thread = Thread(
|
||||||
name="invoker_processor",
|
name="invoker_processor",
|
||||||
target=self.__process,
|
target=self.__process,
|
||||||
kwargs=dict(stop_event=self.__stop_event),
|
kwargs={"stop_event": self.__stop_event},
|
||||||
)
|
)
|
||||||
self.__invoker_thread.daemon = True # TODO: make async and do not use threads
|
self.__invoker_thread.daemon = True # TODO: make async and do not use threads
|
||||||
self.__invoker_thread.start()
|
self.__invoker_thread.start()
|
||||||
|
@ -14,7 +14,7 @@ class MemoryInvocationQueue(InvocationQueueABC):
|
|||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.__queue = Queue()
|
self.__queue = Queue()
|
||||||
self.__cancellations = dict()
|
self.__cancellations = {}
|
||||||
|
|
||||||
def get(self) -> InvocationQueueItem:
|
def get(self) -> InvocationQueueItem:
|
||||||
item = self.__queue.get()
|
item = self.__queue.get()
|
||||||
|
@ -122,7 +122,7 @@ class InvocationStatsService(InvocationStatsServiceBase):
|
|||||||
def log_stats(self):
|
def log_stats(self):
|
||||||
completed = set()
|
completed = set()
|
||||||
errored = set()
|
errored = set()
|
||||||
for graph_id, node_log in self._stats.items():
|
for graph_id, _node_log in self._stats.items():
|
||||||
try:
|
try:
|
||||||
current_graph_state = self._invoker.services.graph_execution_manager.get(graph_id)
|
current_graph_state = self._invoker.services.graph_execution_manager.get(graph_id)
|
||||||
except Exception:
|
except Exception:
|
||||||
@ -142,7 +142,7 @@ class InvocationStatsService(InvocationStatsServiceBase):
|
|||||||
cache_stats = self._cache_stats[graph_id]
|
cache_stats = self._cache_stats[graph_id]
|
||||||
hwm = cache_stats.high_watermark / GIG
|
hwm = cache_stats.high_watermark / GIG
|
||||||
tot = cache_stats.cache_size / GIG
|
tot = cache_stats.cache_size / GIG
|
||||||
loaded = sum([v for v in cache_stats.loaded_model_sizes.values()]) / GIG
|
loaded = sum(list(cache_stats.loaded_model_sizes.values())) / GIG
|
||||||
|
|
||||||
logger.info(f"TOTAL GRAPH EXECUTION TIME: {total_time:7.3f}s")
|
logger.info(f"TOTAL GRAPH EXECUTION TIME: {total_time:7.3f}s")
|
||||||
logger.info("RAM used by InvokeAI process: " + "%4.2fG" % self.ram_used + f" ({self.ram_changed:+5.3f}G)")
|
logger.info("RAM used by InvokeAI process: " + "%4.2fG" % self.ram_used + f" ({self.ram_changed:+5.3f}G)")
|
||||||
|
@ -15,8 +15,8 @@ class ItemStorageABC(ABC, Generic[T]):
|
|||||||
_on_deleted_callbacks: list[Callable[[str], None]]
|
_on_deleted_callbacks: list[Callable[[str], None]]
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self._on_changed_callbacks = list()
|
self._on_changed_callbacks = []
|
||||||
self._on_deleted_callbacks = list()
|
self._on_deleted_callbacks = []
|
||||||
|
|
||||||
"""Base item storage class"""
|
"""Base item storage class"""
|
||||||
|
|
||||||
|
@ -112,7 +112,7 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
|||||||
)
|
)
|
||||||
result = self._cursor.fetchall()
|
result = self._cursor.fetchall()
|
||||||
|
|
||||||
items = list(map(lambda r: self._parse_item(r[0]), result))
|
items = [self._parse_item(r[0]) for r in result]
|
||||||
|
|
||||||
self._cursor.execute(f"""SELECT count(*) FROM {self._table_name};""")
|
self._cursor.execute(f"""SELECT count(*) FROM {self._table_name};""")
|
||||||
count = self._cursor.fetchone()[0]
|
count = self._cursor.fetchone()[0]
|
||||||
@ -132,7 +132,7 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
|||||||
)
|
)
|
||||||
result = self._cursor.fetchall()
|
result = self._cursor.fetchall()
|
||||||
|
|
||||||
items = list(map(lambda r: self._parse_item(r[0]), result))
|
items = [self._parse_item(r[0]) for r in result]
|
||||||
|
|
||||||
self._cursor.execute(
|
self._cursor.execute(
|
||||||
f"""SELECT count(*) FROM {self._table_name} WHERE item LIKE ?;""",
|
f"""SELECT count(*) FROM {self._table_name} WHERE item LIKE ?;""",
|
||||||
|
@ -13,8 +13,8 @@ class LatentsStorageBase(ABC):
|
|||||||
_on_deleted_callbacks: list[Callable[[str], None]]
|
_on_deleted_callbacks: list[Callable[[str], None]]
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self._on_changed_callbacks = list()
|
self._on_changed_callbacks = []
|
||||||
self._on_deleted_callbacks = list()
|
self._on_deleted_callbacks = []
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get(self, name: str) -> torch.Tensor:
|
def get(self, name: str) -> torch.Tensor:
|
||||||
|
@ -19,7 +19,7 @@ class ForwardCacheLatentsStorage(LatentsStorageBase):
|
|||||||
def __init__(self, underlying_storage: LatentsStorageBase, max_cache_size: int = 20):
|
def __init__(self, underlying_storage: LatentsStorageBase, max_cache_size: int = 20):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.__underlying_storage = underlying_storage
|
self.__underlying_storage = underlying_storage
|
||||||
self.__cache = dict()
|
self.__cache = {}
|
||||||
self.__cache_ids = Queue()
|
self.__cache_ids = Queue()
|
||||||
self.__max_cache_size = max_cache_size
|
self.__max_cache_size = max_cache_size
|
||||||
|
|
||||||
|
@ -33,9 +33,11 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
|||||||
self.__thread = Thread(
|
self.__thread = Thread(
|
||||||
name="session_processor",
|
name="session_processor",
|
||||||
target=self.__process,
|
target=self.__process,
|
||||||
kwargs=dict(
|
kwargs={
|
||||||
stop_event=self.__stop_event, poll_now_event=self.__poll_now_event, resume_event=self.__resume_event
|
"stop_event": self.__stop_event,
|
||||||
),
|
"poll_now_event": self.__poll_now_event,
|
||||||
|
"resume_event": self.__resume_event,
|
||||||
|
},
|
||||||
)
|
)
|
||||||
self.__thread.start()
|
self.__thread.start()
|
||||||
|
|
||||||
|
@ -129,12 +129,12 @@ class Batch(BaseModel):
|
|||||||
return v
|
return v
|
||||||
|
|
||||||
model_config = ConfigDict(
|
model_config = ConfigDict(
|
||||||
json_schema_extra=dict(
|
json_schema_extra={
|
||||||
required=[
|
"required": [
|
||||||
"graph",
|
"graph",
|
||||||
"runs",
|
"runs",
|
||||||
]
|
]
|
||||||
)
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -191,8 +191,8 @@ class SessionQueueItemWithoutGraph(BaseModel):
|
|||||||
return SessionQueueItemDTO(**queue_item_dict)
|
return SessionQueueItemDTO(**queue_item_dict)
|
||||||
|
|
||||||
model_config = ConfigDict(
|
model_config = ConfigDict(
|
||||||
json_schema_extra=dict(
|
json_schema_extra={
|
||||||
required=[
|
"required": [
|
||||||
"item_id",
|
"item_id",
|
||||||
"status",
|
"status",
|
||||||
"batch_id",
|
"batch_id",
|
||||||
@ -203,7 +203,7 @@ class SessionQueueItemWithoutGraph(BaseModel):
|
|||||||
"created_at",
|
"created_at",
|
||||||
"updated_at",
|
"updated_at",
|
||||||
]
|
]
|
||||||
)
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -222,8 +222,8 @@ class SessionQueueItem(SessionQueueItemWithoutGraph):
|
|||||||
return SessionQueueItem(**queue_item_dict)
|
return SessionQueueItem(**queue_item_dict)
|
||||||
|
|
||||||
model_config = ConfigDict(
|
model_config = ConfigDict(
|
||||||
json_schema_extra=dict(
|
json_schema_extra={
|
||||||
required=[
|
"required": [
|
||||||
"item_id",
|
"item_id",
|
||||||
"status",
|
"status",
|
||||||
"batch_id",
|
"batch_id",
|
||||||
@ -235,7 +235,7 @@ class SessionQueueItem(SessionQueueItemWithoutGraph):
|
|||||||
"created_at",
|
"created_at",
|
||||||
"updated_at",
|
"updated_at",
|
||||||
]
|
]
|
||||||
)
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -355,7 +355,7 @@ def create_session_nfv_tuples(
|
|||||||
for item in batch_datum.items
|
for item in batch_datum.items
|
||||||
]
|
]
|
||||||
node_field_values_to_zip.append(node_field_values)
|
node_field_values_to_zip.append(node_field_values)
|
||||||
data.append(list(zip(*node_field_values_to_zip))) # type: ignore [arg-type]
|
data.append(list(zip(*node_field_values_to_zip, strict=True))) # type: ignore [arg-type]
|
||||||
|
|
||||||
# create generator to yield session,nfv tuples
|
# create generator to yield session,nfv tuples
|
||||||
count = 0
|
count = 0
|
||||||
@ -383,7 +383,7 @@ def calc_session_count(batch: Batch) -> int:
|
|||||||
for batch_datum in batch_datum_list:
|
for batch_datum in batch_datum_list:
|
||||||
batch_data_items = range(len(batch_datum.items))
|
batch_data_items = range(len(batch_datum.items))
|
||||||
to_zip.append(batch_data_items)
|
to_zip.append(batch_data_items)
|
||||||
data.append(list(zip(*to_zip)))
|
data.append(list(zip(*to_zip, strict=True)))
|
||||||
data_product = list(product(*data))
|
data_product = list(product(*data))
|
||||||
return len(data_product) * batch.runs
|
return len(data_product) * batch.runs
|
||||||
|
|
||||||
|
@ -78,7 +78,7 @@ def create_system_graphs(graph_library: ItemStorageABC[LibraryGraph]) -> list[Li
|
|||||||
"""Creates the default system graphs, or adds new versions if the old ones don't match"""
|
"""Creates the default system graphs, or adds new versions if the old ones don't match"""
|
||||||
|
|
||||||
# TODO: Uncomment this when we are ready to fix this up to prevent breaking changes
|
# TODO: Uncomment this when we are ready to fix this up to prevent breaking changes
|
||||||
graphs: list[LibraryGraph] = list()
|
graphs: list[LibraryGraph] = []
|
||||||
|
|
||||||
text_to_image = graph_library.get(default_text_to_image_graph_id)
|
text_to_image = graph_library.get(default_text_to_image_graph_id)
|
||||||
|
|
||||||
|
@ -352,7 +352,7 @@ class Graph(BaseModel):
|
|||||||
|
|
||||||
# Validate that all node ids are unique
|
# Validate that all node ids are unique
|
||||||
node_ids = [n.id for n in self.nodes.values()]
|
node_ids = [n.id for n in self.nodes.values()]
|
||||||
duplicate_node_ids = set([node_id for node_id in node_ids if node_ids.count(node_id) >= 2])
|
duplicate_node_ids = {node_id for node_id in node_ids if node_ids.count(node_id) >= 2}
|
||||||
if duplicate_node_ids:
|
if duplicate_node_ids:
|
||||||
raise DuplicateNodeIdError(f"Node ids must be unique, found duplicates {duplicate_node_ids}")
|
raise DuplicateNodeIdError(f"Node ids must be unique, found duplicates {duplicate_node_ids}")
|
||||||
|
|
||||||
@ -616,7 +616,7 @@ class Graph(BaseModel):
|
|||||||
self, node_path: str, prefix: Optional[str] = None
|
self, node_path: str, prefix: Optional[str] = None
|
||||||
) -> list[tuple["Graph", Union[str, None], Edge]]:
|
) -> list[tuple["Graph", Union[str, None], Edge]]:
|
||||||
"""Gets all input edges for a node along with the graph they are in and the graph's path"""
|
"""Gets all input edges for a node along with the graph they are in and the graph's path"""
|
||||||
edges = list()
|
edges = []
|
||||||
|
|
||||||
# Return any input edges that appear in this graph
|
# Return any input edges that appear in this graph
|
||||||
edges.extend([(self, prefix, e) for e in self.edges if e.destination.node_id == node_path])
|
edges.extend([(self, prefix, e) for e in self.edges if e.destination.node_id == node_path])
|
||||||
@ -658,7 +658,7 @@ class Graph(BaseModel):
|
|||||||
self, node_path: str, prefix: Optional[str] = None
|
self, node_path: str, prefix: Optional[str] = None
|
||||||
) -> list[tuple["Graph", Union[str, None], Edge]]:
|
) -> list[tuple["Graph", Union[str, None], Edge]]:
|
||||||
"""Gets all output edges for a node along with the graph they are in and the graph's path"""
|
"""Gets all output edges for a node along with the graph they are in and the graph's path"""
|
||||||
edges = list()
|
edges = []
|
||||||
|
|
||||||
# Return any input edges that appear in this graph
|
# Return any input edges that appear in this graph
|
||||||
edges.extend([(self, prefix, e) for e in self.edges if e.source.node_id == node_path])
|
edges.extend([(self, prefix, e) for e in self.edges if e.source.node_id == node_path])
|
||||||
@ -680,8 +680,8 @@ class Graph(BaseModel):
|
|||||||
new_input: Optional[EdgeConnection] = None,
|
new_input: Optional[EdgeConnection] = None,
|
||||||
new_output: Optional[EdgeConnection] = None,
|
new_output: Optional[EdgeConnection] = None,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
inputs = list([e.source for e in self._get_input_edges(node_path, "collection")])
|
inputs = [e.source for e in self._get_input_edges(node_path, "collection")]
|
||||||
outputs = list([e.destination for e in self._get_output_edges(node_path, "item")])
|
outputs = [e.destination for e in self._get_output_edges(node_path, "item")]
|
||||||
|
|
||||||
if new_input is not None:
|
if new_input is not None:
|
||||||
inputs.append(new_input)
|
inputs.append(new_input)
|
||||||
@ -694,7 +694,7 @@ class Graph(BaseModel):
|
|||||||
|
|
||||||
# Get input and output fields (the fields linked to the iterator's input/output)
|
# Get input and output fields (the fields linked to the iterator's input/output)
|
||||||
input_field = get_output_field(self.get_node(inputs[0].node_id), inputs[0].field)
|
input_field = get_output_field(self.get_node(inputs[0].node_id), inputs[0].field)
|
||||||
output_fields = list([get_input_field(self.get_node(e.node_id), e.field) for e in outputs])
|
output_fields = [get_input_field(self.get_node(e.node_id), e.field) for e in outputs]
|
||||||
|
|
||||||
# Input type must be a list
|
# Input type must be a list
|
||||||
if get_origin(input_field) != list:
|
if get_origin(input_field) != list:
|
||||||
@ -713,8 +713,8 @@ class Graph(BaseModel):
|
|||||||
new_input: Optional[EdgeConnection] = None,
|
new_input: Optional[EdgeConnection] = None,
|
||||||
new_output: Optional[EdgeConnection] = None,
|
new_output: Optional[EdgeConnection] = None,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
inputs = list([e.source for e in self._get_input_edges(node_path, "item")])
|
inputs = [e.source for e in self._get_input_edges(node_path, "item")]
|
||||||
outputs = list([e.destination for e in self._get_output_edges(node_path, "collection")])
|
outputs = [e.destination for e in self._get_output_edges(node_path, "collection")]
|
||||||
|
|
||||||
if new_input is not None:
|
if new_input is not None:
|
||||||
inputs.append(new_input)
|
inputs.append(new_input)
|
||||||
@ -722,18 +722,16 @@ class Graph(BaseModel):
|
|||||||
outputs.append(new_output)
|
outputs.append(new_output)
|
||||||
|
|
||||||
# Get input and output fields (the fields linked to the iterator's input/output)
|
# Get input and output fields (the fields linked to the iterator's input/output)
|
||||||
input_fields = list([get_output_field(self.get_node(e.node_id), e.field) for e in inputs])
|
input_fields = [get_output_field(self.get_node(e.node_id), e.field) for e in inputs]
|
||||||
output_fields = list([get_input_field(self.get_node(e.node_id), e.field) for e in outputs])
|
output_fields = [get_input_field(self.get_node(e.node_id), e.field) for e in outputs]
|
||||||
|
|
||||||
# Validate that all inputs are derived from or match a single type
|
# Validate that all inputs are derived from or match a single type
|
||||||
input_field_types = set(
|
input_field_types = {
|
||||||
[
|
|
||||||
t
|
t
|
||||||
for input_field in input_fields
|
for input_field in input_fields
|
||||||
for t in ([input_field] if get_origin(input_field) is None else get_args(input_field))
|
for t in ([input_field] if get_origin(input_field) is None else get_args(input_field))
|
||||||
if t != NoneType
|
if t != NoneType
|
||||||
]
|
} # Get unique types
|
||||||
) # Get unique types
|
|
||||||
type_tree = nx.DiGraph()
|
type_tree = nx.DiGraph()
|
||||||
type_tree.add_nodes_from(input_field_types)
|
type_tree.add_nodes_from(input_field_types)
|
||||||
type_tree.add_edges_from([e for e in itertools.permutations(input_field_types, 2) if issubclass(e[1], e[0])])
|
type_tree.add_edges_from([e for e in itertools.permutations(input_field_types, 2) if issubclass(e[1], e[0])])
|
||||||
@ -761,15 +759,15 @@ class Graph(BaseModel):
|
|||||||
"""Returns a NetworkX DiGraph representing the layout of this graph"""
|
"""Returns a NetworkX DiGraph representing the layout of this graph"""
|
||||||
# TODO: Cache this?
|
# TODO: Cache this?
|
||||||
g = nx.DiGraph()
|
g = nx.DiGraph()
|
||||||
g.add_nodes_from([n for n in self.nodes.keys()])
|
g.add_nodes_from(list(self.nodes.keys()))
|
||||||
g.add_edges_from(set([(e.source.node_id, e.destination.node_id) for e in self.edges]))
|
g.add_edges_from({(e.source.node_id, e.destination.node_id) for e in self.edges})
|
||||||
return g
|
return g
|
||||||
|
|
||||||
def nx_graph_with_data(self) -> nx.DiGraph:
|
def nx_graph_with_data(self) -> nx.DiGraph:
|
||||||
"""Returns a NetworkX DiGraph representing the data and layout of this graph"""
|
"""Returns a NetworkX DiGraph representing the data and layout of this graph"""
|
||||||
g = nx.DiGraph()
|
g = nx.DiGraph()
|
||||||
g.add_nodes_from([n for n in self.nodes.items()])
|
g.add_nodes_from(list(self.nodes.items()))
|
||||||
g.add_edges_from(set([(e.source.node_id, e.destination.node_id) for e in self.edges]))
|
g.add_edges_from({(e.source.node_id, e.destination.node_id) for e in self.edges})
|
||||||
return g
|
return g
|
||||||
|
|
||||||
def nx_graph_flat(self, nx_graph: Optional[nx.DiGraph] = None, prefix: Optional[str] = None) -> nx.DiGraph:
|
def nx_graph_flat(self, nx_graph: Optional[nx.DiGraph] = None, prefix: Optional[str] = None) -> nx.DiGraph:
|
||||||
@ -791,7 +789,7 @@ class Graph(BaseModel):
|
|||||||
|
|
||||||
# TODO: figure out if iteration nodes need to be expanded
|
# TODO: figure out if iteration nodes need to be expanded
|
||||||
|
|
||||||
unique_edges = set([(e.source.node_id, e.destination.node_id) for e in self.edges])
|
unique_edges = {(e.source.node_id, e.destination.node_id) for e in self.edges}
|
||||||
g.add_edges_from([(self._get_node_path(e[0], prefix), self._get_node_path(e[1], prefix)) for e in unique_edges])
|
g.add_edges_from([(self._get_node_path(e[0], prefix), self._get_node_path(e[1], prefix)) for e in unique_edges])
|
||||||
return g
|
return g
|
||||||
|
|
||||||
@ -843,8 +841,8 @@ class GraphExecutionState(BaseModel):
|
|||||||
return v
|
return v
|
||||||
|
|
||||||
model_config = ConfigDict(
|
model_config = ConfigDict(
|
||||||
json_schema_extra=dict(
|
json_schema_extra={
|
||||||
required=[
|
"required": [
|
||||||
"id",
|
"id",
|
||||||
"graph",
|
"graph",
|
||||||
"execution_graph",
|
"execution_graph",
|
||||||
@ -855,7 +853,7 @@ class GraphExecutionState(BaseModel):
|
|||||||
"prepared_source_mapping",
|
"prepared_source_mapping",
|
||||||
"source_prepared_mapping",
|
"source_prepared_mapping",
|
||||||
]
|
]
|
||||||
)
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
def next(self) -> Optional[BaseInvocation]:
|
def next(self) -> Optional[BaseInvocation]:
|
||||||
@ -895,7 +893,7 @@ class GraphExecutionState(BaseModel):
|
|||||||
source_node = self.prepared_source_mapping[node_id]
|
source_node = self.prepared_source_mapping[node_id]
|
||||||
prepared_nodes = self.source_prepared_mapping[source_node]
|
prepared_nodes = self.source_prepared_mapping[source_node]
|
||||||
|
|
||||||
if all([n in self.executed for n in prepared_nodes]):
|
if all(n in self.executed for n in prepared_nodes):
|
||||||
self.executed.add(source_node)
|
self.executed.add(source_node)
|
||||||
self.executed_history.append(source_node)
|
self.executed_history.append(source_node)
|
||||||
|
|
||||||
@ -930,7 +928,7 @@ class GraphExecutionState(BaseModel):
|
|||||||
input_collection = getattr(input_collection_prepared_node_output, input_collection_edge.source.field)
|
input_collection = getattr(input_collection_prepared_node_output, input_collection_edge.source.field)
|
||||||
self_iteration_count = len(input_collection)
|
self_iteration_count = len(input_collection)
|
||||||
|
|
||||||
new_nodes: list[str] = list()
|
new_nodes: list[str] = []
|
||||||
if self_iteration_count == 0:
|
if self_iteration_count == 0:
|
||||||
# TODO: should this raise a warning? It might just happen if an empty collection is input, and should be valid.
|
# TODO: should this raise a warning? It might just happen if an empty collection is input, and should be valid.
|
||||||
return new_nodes
|
return new_nodes
|
||||||
@ -940,7 +938,7 @@ class GraphExecutionState(BaseModel):
|
|||||||
|
|
||||||
# Create new edges for this iteration
|
# Create new edges for this iteration
|
||||||
# For collect nodes, this may contain multiple inputs to the same field
|
# For collect nodes, this may contain multiple inputs to the same field
|
||||||
new_edges: list[Edge] = list()
|
new_edges: list[Edge] = []
|
||||||
for edge in input_edges:
|
for edge in input_edges:
|
||||||
for input_node_id in (n[1] for n in iteration_node_map if n[0] == edge.source.node_id):
|
for input_node_id in (n[1] for n in iteration_node_map if n[0] == edge.source.node_id):
|
||||||
new_edge = Edge(
|
new_edge = Edge(
|
||||||
@ -1034,7 +1032,7 @@ class GraphExecutionState(BaseModel):
|
|||||||
|
|
||||||
# Create execution nodes
|
# Create execution nodes
|
||||||
next_node = self.graph.get_node(next_node_id)
|
next_node = self.graph.get_node(next_node_id)
|
||||||
new_node_ids = list()
|
new_node_ids = []
|
||||||
if isinstance(next_node, CollectInvocation):
|
if isinstance(next_node, CollectInvocation):
|
||||||
# Collapse all iterator input mappings and create a single execution node for the collect invocation
|
# Collapse all iterator input mappings and create a single execution node for the collect invocation
|
||||||
all_iteration_mappings = list(
|
all_iteration_mappings = list(
|
||||||
@ -1055,7 +1053,10 @@ class GraphExecutionState(BaseModel):
|
|||||||
# For every iterator, the parent must either not be a child of that iterator, or must match the prepared iteration for that iterator
|
# For every iterator, the parent must either not be a child of that iterator, or must match the prepared iteration for that iterator
|
||||||
# TODO: Handle a node mapping to none
|
# TODO: Handle a node mapping to none
|
||||||
eg = self.execution_graph.nx_graph_flat()
|
eg = self.execution_graph.nx_graph_flat()
|
||||||
prepared_parent_mappings = [[(n, self._get_iteration_node(n, g, eg, it)) for n in next_node_parents] for it in iterator_node_prepared_combinations] # type: ignore
|
prepared_parent_mappings = [
|
||||||
|
[(n, self._get_iteration_node(n, g, eg, it)) for n in next_node_parents]
|
||||||
|
for it in iterator_node_prepared_combinations
|
||||||
|
] # type: ignore
|
||||||
|
|
||||||
# Create execution node for each iteration
|
# Create execution node for each iteration
|
||||||
for iteration_mappings in prepared_parent_mappings:
|
for iteration_mappings in prepared_parent_mappings:
|
||||||
@ -1121,7 +1122,7 @@ class GraphExecutionState(BaseModel):
|
|||||||
for edge in input_edges
|
for edge in input_edges
|
||||||
if edge.destination.field == "item"
|
if edge.destination.field == "item"
|
||||||
]
|
]
|
||||||
setattr(node, "collection", output_collection)
|
node.collection = output_collection
|
||||||
else:
|
else:
|
||||||
for edge in input_edges:
|
for edge in input_edges:
|
||||||
output_value = getattr(self.results[edge.source.node_id], edge.source.field)
|
output_value = getattr(self.results[edge.source.node_id], edge.source.field)
|
||||||
@ -1201,7 +1202,7 @@ class LibraryGraph(BaseModel):
|
|||||||
|
|
||||||
@field_validator("exposed_inputs", "exposed_outputs")
|
@field_validator("exposed_inputs", "exposed_outputs")
|
||||||
def validate_exposed_aliases(cls, v: list[Union[ExposedNodeInput, ExposedNodeOutput]]):
|
def validate_exposed_aliases(cls, v: list[Union[ExposedNodeInput, ExposedNodeOutput]]):
|
||||||
if len(v) != len(set(i.alias for i in v)):
|
if len(v) != len({i.alias for i in v}):
|
||||||
raise ValueError("Duplicate exposed alias")
|
raise ValueError("Duplicate exposed alias")
|
||||||
return v
|
return v
|
||||||
|
|
||||||
|
@ -59,7 +59,7 @@ def thin_one_time(x, kernels):
|
|||||||
|
|
||||||
def lvmin_thin(x, prunings=True):
|
def lvmin_thin(x, prunings=True):
|
||||||
y = x
|
y = x
|
||||||
for i in range(32):
|
for _i in range(32):
|
||||||
y, is_done = thin_one_time(y, lvmin_kernels)
|
y, is_done = thin_one_time(y, lvmin_kernels)
|
||||||
if is_done:
|
if is_done:
|
||||||
break
|
break
|
||||||
|
@ -21,11 +21,11 @@ def get_metadata_graph_from_raw_session(session_raw: str) -> Optional[dict]:
|
|||||||
|
|
||||||
# sanity check make sure the graph is at least reasonably shaped
|
# sanity check make sure the graph is at least reasonably shaped
|
||||||
if (
|
if (
|
||||||
type(graph) is not dict
|
not isinstance(graph, dict)
|
||||||
or "nodes" not in graph
|
or "nodes" not in graph
|
||||||
or type(graph["nodes"]) is not dict
|
or not isinstance(graph["nodes"], dict)
|
||||||
or "edges" not in graph
|
or "edges" not in graph
|
||||||
or type(graph["edges"]) is not list
|
or not isinstance(graph["edges"], list)
|
||||||
):
|
):
|
||||||
# something has gone terribly awry, return an empty dict
|
# something has gone terribly awry, return an empty dict
|
||||||
return None
|
return None
|
||||||
|
@ -88,7 +88,7 @@ class PromptFormatter:
|
|||||||
t2i = self.t2i
|
t2i = self.t2i
|
||||||
opt = self.opt
|
opt = self.opt
|
||||||
|
|
||||||
switches = list()
|
switches = []
|
||||||
switches.append(f'"{opt.prompt}"')
|
switches.append(f'"{opt.prompt}"')
|
||||||
switches.append(f"-s{opt.steps or t2i.steps}")
|
switches.append(f"-s{opt.steps or t2i.steps}")
|
||||||
switches.append(f"-W{opt.width or t2i.width}")
|
switches.append(f"-W{opt.width or t2i.width}")
|
||||||
|
@ -88,7 +88,7 @@ class Txt2Mask(object):
|
|||||||
provided image and returns a SegmentedGrayscale object in which the brighter
|
provided image and returns a SegmentedGrayscale object in which the brighter
|
||||||
pixels indicate where the object is inferred to be.
|
pixels indicate where the object is inferred to be.
|
||||||
"""
|
"""
|
||||||
if type(image) is str:
|
if isinstance(image, str):
|
||||||
image = Image.open(image).convert("RGB")
|
image = Image.open(image).convert("RGB")
|
||||||
|
|
||||||
image = ImageOps.exif_transpose(image)
|
image = ImageOps.exif_transpose(image)
|
||||||
|
@ -40,7 +40,7 @@ class InitImageResizer:
|
|||||||
(rw, rh) = (int(scale * im.width), int(scale * im.height))
|
(rw, rh) = (int(scale * im.width), int(scale * im.height))
|
||||||
|
|
||||||
# round everything to multiples of 64
|
# round everything to multiples of 64
|
||||||
width, height, rw, rh = map(lambda x: x - x % 64, (width, height, rw, rh))
|
width, height, rw, rh = (x - x % 64 for x in (width, height, rw, rh))
|
||||||
|
|
||||||
# no resize necessary, but return a copy
|
# no resize necessary, but return a copy
|
||||||
if im.width == width and im.height == height:
|
if im.width == width and im.height == height:
|
||||||
|
@ -197,7 +197,7 @@ def download_with_progress_bar(model_url: str, model_dest: str, label: str = "th
|
|||||||
|
|
||||||
def download_conversion_models():
|
def download_conversion_models():
|
||||||
target_dir = config.models_path / "core/convert"
|
target_dir = config.models_path / "core/convert"
|
||||||
kwargs = dict() # for future use
|
kwargs = {} # for future use
|
||||||
try:
|
try:
|
||||||
logger.info("Downloading core tokenizers and text encoders")
|
logger.info("Downloading core tokenizers and text encoders")
|
||||||
|
|
||||||
@ -252,26 +252,26 @@ def download_conversion_models():
|
|||||||
def download_realesrgan():
|
def download_realesrgan():
|
||||||
logger.info("Installing ESRGAN Upscaling models...")
|
logger.info("Installing ESRGAN Upscaling models...")
|
||||||
URLs = [
|
URLs = [
|
||||||
dict(
|
{
|
||||||
url="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
|
"url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
|
||||||
dest="core/upscaling/realesrgan/RealESRGAN_x4plus.pth",
|
"dest": "core/upscaling/realesrgan/RealESRGAN_x4plus.pth",
|
||||||
description="RealESRGAN_x4plus.pth",
|
"description": "RealESRGAN_x4plus.pth",
|
||||||
),
|
},
|
||||||
dict(
|
{
|
||||||
url="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth",
|
"url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth",
|
||||||
dest="core/upscaling/realesrgan/RealESRGAN_x4plus_anime_6B.pth",
|
"dest": "core/upscaling/realesrgan/RealESRGAN_x4plus_anime_6B.pth",
|
||||||
description="RealESRGAN_x4plus_anime_6B.pth",
|
"description": "RealESRGAN_x4plus_anime_6B.pth",
|
||||||
),
|
},
|
||||||
dict(
|
{
|
||||||
url="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth",
|
"url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth",
|
||||||
dest="core/upscaling/realesrgan/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth",
|
"dest": "core/upscaling/realesrgan/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth",
|
||||||
description="ESRGAN_SRx4_DF2KOST_official.pth",
|
"description": "ESRGAN_SRx4_DF2KOST_official.pth",
|
||||||
),
|
},
|
||||||
dict(
|
{
|
||||||
url="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
|
"url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
|
||||||
dest="core/upscaling/realesrgan/RealESRGAN_x2plus.pth",
|
"dest": "core/upscaling/realesrgan/RealESRGAN_x2plus.pth",
|
||||||
description="RealESRGAN_x2plus.pth",
|
"description": "RealESRGAN_x2plus.pth",
|
||||||
),
|
},
|
||||||
]
|
]
|
||||||
for model in URLs:
|
for model in URLs:
|
||||||
download_with_progress_bar(model["url"], config.models_path / model["dest"], model["description"])
|
download_with_progress_bar(model["url"], config.models_path / model["dest"], model["description"])
|
||||||
@ -680,7 +680,7 @@ def default_user_selections(program_opts: Namespace) -> InstallSelections:
|
|||||||
if program_opts.default_only
|
if program_opts.default_only
|
||||||
else [models[x].path or models[x].repo_id for x in installer.recommended_models()]
|
else [models[x].path or models[x].repo_id for x in installer.recommended_models()]
|
||||||
if program_opts.yes_to_all
|
if program_opts.yes_to_all
|
||||||
else list(),
|
else [],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -123,8 +123,6 @@ class MigrateTo3(object):
|
|||||||
logger.error(str(e))
|
logger.error(str(e))
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
|
||||||
logger.error(str(e))
|
|
||||||
for f in files:
|
for f in files:
|
||||||
# don't copy raw learned_embeds.bin or pytorch_lora_weights.bin
|
# don't copy raw learned_embeds.bin or pytorch_lora_weights.bin
|
||||||
# let them be copied as part of a tree copy operation
|
# let them be copied as part of a tree copy operation
|
||||||
@ -143,8 +141,6 @@ class MigrateTo3(object):
|
|||||||
logger.error(str(e))
|
logger.error(str(e))
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
|
||||||
logger.error(str(e))
|
|
||||||
|
|
||||||
def migrate_support_models(self):
|
def migrate_support_models(self):
|
||||||
"""
|
"""
|
||||||
@ -182,10 +178,10 @@ class MigrateTo3(object):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
dest_directory = self.dest_models
|
dest_directory = self.dest_models
|
||||||
kwargs = dict(
|
kwargs = {
|
||||||
cache_dir=self.root_directory / "models/hub",
|
"cache_dir": self.root_directory / "models/hub",
|
||||||
# local_files_only = True
|
# local_files_only = True
|
||||||
)
|
}
|
||||||
try:
|
try:
|
||||||
logger.info("Migrating core tokenizers and text encoders")
|
logger.info("Migrating core tokenizers and text encoders")
|
||||||
target_dir = dest_directory / "core" / "convert"
|
target_dir = dest_directory / "core" / "convert"
|
||||||
@ -316,11 +312,11 @@ class MigrateTo3(object):
|
|||||||
dest_dir = self.dest_models
|
dest_dir = self.dest_models
|
||||||
|
|
||||||
cache = self.root_directory / "models/hub"
|
cache = self.root_directory / "models/hub"
|
||||||
kwargs = dict(
|
kwargs = {
|
||||||
cache_dir=cache,
|
"cache_dir": cache,
|
||||||
safety_checker=None,
|
"safety_checker": None,
|
||||||
# local_files_only = True,
|
# local_files_only = True,
|
||||||
)
|
}
|
||||||
|
|
||||||
owner, repo_name = repo_id.split("/")
|
owner, repo_name = repo_id.split("/")
|
||||||
model_name = model_name or repo_name
|
model_name = model_name or repo_name
|
||||||
|
@ -120,7 +120,7 @@ class ModelInstall(object):
|
|||||||
be treated uniformly. It also sorts the models alphabetically
|
be treated uniformly. It also sorts the models alphabetically
|
||||||
by their name, to improve the display somewhat.
|
by their name, to improve the display somewhat.
|
||||||
"""
|
"""
|
||||||
model_dict = dict()
|
model_dict = {}
|
||||||
|
|
||||||
# first populate with the entries in INITIAL_MODELS.yaml
|
# first populate with the entries in INITIAL_MODELS.yaml
|
||||||
for key, value in self.datasets.items():
|
for key, value in self.datasets.items():
|
||||||
@ -134,7 +134,7 @@ class ModelInstall(object):
|
|||||||
model_dict[key] = model_info
|
model_dict[key] = model_info
|
||||||
|
|
||||||
# supplement with entries in models.yaml
|
# supplement with entries in models.yaml
|
||||||
installed_models = [x for x in self.mgr.list_models()]
|
installed_models = list(self.mgr.list_models())
|
||||||
|
|
||||||
for md in installed_models:
|
for md in installed_models:
|
||||||
base = md["base_model"]
|
base = md["base_model"]
|
||||||
@ -176,7 +176,7 @@ class ModelInstall(object):
|
|||||||
# logic here a little reversed to maintain backward compatibility
|
# logic here a little reversed to maintain backward compatibility
|
||||||
def starter_models(self, all_models: bool = False) -> Set[str]:
|
def starter_models(self, all_models: bool = False) -> Set[str]:
|
||||||
models = set()
|
models = set()
|
||||||
for key, value in self.datasets.items():
|
for key, _value in self.datasets.items():
|
||||||
name, base, model_type = ModelManager.parse_key(key)
|
name, base, model_type = ModelManager.parse_key(key)
|
||||||
if all_models or model_type in [ModelType.Main, ModelType.Vae]:
|
if all_models or model_type in [ModelType.Main, ModelType.Vae]:
|
||||||
models.add(key)
|
models.add(key)
|
||||||
@ -184,7 +184,7 @@ class ModelInstall(object):
|
|||||||
|
|
||||||
def recommended_models(self) -> Set[str]:
|
def recommended_models(self) -> Set[str]:
|
||||||
starters = self.starter_models(all_models=True)
|
starters = self.starter_models(all_models=True)
|
||||||
return set([x for x in starters if self.datasets[x].get("recommended", False)])
|
return {x for x in starters if self.datasets[x].get("recommended", False)}
|
||||||
|
|
||||||
def default_model(self) -> str:
|
def default_model(self) -> str:
|
||||||
starters = self.starter_models()
|
starters = self.starter_models()
|
||||||
@ -234,7 +234,7 @@ class ModelInstall(object):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
if not models_installed:
|
if not models_installed:
|
||||||
models_installed = dict()
|
models_installed = {}
|
||||||
|
|
||||||
model_path_id_or_url = str(model_path_id_or_url).strip("\"' ")
|
model_path_id_or_url = str(model_path_id_or_url).strip("\"' ")
|
||||||
|
|
||||||
@ -252,7 +252,6 @@ class ModelInstall(object):
|
|||||||
|
|
||||||
# folders style or similar
|
# folders style or similar
|
||||||
elif path.is_dir() and any(
|
elif path.is_dir() and any(
|
||||||
[
|
|
||||||
(path / x).exists()
|
(path / x).exists()
|
||||||
for x in {
|
for x in {
|
||||||
"config.json",
|
"config.json",
|
||||||
@ -261,7 +260,6 @@ class ModelInstall(object):
|
|||||||
"pytorch_lora_weights.bin",
|
"pytorch_lora_weights.bin",
|
||||||
"pytorch_lora_weights.safetensors",
|
"pytorch_lora_weights.safetensors",
|
||||||
}
|
}
|
||||||
]
|
|
||||||
):
|
):
|
||||||
models_installed.update({str(model_path_id_or_url): self._install_path(path)})
|
models_installed.update({str(model_path_id_or_url): self._install_path(path)})
|
||||||
|
|
||||||
@ -433,17 +431,17 @@ class ModelInstall(object):
|
|||||||
|
|
||||||
rel_path = self.relative_to_root(path, self.config.models_path)
|
rel_path = self.relative_to_root(path, self.config.models_path)
|
||||||
|
|
||||||
attributes = dict(
|
attributes = {
|
||||||
path=str(rel_path),
|
"path": str(rel_path),
|
||||||
description=str(description),
|
"description": str(description),
|
||||||
model_format=info.format,
|
"model_format": info.format,
|
||||||
)
|
}
|
||||||
legacy_conf = None
|
legacy_conf = None
|
||||||
if info.model_type == ModelType.Main or info.model_type == ModelType.ONNX:
|
if info.model_type == ModelType.Main or info.model_type == ModelType.ONNX:
|
||||||
attributes.update(
|
attributes.update(
|
||||||
dict(
|
{
|
||||||
variant=info.variant_type,
|
"variant": info.variant_type,
|
||||||
)
|
}
|
||||||
)
|
)
|
||||||
if info.format == "checkpoint":
|
if info.format == "checkpoint":
|
||||||
try:
|
try:
|
||||||
@ -474,7 +472,7 @@ class ModelInstall(object):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if legacy_conf:
|
if legacy_conf:
|
||||||
attributes.update(dict(config=str(legacy_conf)))
|
attributes.update({"config": str(legacy_conf)})
|
||||||
return attributes
|
return attributes
|
||||||
|
|
||||||
def relative_to_root(self, path: Path, root: Optional[Path] = None) -> Path:
|
def relative_to_root(self, path: Path, root: Optional[Path] = None) -> Path:
|
||||||
@ -519,7 +517,7 @@ class ModelInstall(object):
|
|||||||
def _download_hf_model(self, repo_id: str, files: List[str], staging: Path, subfolder: None) -> Path:
|
def _download_hf_model(self, repo_id: str, files: List[str], staging: Path, subfolder: None) -> Path:
|
||||||
_, name = repo_id.split("/")
|
_, name = repo_id.split("/")
|
||||||
location = staging / name
|
location = staging / name
|
||||||
paths = list()
|
paths = []
|
||||||
for filename in files:
|
for filename in files:
|
||||||
filePath = Path(filename)
|
filePath = Path(filename)
|
||||||
p = hf_download_with_resume(
|
p = hf_download_with_resume(
|
||||||
|
@ -130,7 +130,9 @@ class IPAttnProcessor2_0(torch.nn.Module):
|
|||||||
assert ip_adapter_image_prompt_embeds is not None
|
assert ip_adapter_image_prompt_embeds is not None
|
||||||
assert len(ip_adapter_image_prompt_embeds) == len(self._weights)
|
assert len(ip_adapter_image_prompt_embeds) == len(self._weights)
|
||||||
|
|
||||||
for ipa_embed, ipa_weights, scale in zip(ip_adapter_image_prompt_embeds, self._weights, self._scales):
|
for ipa_embed, ipa_weights, scale in zip(
|
||||||
|
ip_adapter_image_prompt_embeds, self._weights, self._scales, strict=True
|
||||||
|
):
|
||||||
# The batch dimensions should match.
|
# The batch dimensions should match.
|
||||||
assert ipa_embed.shape[0] == encoder_hidden_states.shape[0]
|
assert ipa_embed.shape[0] == encoder_hidden_states.shape[0]
|
||||||
# The token_len dimensions should match.
|
# The token_len dimensions should match.
|
||||||
|
@ -56,7 +56,7 @@ class PerceiverAttention(nn.Module):
|
|||||||
x = self.norm1(x)
|
x = self.norm1(x)
|
||||||
latents = self.norm2(latents)
|
latents = self.norm2(latents)
|
||||||
|
|
||||||
b, l, _ = latents.shape
|
b, L, _ = latents.shape
|
||||||
|
|
||||||
q = self.to_q(latents)
|
q = self.to_q(latents)
|
||||||
kv_input = torch.cat((x, latents), dim=-2)
|
kv_input = torch.cat((x, latents), dim=-2)
|
||||||
@ -72,7 +72,7 @@ class PerceiverAttention(nn.Module):
|
|||||||
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
|
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
|
||||||
out = weight @ v
|
out = weight @ v
|
||||||
|
|
||||||
out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
|
out = out.permute(0, 2, 1, 3).reshape(b, L, -1)
|
||||||
|
|
||||||
return self.to_out(out)
|
return self.to_out(out)
|
||||||
|
|
||||||
|
@ -269,7 +269,7 @@ def create_unet_diffusers_config(original_config, image_size: int, controlnet=Fa
|
|||||||
resolution *= 2
|
resolution *= 2
|
||||||
|
|
||||||
up_block_types = []
|
up_block_types = []
|
||||||
for i in range(len(block_out_channels)):
|
for _i in range(len(block_out_channels)):
|
||||||
block_type = "CrossAttnUpBlock2D" if resolution in unet_params.attention_resolutions else "UpBlock2D"
|
block_type = "CrossAttnUpBlock2D" if resolution in unet_params.attention_resolutions else "UpBlock2D"
|
||||||
up_block_types.append(block_type)
|
up_block_types.append(block_type)
|
||||||
resolution //= 2
|
resolution //= 2
|
||||||
@ -1223,7 +1223,7 @@ def download_from_original_stable_diffusion_ckpt(
|
|||||||
# scan model
|
# scan model
|
||||||
scan_result = scan_file_path(checkpoint_path)
|
scan_result = scan_file_path(checkpoint_path)
|
||||||
if scan_result.infected_files != 0:
|
if scan_result.infected_files != 0:
|
||||||
raise "The model {checkpoint_path} is potentially infected by malware. Aborting import."
|
raise Exception("The model {checkpoint_path} is potentially infected by malware. Aborting import.")
|
||||||
if device is None:
|
if device is None:
|
||||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
checkpoint = torch.load(checkpoint_path, map_location=device)
|
checkpoint = torch.load(checkpoint_path, map_location=device)
|
||||||
@ -1664,7 +1664,7 @@ def download_controlnet_from_original_ckpt(
|
|||||||
# scan model
|
# scan model
|
||||||
scan_result = scan_file_path(checkpoint_path)
|
scan_result = scan_file_path(checkpoint_path)
|
||||||
if scan_result.infected_files != 0:
|
if scan_result.infected_files != 0:
|
||||||
raise "The model {checkpoint_path} is potentially infected by malware. Aborting import."
|
raise Exception("The model {checkpoint_path} is potentially infected by malware. Aborting import.")
|
||||||
if device is None:
|
if device is None:
|
||||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
checkpoint = torch.load(checkpoint_path, map_location=device)
|
checkpoint = torch.load(checkpoint_path, map_location=device)
|
||||||
|
@ -104,7 +104,7 @@ class ModelPatcher:
|
|||||||
loras: List[Tuple[LoRAModel, float]],
|
loras: List[Tuple[LoRAModel, float]],
|
||||||
prefix: str,
|
prefix: str,
|
||||||
):
|
):
|
||||||
original_weights = dict()
|
original_weights = {}
|
||||||
try:
|
try:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for lora, lora_weight in loras:
|
for lora, lora_weight in loras:
|
||||||
@ -242,7 +242,7 @@ class ModelPatcher:
|
|||||||
):
|
):
|
||||||
skipped_layers = []
|
skipped_layers = []
|
||||||
try:
|
try:
|
||||||
for i in range(clip_skip):
|
for _i in range(clip_skip):
|
||||||
skipped_layers.append(text_encoder.text_model.encoder.layers.pop(-1))
|
skipped_layers.append(text_encoder.text_model.encoder.layers.pop(-1))
|
||||||
|
|
||||||
yield
|
yield
|
||||||
@ -324,7 +324,7 @@ class TextualInversionManager(BaseTextualInversionManager):
|
|||||||
tokenizer: CLIPTokenizer
|
tokenizer: CLIPTokenizer
|
||||||
|
|
||||||
def __init__(self, tokenizer: CLIPTokenizer):
|
def __init__(self, tokenizer: CLIPTokenizer):
|
||||||
self.pad_tokens = dict()
|
self.pad_tokens = {}
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
|
|
||||||
def expand_textual_inversion_token_ids_if_necessary(self, token_ids: list[int]) -> list[int]:
|
def expand_textual_inversion_token_ids_if_necessary(self, token_ids: list[int]) -> list[int]:
|
||||||
@ -385,10 +385,10 @@ class ONNXModelPatcher:
|
|||||||
if not isinstance(model, IAIOnnxRuntimeModel):
|
if not isinstance(model, IAIOnnxRuntimeModel):
|
||||||
raise Exception("Only IAIOnnxRuntimeModel models supported")
|
raise Exception("Only IAIOnnxRuntimeModel models supported")
|
||||||
|
|
||||||
orig_weights = dict()
|
orig_weights = {}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
blended_loras = dict()
|
blended_loras = {}
|
||||||
|
|
||||||
for lora, lora_weight in loras:
|
for lora, lora_weight in loras:
|
||||||
for layer_key, layer in lora.layers.items():
|
for layer_key, layer in lora.layers.items():
|
||||||
@ -404,7 +404,7 @@ class ONNXModelPatcher:
|
|||||||
else:
|
else:
|
||||||
blended_loras[layer_key] = layer_weight
|
blended_loras[layer_key] = layer_weight
|
||||||
|
|
||||||
node_names = dict()
|
node_names = {}
|
||||||
for node in model.nodes.values():
|
for node in model.nodes.values():
|
||||||
node_names[node.name.replace("/", "_").replace(".", "_").lstrip("_")] = node.name
|
node_names[node.name.replace("/", "_").replace(".", "_").lstrip("_")] = node.name
|
||||||
|
|
||||||
|
@ -66,11 +66,13 @@ class CacheStats(object):
|
|||||||
|
|
||||||
class ModelLocker(object):
|
class ModelLocker(object):
|
||||||
"Forward declaration"
|
"Forward declaration"
|
||||||
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class ModelCache(object):
|
class ModelCache(object):
|
||||||
"Forward declaration"
|
"Forward declaration"
|
||||||
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@ -132,7 +134,7 @@ class ModelCache(object):
|
|||||||
snapshots, so it is recommended to disable this feature unless you are actively inspecting the model cache's
|
snapshots, so it is recommended to disable this feature unless you are actively inspecting the model cache's
|
||||||
behaviour.
|
behaviour.
|
||||||
"""
|
"""
|
||||||
self.model_infos: Dict[str, ModelBase] = dict()
|
self.model_infos: Dict[str, ModelBase] = {}
|
||||||
# allow lazy offloading only when vram cache enabled
|
# allow lazy offloading only when vram cache enabled
|
||||||
self.lazy_offloading = lazy_offloading and max_vram_cache_size > 0
|
self.lazy_offloading = lazy_offloading and max_vram_cache_size > 0
|
||||||
self.precision: torch.dtype = precision
|
self.precision: torch.dtype = precision
|
||||||
@ -147,8 +149,8 @@ class ModelCache(object):
|
|||||||
# used for stats collection
|
# used for stats collection
|
||||||
self.stats = None
|
self.stats = None
|
||||||
|
|
||||||
self._cached_models = dict()
|
self._cached_models = {}
|
||||||
self._cache_stack = list()
|
self._cache_stack = []
|
||||||
|
|
||||||
def _capture_memory_snapshot(self) -> Optional[MemorySnapshot]:
|
def _capture_memory_snapshot(self) -> Optional[MemorySnapshot]:
|
||||||
if self._log_memory_usage:
|
if self._log_memory_usage:
|
||||||
|
@ -26,5 +26,5 @@ def skip_torch_weight_init():
|
|||||||
|
|
||||||
yield None
|
yield None
|
||||||
finally:
|
finally:
|
||||||
for torch_module, saved_function in zip(torch_modules, saved_functions):
|
for torch_module, saved_function in zip(torch_modules, saved_functions, strict=True):
|
||||||
torch_module.reset_parameters = saved_function
|
torch_module.reset_parameters = saved_function
|
||||||
|
@ -363,7 +363,7 @@ class ModelManager(object):
|
|||||||
else:
|
else:
|
||||||
return
|
return
|
||||||
|
|
||||||
self.models = dict()
|
self.models = {}
|
||||||
for model_key, model_config in config.items():
|
for model_key, model_config in config.items():
|
||||||
if model_key.startswith("_"):
|
if model_key.startswith("_"):
|
||||||
continue
|
continue
|
||||||
@ -374,7 +374,7 @@ class ModelManager(object):
|
|||||||
self.models[model_key] = model_class.create_config(**model_config)
|
self.models[model_key] = model_class.create_config(**model_config)
|
||||||
|
|
||||||
# check config version number and update on disk/RAM if necessary
|
# check config version number and update on disk/RAM if necessary
|
||||||
self.cache_keys = dict()
|
self.cache_keys = {}
|
||||||
|
|
||||||
# add controlnet, lora and textual_inversion models from disk
|
# add controlnet, lora and textual_inversion models from disk
|
||||||
self.scan_models_directory()
|
self.scan_models_directory()
|
||||||
@ -655,7 +655,7 @@ class ModelManager(object):
|
|||||||
"""
|
"""
|
||||||
# TODO: redo
|
# TODO: redo
|
||||||
for model_dict in self.list_models():
|
for model_dict in self.list_models():
|
||||||
for model_name, model_info in model_dict.items():
|
for _model_name, model_info in model_dict.items():
|
||||||
line = f'{model_info["name"]:25s} {model_info["type"]:10s} {model_info["description"]}'
|
line = f'{model_info["name"]:25s} {model_info["type"]:10s} {model_info["description"]}'
|
||||||
print(line)
|
print(line)
|
||||||
|
|
||||||
@ -902,7 +902,7 @@ class ModelManager(object):
|
|||||||
"""
|
"""
|
||||||
Write current configuration out to the indicated file.
|
Write current configuration out to the indicated file.
|
||||||
"""
|
"""
|
||||||
data_to_save = dict()
|
data_to_save = {}
|
||||||
data_to_save["__metadata__"] = self.config_meta.model_dump()
|
data_to_save["__metadata__"] = self.config_meta.model_dump()
|
||||||
|
|
||||||
for model_key, model_config in self.models.items():
|
for model_key, model_config in self.models.items():
|
||||||
@ -1034,7 +1034,7 @@ class ModelManager(object):
|
|||||||
self.ignore = ignore
|
self.ignore = ignore
|
||||||
|
|
||||||
def on_search_started(self):
|
def on_search_started(self):
|
||||||
self.new_models_found = dict()
|
self.new_models_found = {}
|
||||||
|
|
||||||
def on_model_found(self, model: Path):
|
def on_model_found(self, model: Path):
|
||||||
if model not in self.ignore:
|
if model not in self.ignore:
|
||||||
@ -1106,7 +1106,7 @@ class ModelManager(object):
|
|||||||
# avoid circular import here
|
# avoid circular import here
|
||||||
from invokeai.backend.install.model_install_backend import ModelInstall
|
from invokeai.backend.install.model_install_backend import ModelInstall
|
||||||
|
|
||||||
successfully_installed = dict()
|
successfully_installed = {}
|
||||||
|
|
||||||
installer = ModelInstall(
|
installer = ModelInstall(
|
||||||
config=self.app_config, prediction_type_helper=prediction_type_helper, model_manager=self
|
config=self.app_config, prediction_type_helper=prediction_type_helper, model_manager=self
|
||||||
|
@ -92,7 +92,7 @@ class ModelMerger(object):
|
|||||||
**kwargs - the default DiffusionPipeline.get_config_dict kwargs:
|
**kwargs - the default DiffusionPipeline.get_config_dict kwargs:
|
||||||
cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map
|
cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map
|
||||||
"""
|
"""
|
||||||
model_paths = list()
|
model_paths = []
|
||||||
config = self.manager.app_config
|
config = self.manager.app_config
|
||||||
base_model = BaseModelType(base_model)
|
base_model = BaseModelType(base_model)
|
||||||
vae = None
|
vae = None
|
||||||
@ -124,13 +124,13 @@ class ModelMerger(object):
|
|||||||
dump_path = (dump_path / merged_model_name).as_posix()
|
dump_path = (dump_path / merged_model_name).as_posix()
|
||||||
|
|
||||||
merged_pipe.save_pretrained(dump_path, safe_serialization=True)
|
merged_pipe.save_pretrained(dump_path, safe_serialization=True)
|
||||||
attributes = dict(
|
attributes = {
|
||||||
path=dump_path,
|
"path": dump_path,
|
||||||
description=f"Merge of models {', '.join(model_names)}",
|
"description": f"Merge of models {', '.join(model_names)}",
|
||||||
model_format="diffusers",
|
"model_format": "diffusers",
|
||||||
variant=ModelVariantType.Normal.value,
|
"variant": ModelVariantType.Normal.value,
|
||||||
vae=vae,
|
"vae": vae,
|
||||||
)
|
}
|
||||||
return self.manager.add_model(
|
return self.manager.add_model(
|
||||||
merged_model_name,
|
merged_model_name,
|
||||||
base_model=base_model,
|
base_model=base_model,
|
||||||
|
@ -237,7 +237,7 @@ class ModelProbe(object):
|
|||||||
# scan model
|
# scan model
|
||||||
scan_result = scan_file_path(checkpoint)
|
scan_result = scan_file_path(checkpoint)
|
||||||
if scan_result.infected_files != 0:
|
if scan_result.infected_files != 0:
|
||||||
raise "The model {model_name} is potentially infected by malware. Aborting import."
|
raise Exception("The model {model_name} is potentially infected by malware. Aborting import.")
|
||||||
|
|
||||||
|
|
||||||
# ##################################################3
|
# ##################################################3
|
||||||
|
@ -59,7 +59,7 @@ class ModelSearch(ABC):
|
|||||||
for root, dirs, files in os.walk(path, followlinks=True):
|
for root, dirs, files in os.walk(path, followlinks=True):
|
||||||
if str(Path(root).name).startswith("."):
|
if str(Path(root).name).startswith("."):
|
||||||
self._pruned_paths.add(root)
|
self._pruned_paths.add(root)
|
||||||
if any([Path(root).is_relative_to(x) for x in self._pruned_paths]):
|
if any(Path(root).is_relative_to(x) for x in self._pruned_paths):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
self._items_scanned += len(dirs) + len(files)
|
self._items_scanned += len(dirs) + len(files)
|
||||||
@ -69,7 +69,6 @@ class ModelSearch(ABC):
|
|||||||
self._scanned_dirs.add(path)
|
self._scanned_dirs.add(path)
|
||||||
continue
|
continue
|
||||||
if any(
|
if any(
|
||||||
[
|
|
||||||
(path / x).exists()
|
(path / x).exists()
|
||||||
for x in {
|
for x in {
|
||||||
"config.json",
|
"config.json",
|
||||||
@ -78,7 +77,6 @@ class ModelSearch(ABC):
|
|||||||
"pytorch_lora_weights.bin",
|
"pytorch_lora_weights.bin",
|
||||||
"image_encoder.txt",
|
"image_encoder.txt",
|
||||||
}
|
}
|
||||||
]
|
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
self.on_model_found(path)
|
self.on_model_found(path)
|
||||||
|
@ -97,8 +97,8 @@ MODEL_CLASSES = {
|
|||||||
# },
|
# },
|
||||||
}
|
}
|
||||||
|
|
||||||
MODEL_CONFIGS = list()
|
MODEL_CONFIGS = []
|
||||||
OPENAPI_MODEL_CONFIGS = list()
|
OPENAPI_MODEL_CONFIGS = []
|
||||||
|
|
||||||
|
|
||||||
class OpenAPIModelInfoBase(BaseModel):
|
class OpenAPIModelInfoBase(BaseModel):
|
||||||
@ -109,7 +109,7 @@ class OpenAPIModelInfoBase(BaseModel):
|
|||||||
model_config = ConfigDict(protected_namespaces=())
|
model_config = ConfigDict(protected_namespaces=())
|
||||||
|
|
||||||
|
|
||||||
for base_model, models in MODEL_CLASSES.items():
|
for _base_model, models in MODEL_CLASSES.items():
|
||||||
for model_type, model_class in models.items():
|
for model_type, model_class in models.items():
|
||||||
model_configs = set(model_class._get_configs().values())
|
model_configs = set(model_class._get_configs().values())
|
||||||
model_configs.discard(None)
|
model_configs.discard(None)
|
||||||
@ -133,7 +133,7 @@ for base_model, models in MODEL_CLASSES.items():
|
|||||||
|
|
||||||
|
|
||||||
def get_model_config_enums():
|
def get_model_config_enums():
|
||||||
enums = list()
|
enums = []
|
||||||
|
|
||||||
for model_config in MODEL_CONFIGS:
|
for model_config in MODEL_CONFIGS:
|
||||||
if hasattr(inspect, "get_annotations"):
|
if hasattr(inspect, "get_annotations"):
|
||||||
|
@ -153,7 +153,7 @@ class ModelBase(metaclass=ABCMeta):
|
|||||||
|
|
||||||
else:
|
else:
|
||||||
res_type = sys.modules["diffusers"]
|
res_type = sys.modules["diffusers"]
|
||||||
res_type = getattr(res_type, "pipelines")
|
res_type = res_type.pipelines
|
||||||
|
|
||||||
for subtype in subtypes:
|
for subtype in subtypes:
|
||||||
res_type = getattr(res_type, subtype)
|
res_type = getattr(res_type, subtype)
|
||||||
@ -164,7 +164,7 @@ class ModelBase(metaclass=ABCMeta):
|
|||||||
with suppress(Exception):
|
with suppress(Exception):
|
||||||
return cls.__configs
|
return cls.__configs
|
||||||
|
|
||||||
configs = dict()
|
configs = {}
|
||||||
for name in dir(cls):
|
for name in dir(cls):
|
||||||
if name.startswith("__"):
|
if name.startswith("__"):
|
||||||
continue
|
continue
|
||||||
@ -246,8 +246,8 @@ class DiffusersModel(ModelBase):
|
|||||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||||
super().__init__(model_path, base_model, model_type)
|
super().__init__(model_path, base_model, model_type)
|
||||||
|
|
||||||
self.child_types: Dict[str, Type] = dict()
|
self.child_types: Dict[str, Type] = {}
|
||||||
self.child_sizes: Dict[str, int] = dict()
|
self.child_sizes: Dict[str, int] = {}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
config_data = DiffusionPipeline.load_config(self.model_path)
|
config_data = DiffusionPipeline.load_config(self.model_path)
|
||||||
@ -326,8 +326,8 @@ def calc_model_size_by_fs(model_path: str, subfolder: Optional[str] = None, vari
|
|||||||
all_files = os.listdir(model_path)
|
all_files = os.listdir(model_path)
|
||||||
all_files = [f for f in all_files if os.path.isfile(os.path.join(model_path, f))]
|
all_files = [f for f in all_files if os.path.isfile(os.path.join(model_path, f))]
|
||||||
|
|
||||||
fp16_files = set([f for f in all_files if ".fp16." in f or ".fp16-" in f])
|
fp16_files = {f for f in all_files if ".fp16." in f or ".fp16-" in f}
|
||||||
bit8_files = set([f for f in all_files if ".8bit." in f or ".8bit-" in f])
|
bit8_files = {f for f in all_files if ".8bit." in f or ".8bit-" in f}
|
||||||
other_files = set(all_files) - fp16_files - bit8_files
|
other_files = set(all_files) - fp16_files - bit8_files
|
||||||
|
|
||||||
if variant is None:
|
if variant is None:
|
||||||
@ -413,7 +413,7 @@ def _calc_onnx_model_by_data(model) -> int:
|
|||||||
|
|
||||||
|
|
||||||
def _fast_safetensors_reader(path: str):
|
def _fast_safetensors_reader(path: str):
|
||||||
checkpoint = dict()
|
checkpoint = {}
|
||||||
device = torch.device("meta")
|
device = torch.device("meta")
|
||||||
with open(path, "rb") as f:
|
with open(path, "rb") as f:
|
||||||
definition_len = int.from_bytes(f.read(8), "little")
|
definition_len = int.from_bytes(f.read(8), "little")
|
||||||
@ -483,7 +483,7 @@ class IAIOnnxRuntimeModel:
|
|||||||
class _tensor_access:
|
class _tensor_access:
|
||||||
def __init__(self, model):
|
def __init__(self, model):
|
||||||
self.model = model
|
self.model = model
|
||||||
self.indexes = dict()
|
self.indexes = {}
|
||||||
for idx, obj in enumerate(self.model.proto.graph.initializer):
|
for idx, obj in enumerate(self.model.proto.graph.initializer):
|
||||||
self.indexes[obj.name] = idx
|
self.indexes[obj.name] = idx
|
||||||
|
|
||||||
@ -524,7 +524,7 @@ class IAIOnnxRuntimeModel:
|
|||||||
|
|
||||||
class _access_helper:
|
class _access_helper:
|
||||||
def __init__(self, raw_proto):
|
def __init__(self, raw_proto):
|
||||||
self.indexes = dict()
|
self.indexes = {}
|
||||||
self.raw_proto = raw_proto
|
self.raw_proto = raw_proto
|
||||||
for idx, obj in enumerate(raw_proto):
|
for idx, obj in enumerate(raw_proto):
|
||||||
self.indexes[obj.name] = idx
|
self.indexes[obj.name] = idx
|
||||||
@ -549,7 +549,7 @@ class IAIOnnxRuntimeModel:
|
|||||||
return self.indexes.keys()
|
return self.indexes.keys()
|
||||||
|
|
||||||
def values(self):
|
def values(self):
|
||||||
return [obj for obj in self.raw_proto]
|
return list(self.raw_proto)
|
||||||
|
|
||||||
def __init__(self, model_path: str, provider: Optional[str]):
|
def __init__(self, model_path: str, provider: Optional[str]):
|
||||||
self.path = model_path
|
self.path = model_path
|
||||||
|
@ -104,7 +104,7 @@ class ControlNetModel(ModelBase):
|
|||||||
return ControlNetModelFormat.Diffusers
|
return ControlNetModelFormat.Diffusers
|
||||||
|
|
||||||
if os.path.isfile(path):
|
if os.path.isfile(path):
|
||||||
if any([path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt", "pth"]]):
|
if any(path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt", "pth"]):
|
||||||
return ControlNetModelFormat.Checkpoint
|
return ControlNetModelFormat.Checkpoint
|
||||||
|
|
||||||
raise InvalidModelException(f"Not a valid model: {path}")
|
raise InvalidModelException(f"Not a valid model: {path}")
|
||||||
|
@ -73,7 +73,7 @@ class LoRAModel(ModelBase):
|
|||||||
return LoRAModelFormat.Diffusers
|
return LoRAModelFormat.Diffusers
|
||||||
|
|
||||||
if os.path.isfile(path):
|
if os.path.isfile(path):
|
||||||
if any([path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]]):
|
if any(path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]):
|
||||||
return LoRAModelFormat.LyCORIS
|
return LoRAModelFormat.LyCORIS
|
||||||
|
|
||||||
raise InvalidModelException(f"Not a valid model: {path}")
|
raise InvalidModelException(f"Not a valid model: {path}")
|
||||||
@ -462,7 +462,7 @@ class LoRAModelRaw: # (torch.nn.Module):
|
|||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
):
|
):
|
||||||
# TODO: try revert if exception?
|
# TODO: try revert if exception?
|
||||||
for key, layer in self.layers.items():
|
for _key, layer in self.layers.items():
|
||||||
layer.to(device=device, dtype=dtype)
|
layer.to(device=device, dtype=dtype)
|
||||||
|
|
||||||
def calc_size(self) -> int:
|
def calc_size(self) -> int:
|
||||||
@ -499,7 +499,7 @@ class LoRAModelRaw: # (torch.nn.Module):
|
|||||||
stability_unet_keys = list(SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP)
|
stability_unet_keys = list(SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP)
|
||||||
stability_unet_keys.sort()
|
stability_unet_keys.sort()
|
||||||
|
|
||||||
new_state_dict = dict()
|
new_state_dict = {}
|
||||||
for full_key, value in state_dict.items():
|
for full_key, value in state_dict.items():
|
||||||
if full_key.startswith("lora_unet_"):
|
if full_key.startswith("lora_unet_"):
|
||||||
search_key = full_key.replace("lora_unet_", "")
|
search_key = full_key.replace("lora_unet_", "")
|
||||||
@ -545,7 +545,7 @@ class LoRAModelRaw: # (torch.nn.Module):
|
|||||||
|
|
||||||
model = cls(
|
model = cls(
|
||||||
name=file_path.stem, # TODO:
|
name=file_path.stem, # TODO:
|
||||||
layers=dict(),
|
layers={},
|
||||||
)
|
)
|
||||||
|
|
||||||
if file_path.suffix == ".safetensors":
|
if file_path.suffix == ".safetensors":
|
||||||
@ -593,12 +593,12 @@ class LoRAModelRaw: # (torch.nn.Module):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _group_state(state_dict: dict):
|
def _group_state(state_dict: dict):
|
||||||
state_dict_groupped = dict()
|
state_dict_groupped = {}
|
||||||
|
|
||||||
for key, value in state_dict.items():
|
for key, value in state_dict.items():
|
||||||
stem, leaf = key.split(".", 1)
|
stem, leaf = key.split(".", 1)
|
||||||
if stem not in state_dict_groupped:
|
if stem not in state_dict_groupped:
|
||||||
state_dict_groupped[stem] = dict()
|
state_dict_groupped[stem] = {}
|
||||||
state_dict_groupped[stem][leaf] = value
|
state_dict_groupped[stem][leaf] = value
|
||||||
|
|
||||||
return state_dict_groupped
|
return state_dict_groupped
|
||||||
|
@ -110,7 +110,7 @@ class StableDiffusion1Model(DiffusersModel):
|
|||||||
return StableDiffusion1ModelFormat.Diffusers
|
return StableDiffusion1ModelFormat.Diffusers
|
||||||
|
|
||||||
if os.path.isfile(model_path):
|
if os.path.isfile(model_path):
|
||||||
if any([model_path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]]):
|
if any(model_path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]):
|
||||||
return StableDiffusion1ModelFormat.Checkpoint
|
return StableDiffusion1ModelFormat.Checkpoint
|
||||||
|
|
||||||
raise InvalidModelException(f"Not a valid model: {model_path}")
|
raise InvalidModelException(f"Not a valid model: {model_path}")
|
||||||
@ -221,7 +221,7 @@ class StableDiffusion2Model(DiffusersModel):
|
|||||||
return StableDiffusion2ModelFormat.Diffusers
|
return StableDiffusion2ModelFormat.Diffusers
|
||||||
|
|
||||||
if os.path.isfile(model_path):
|
if os.path.isfile(model_path):
|
||||||
if any([model_path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]]):
|
if any(model_path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]):
|
||||||
return StableDiffusion2ModelFormat.Checkpoint
|
return StableDiffusion2ModelFormat.Checkpoint
|
||||||
|
|
||||||
raise InvalidModelException(f"Not a valid model: {model_path}")
|
raise InvalidModelException(f"Not a valid model: {model_path}")
|
||||||
|
@ -71,7 +71,7 @@ class TextualInversionModel(ModelBase):
|
|||||||
return None # diffusers-ti
|
return None # diffusers-ti
|
||||||
|
|
||||||
if os.path.isfile(path):
|
if os.path.isfile(path):
|
||||||
if any([path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt", "bin"]]):
|
if any(path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt", "bin"]):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
raise InvalidModelException(f"Not a valid model: {path}")
|
raise InvalidModelException(f"Not a valid model: {path}")
|
||||||
|
@ -89,7 +89,7 @@ class VaeModel(ModelBase):
|
|||||||
return VaeModelFormat.Diffusers
|
return VaeModelFormat.Diffusers
|
||||||
|
|
||||||
if os.path.isfile(path):
|
if os.path.isfile(path):
|
||||||
if any([path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]]):
|
if any(path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]):
|
||||||
return VaeModelFormat.Checkpoint
|
return VaeModelFormat.Checkpoint
|
||||||
|
|
||||||
raise InvalidModelException(f"Not a valid model: {path}")
|
raise InvalidModelException(f"Not a valid model: {path}")
|
||||||
|
@ -114,8 +114,10 @@ class ModelConfigBase(BaseModel):
|
|||||||
current_hash: Optional[str] = Field(
|
current_hash: Optional[str] = Field(
|
||||||
description="current fasthash of model contents", default=None
|
description="current fasthash of model contents", default=None
|
||||||
) # if model is converted or otherwise modified, this will hold updated hash
|
) # if model is converted or otherwise modified, this will hold updated hash
|
||||||
description: Optional[str] = Field(None)
|
description: Optional[str] = Field(default=None)
|
||||||
source: Optional[str] = Field(description="Model download source (URL or repo_id)", default=None)
|
source: Optional[str] = Field(
|
||||||
|
description="Model download source (URL or repo_id)", default=None
|
||||||
|
)
|
||||||
|
|
||||||
model_config = ConfigDict(
|
model_config = ConfigDict(
|
||||||
use_enum_values=False,
|
use_enum_values=False,
|
||||||
@ -249,12 +251,19 @@ class T2IConfig(ModelConfigBase):
|
|||||||
format: Literal[ModelFormat.Diffusers]
|
format: Literal[ModelFormat.Diffusers]
|
||||||
|
|
||||||
|
|
||||||
_ONNXConfig = Annotated[Union[ONNXSD1Config, ONNXSD2Config], Field(discriminator="base")]
|
_ONNXConfig = Annotated[
|
||||||
_ControlNetConfig = Annotated[
|
Union[ONNXSD1Config, ONNXSD2Config], Field(discriminator="base")
|
||||||
Union[ControlNetDiffusersConfig, ControlNetCheckpointConfig], Field(discriminator="format")
|
]
|
||||||
|
_ControlNetConfig = Annotated[
|
||||||
|
Union[ControlNetDiffusersConfig, ControlNetCheckpointConfig],
|
||||||
|
Field(discriminator="format"),
|
||||||
|
]
|
||||||
|
_VaeConfig = Annotated[
|
||||||
|
Union[VaeDiffusersConfig, VaeCheckpointConfig], Field(discriminator="format")
|
||||||
|
]
|
||||||
|
_MainModelConfig = Annotated[
|
||||||
|
Union[MainDiffusersConfig, MainCheckpointConfig], Field(discriminator="format")
|
||||||
]
|
]
|
||||||
_VaeConfig = Annotated[Union[VaeDiffusersConfig, VaeCheckpointConfig], Field(discriminator="format")]
|
|
||||||
_MainModelConfig = Annotated[Union[MainDiffusersConfig, MainCheckpointConfig], Field(discriminator="format")]
|
|
||||||
|
|
||||||
AnyModelConfig = Union[
|
AnyModelConfig = Union[
|
||||||
_MainModelConfig,
|
_MainModelConfig,
|
||||||
|
@ -49,7 +49,7 @@ class FastModelHash(object):
|
|||||||
def _hash_dir(cls, model_location: Union[str, Path]) -> str:
|
def _hash_dir(cls, model_location: Union[str, Path]) -> str:
|
||||||
components: Dict[str, str] = {}
|
components: Dict[str, str] = {}
|
||||||
|
|
||||||
for root, dirs, files in os.walk(model_location):
|
for root, _dirs, files in os.walk(model_location):
|
||||||
for file in files:
|
for file in files:
|
||||||
# only tally tensor files because diffusers config files change slightly
|
# only tally tensor files because diffusers config files change slightly
|
||||||
# depending on how the model was downloaded/converted.
|
# depending on how the model was downloaded/converted.
|
||||||
@ -61,6 +61,6 @@ class FastModelHash(object):
|
|||||||
|
|
||||||
# hash all the model hashes together, using alphabetic file order
|
# hash all the model hashes together, using alphabetic file order
|
||||||
md5 = hashlib.md5()
|
md5 = hashlib.md5()
|
||||||
for path, fast_hash in sorted(components.items()):
|
for _path, fast_hash in sorted(components.items()):
|
||||||
md5.update(fast_hash.encode("utf-8"))
|
md5.update(fast_hash.encode("utf-8"))
|
||||||
return md5.hexdigest()
|
return md5.hexdigest()
|
||||||
|
@ -7,9 +7,16 @@ from omegaconf import DictConfig, OmegaConf
|
|||||||
from pydantic import TypeAdapter
|
from pydantic import TypeAdapter
|
||||||
|
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
from invokeai.app.services.model_records import DuplicateModelException, ModelRecordServiceSQL
|
from invokeai.app.services.model_records import (
|
||||||
|
DuplicateModelException,
|
||||||
|
ModelRecordServiceSQL,
|
||||||
|
)
|
||||||
from invokeai.app.services.shared.sqlite import SqliteDatabase
|
from invokeai.app.services.shared.sqlite import SqliteDatabase
|
||||||
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelType
|
from invokeai.backend.model_manager.config import (
|
||||||
|
AnyModelConfig,
|
||||||
|
BaseModelType,
|
||||||
|
ModelType,
|
||||||
|
)
|
||||||
from invokeai.backend.model_manager.hash import FastModelHash
|
from invokeai.backend.model_manager.hash import FastModelHash
|
||||||
from invokeai.backend.util.logging import InvokeAILogger
|
from invokeai.backend.util.logging import InvokeAILogger
|
||||||
|
|
||||||
|
@ -193,6 +193,7 @@ class InvokeAIStableDiffusionPipelineOutput(StableDiffusionPipelineOutput):
|
|||||||
attention_map_saver (`AttentionMapSaver`): Object containing attention maps that can be displayed to the user
|
attention_map_saver (`AttentionMapSaver`): Object containing attention maps that can be displayed to the user
|
||||||
after generation completes. Optional.
|
after generation completes. Optional.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
attention_map_saver: Optional[AttentionMapSaver]
|
attention_map_saver: Optional[AttentionMapSaver]
|
||||||
|
|
||||||
|
|
||||||
|
@ -54,13 +54,13 @@ class Context:
|
|||||||
self.clear_requests(cleanup=True)
|
self.clear_requests(cleanup=True)
|
||||||
|
|
||||||
def register_cross_attention_modules(self, model):
|
def register_cross_attention_modules(self, model):
|
||||||
for name, module in get_cross_attention_modules(model, CrossAttentionType.SELF):
|
for name, _module in get_cross_attention_modules(model, CrossAttentionType.SELF):
|
||||||
if name in self.self_cross_attention_module_identifiers:
|
if name in self.self_cross_attention_module_identifiers:
|
||||||
assert False, f"name {name} cannot appear more than once"
|
raise AssertionError(f"name {name} cannot appear more than once")
|
||||||
self.self_cross_attention_module_identifiers.append(name)
|
self.self_cross_attention_module_identifiers.append(name)
|
||||||
for name, module in get_cross_attention_modules(model, CrossAttentionType.TOKENS):
|
for name, _module in get_cross_attention_modules(model, CrossAttentionType.TOKENS):
|
||||||
if name in self.tokens_cross_attention_module_identifiers:
|
if name in self.tokens_cross_attention_module_identifiers:
|
||||||
assert False, f"name {name} cannot appear more than once"
|
raise AssertionError(f"name {name} cannot appear more than once")
|
||||||
self.tokens_cross_attention_module_identifiers.append(name)
|
self.tokens_cross_attention_module_identifiers.append(name)
|
||||||
|
|
||||||
def request_save_attention_maps(self, cross_attention_type: CrossAttentionType):
|
def request_save_attention_maps(self, cross_attention_type: CrossAttentionType):
|
||||||
@ -170,7 +170,7 @@ class Context:
|
|||||||
self.saved_cross_attention_maps = {}
|
self.saved_cross_attention_maps = {}
|
||||||
|
|
||||||
def offload_saved_attention_slices_to_cpu(self):
|
def offload_saved_attention_slices_to_cpu(self):
|
||||||
for key, map_dict in self.saved_cross_attention_maps.items():
|
for _key, map_dict in self.saved_cross_attention_maps.items():
|
||||||
for offset, slice in map_dict["slices"].items():
|
for offset, slice in map_dict["slices"].items():
|
||||||
map_dict[offset] = slice.to("cpu")
|
map_dict[offset] = slice.to("cpu")
|
||||||
|
|
||||||
@ -433,7 +433,7 @@ def inject_attention_function(unet, context: Context):
|
|||||||
module.identifier = identifier
|
module.identifier = identifier
|
||||||
try:
|
try:
|
||||||
module.set_attention_slice_wrangler(attention_slice_wrangler)
|
module.set_attention_slice_wrangler(attention_slice_wrangler)
|
||||||
module.set_slicing_strategy_getter(lambda module: context.get_slicing_strategy(identifier))
|
module.set_slicing_strategy_getter(lambda module: context.get_slicing_strategy(identifier)) # noqa: B023
|
||||||
except AttributeError as e:
|
except AttributeError as e:
|
||||||
if is_attribute_error_about(e, "set_attention_slice_wrangler"):
|
if is_attribute_error_about(e, "set_attention_slice_wrangler"):
|
||||||
print(f"TODO: implement set_attention_slice_wrangler for {type(module)}") # TODO
|
print(f"TODO: implement set_attention_slice_wrangler for {type(module)}") # TODO
|
||||||
@ -445,7 +445,7 @@ def remove_attention_function(unet):
|
|||||||
cross_attention_modules = get_cross_attention_modules(
|
cross_attention_modules = get_cross_attention_modules(
|
||||||
unet, CrossAttentionType.TOKENS
|
unet, CrossAttentionType.TOKENS
|
||||||
) + get_cross_attention_modules(unet, CrossAttentionType.SELF)
|
) + get_cross_attention_modules(unet, CrossAttentionType.SELF)
|
||||||
for identifier, module in cross_attention_modules:
|
for _identifier, module in cross_attention_modules:
|
||||||
try:
|
try:
|
||||||
# clear wrangler callback
|
# clear wrangler callback
|
||||||
module.set_attention_slice_wrangler(None)
|
module.set_attention_slice_wrangler(None)
|
||||||
|
@ -56,7 +56,7 @@ class AttentionMapSaver:
|
|||||||
|
|
||||||
merged = None
|
merged = None
|
||||||
|
|
||||||
for key, maps in self.collated_maps.items():
|
for _key, maps in self.collated_maps.items():
|
||||||
# maps has shape [(H*W), N] for N tokens
|
# maps has shape [(H*W), N] for N tokens
|
||||||
# but we want [N, H, W]
|
# but we want [N, H, W]
|
||||||
this_scale_factor = math.sqrt(maps.shape[0] / (latents_width * latents_height))
|
this_scale_factor = math.sqrt(maps.shape[0] / (latents_width * latents_height))
|
||||||
|
@ -123,7 +123,7 @@ class InvokeAIDiffuserComponent:
|
|||||||
# control_data should be type List[ControlNetData]
|
# control_data should be type List[ControlNetData]
|
||||||
# this loop covers both ControlNet (one ControlNetData in list)
|
# this loop covers both ControlNet (one ControlNetData in list)
|
||||||
# and MultiControlNet (multiple ControlNetData in list)
|
# and MultiControlNet (multiple ControlNetData in list)
|
||||||
for i, control_datum in enumerate(control_data):
|
for _i, control_datum in enumerate(control_data):
|
||||||
control_mode = control_datum.control_mode
|
control_mode = control_datum.control_mode
|
||||||
# soft_injection and cfg_injection are the two ControlNet control_mode booleans
|
# soft_injection and cfg_injection are the two ControlNet control_mode booleans
|
||||||
# that are combined at higher level to make control_mode enum
|
# that are combined at higher level to make control_mode enum
|
||||||
@ -214,7 +214,7 @@ class InvokeAIDiffuserComponent:
|
|||||||
# add controlnet outputs together if have multiple controlnets
|
# add controlnet outputs together if have multiple controlnets
|
||||||
down_block_res_samples = [
|
down_block_res_samples = [
|
||||||
samples_prev + samples_curr
|
samples_prev + samples_curr
|
||||||
for samples_prev, samples_curr in zip(down_block_res_samples, down_samples)
|
for samples_prev, samples_curr in zip(down_block_res_samples, down_samples, strict=True)
|
||||||
]
|
]
|
||||||
mid_block_res_sample += mid_sample
|
mid_block_res_sample += mid_sample
|
||||||
|
|
||||||
@ -642,7 +642,9 @@ class InvokeAIDiffuserComponent:
|
|||||||
|
|
||||||
deltas = None
|
deltas = None
|
||||||
uncond_latents = None
|
uncond_latents = None
|
||||||
weighted_cond_list = c_or_weighted_c_list if type(c_or_weighted_c_list) is list else [(c_or_weighted_c_list, 1)]
|
weighted_cond_list = (
|
||||||
|
c_or_weighted_c_list if isinstance(c_or_weighted_c_list, list) else [(c_or_weighted_c_list, 1)]
|
||||||
|
)
|
||||||
|
|
||||||
# below is fugly omg
|
# below is fugly omg
|
||||||
conditionings = [uc] + [c for c, weight in weighted_cond_list]
|
conditionings = [uc] + [c for c, weight in weighted_cond_list]
|
||||||
|
@ -16,28 +16,28 @@ from diffusers import (
|
|||||||
UniPCMultistepScheduler,
|
UniPCMultistepScheduler,
|
||||||
)
|
)
|
||||||
|
|
||||||
SCHEDULER_MAP = dict(
|
SCHEDULER_MAP = {
|
||||||
ddim=(DDIMScheduler, dict()),
|
"ddim": (DDIMScheduler, {}),
|
||||||
ddpm=(DDPMScheduler, dict()),
|
"ddpm": (DDPMScheduler, {}),
|
||||||
deis=(DEISMultistepScheduler, dict()),
|
"deis": (DEISMultistepScheduler, {}),
|
||||||
lms=(LMSDiscreteScheduler, dict(use_karras_sigmas=False)),
|
"lms": (LMSDiscreteScheduler, {"use_karras_sigmas": False}),
|
||||||
lms_k=(LMSDiscreteScheduler, dict(use_karras_sigmas=True)),
|
"lms_k": (LMSDiscreteScheduler, {"use_karras_sigmas": True}),
|
||||||
pndm=(PNDMScheduler, dict()),
|
"pndm": (PNDMScheduler, {}),
|
||||||
heun=(HeunDiscreteScheduler, dict(use_karras_sigmas=False)),
|
"heun": (HeunDiscreteScheduler, {"use_karras_sigmas": False}),
|
||||||
heun_k=(HeunDiscreteScheduler, dict(use_karras_sigmas=True)),
|
"heun_k": (HeunDiscreteScheduler, {"use_karras_sigmas": True}),
|
||||||
euler=(EulerDiscreteScheduler, dict(use_karras_sigmas=False)),
|
"euler": (EulerDiscreteScheduler, {"use_karras_sigmas": False}),
|
||||||
euler_k=(EulerDiscreteScheduler, dict(use_karras_sigmas=True)),
|
"euler_k": (EulerDiscreteScheduler, {"use_karras_sigmas": True}),
|
||||||
euler_a=(EulerAncestralDiscreteScheduler, dict()),
|
"euler_a": (EulerAncestralDiscreteScheduler, {}),
|
||||||
kdpm_2=(KDPM2DiscreteScheduler, dict()),
|
"kdpm_2": (KDPM2DiscreteScheduler, {}),
|
||||||
kdpm_2_a=(KDPM2AncestralDiscreteScheduler, dict()),
|
"kdpm_2_a": (KDPM2AncestralDiscreteScheduler, {}),
|
||||||
dpmpp_2s=(DPMSolverSinglestepScheduler, dict(use_karras_sigmas=False)),
|
"dpmpp_2s": (DPMSolverSinglestepScheduler, {"use_karras_sigmas": False}),
|
||||||
dpmpp_2s_k=(DPMSolverSinglestepScheduler, dict(use_karras_sigmas=True)),
|
"dpmpp_2s_k": (DPMSolverSinglestepScheduler, {"use_karras_sigmas": True}),
|
||||||
dpmpp_2m=(DPMSolverMultistepScheduler, dict(use_karras_sigmas=False)),
|
"dpmpp_2m": (DPMSolverMultistepScheduler, {"use_karras_sigmas": False}),
|
||||||
dpmpp_2m_k=(DPMSolverMultistepScheduler, dict(use_karras_sigmas=True)),
|
"dpmpp_2m_k": (DPMSolverMultistepScheduler, {"use_karras_sigmas": True}),
|
||||||
dpmpp_2m_sde=(DPMSolverMultistepScheduler, dict(use_karras_sigmas=False, algorithm_type="sde-dpmsolver++")),
|
"dpmpp_2m_sde": (DPMSolverMultistepScheduler, {"use_karras_sigmas": False, "algorithm_type": "sde-dpmsolver++"}),
|
||||||
dpmpp_2m_sde_k=(DPMSolverMultistepScheduler, dict(use_karras_sigmas=True, algorithm_type="sde-dpmsolver++")),
|
"dpmpp_2m_sde_k": (DPMSolverMultistepScheduler, {"use_karras_sigmas": True, "algorithm_type": "sde-dpmsolver++"}),
|
||||||
dpmpp_sde=(DPMSolverSDEScheduler, dict(use_karras_sigmas=False, noise_sampler_seed=0)),
|
"dpmpp_sde": (DPMSolverSDEScheduler, {"use_karras_sigmas": False, "noise_sampler_seed": 0}),
|
||||||
dpmpp_sde_k=(DPMSolverSDEScheduler, dict(use_karras_sigmas=True, noise_sampler_seed=0)),
|
"dpmpp_sde_k": (DPMSolverSDEScheduler, {"use_karras_sigmas": True, "noise_sampler_seed": 0}),
|
||||||
unipc=(UniPCMultistepScheduler, dict(cpu_only=True)),
|
"unipc": (UniPCMultistepScheduler, {"cpu_only": True}),
|
||||||
lcm=(LCMScheduler, dict()),
|
"lcm": (LCMScheduler, {}),
|
||||||
)
|
}
|
||||||
|
@ -615,7 +615,7 @@ def do_textual_inversion_training(
|
|||||||
vae_info = model_manager.get_model(*model_meta, submodel=SubModelType.Vae)
|
vae_info = model_manager.get_model(*model_meta, submodel=SubModelType.Vae)
|
||||||
unet_info = model_manager.get_model(*model_meta, submodel=SubModelType.UNet)
|
unet_info = model_manager.get_model(*model_meta, submodel=SubModelType.UNet)
|
||||||
|
|
||||||
pipeline_args = dict(local_files_only=True)
|
pipeline_args = {"local_files_only": True}
|
||||||
if tokenizer_name:
|
if tokenizer_name:
|
||||||
tokenizer = CLIPTokenizer.from_pretrained(tokenizer_name, **pipeline_args)
|
tokenizer = CLIPTokenizer.from_pretrained(tokenizer_name, **pipeline_args)
|
||||||
else:
|
else:
|
||||||
|
@ -732,7 +732,9 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
|
|||||||
|
|
||||||
controlnet_down_block_res_samples = ()
|
controlnet_down_block_res_samples = ()
|
||||||
|
|
||||||
for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):
|
for down_block_res_sample, controlnet_block in zip(
|
||||||
|
down_block_res_samples, self.controlnet_down_blocks, strict=True
|
||||||
|
):
|
||||||
down_block_res_sample = controlnet_block(down_block_res_sample)
|
down_block_res_sample = controlnet_block(down_block_res_sample)
|
||||||
controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,)
|
controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,)
|
||||||
|
|
||||||
@ -745,7 +747,9 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
|
|||||||
scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0
|
scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0
|
||||||
|
|
||||||
scales = scales * conditioning_scale
|
scales = scales * conditioning_scale
|
||||||
down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)]
|
down_block_res_samples = [
|
||||||
|
sample * scale for sample, scale in zip(down_block_res_samples, scales, strict=False)
|
||||||
|
]
|
||||||
mid_block_res_sample = mid_block_res_sample * scales[-1] # last one
|
mid_block_res_sample = mid_block_res_sample * scales[-1] # last one
|
||||||
else:
|
else:
|
||||||
down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
|
down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
|
||||||
|
@ -225,34 +225,34 @@ def basicConfig(**kwargs):
|
|||||||
|
|
||||||
|
|
||||||
_FACILITY_MAP = (
|
_FACILITY_MAP = (
|
||||||
dict(
|
{
|
||||||
LOG_KERN=syslog.LOG_KERN,
|
"LOG_KERN": syslog.LOG_KERN,
|
||||||
LOG_USER=syslog.LOG_USER,
|
"LOG_USER": syslog.LOG_USER,
|
||||||
LOG_MAIL=syslog.LOG_MAIL,
|
"LOG_MAIL": syslog.LOG_MAIL,
|
||||||
LOG_DAEMON=syslog.LOG_DAEMON,
|
"LOG_DAEMON": syslog.LOG_DAEMON,
|
||||||
LOG_AUTH=syslog.LOG_AUTH,
|
"LOG_AUTH": syslog.LOG_AUTH,
|
||||||
LOG_LPR=syslog.LOG_LPR,
|
"LOG_LPR": syslog.LOG_LPR,
|
||||||
LOG_NEWS=syslog.LOG_NEWS,
|
"LOG_NEWS": syslog.LOG_NEWS,
|
||||||
LOG_UUCP=syslog.LOG_UUCP,
|
"LOG_UUCP": syslog.LOG_UUCP,
|
||||||
LOG_CRON=syslog.LOG_CRON,
|
"LOG_CRON": syslog.LOG_CRON,
|
||||||
LOG_SYSLOG=syslog.LOG_SYSLOG,
|
"LOG_SYSLOG": syslog.LOG_SYSLOG,
|
||||||
LOG_LOCAL0=syslog.LOG_LOCAL0,
|
"LOG_LOCAL0": syslog.LOG_LOCAL0,
|
||||||
LOG_LOCAL1=syslog.LOG_LOCAL1,
|
"LOG_LOCAL1": syslog.LOG_LOCAL1,
|
||||||
LOG_LOCAL2=syslog.LOG_LOCAL2,
|
"LOG_LOCAL2": syslog.LOG_LOCAL2,
|
||||||
LOG_LOCAL3=syslog.LOG_LOCAL3,
|
"LOG_LOCAL3": syslog.LOG_LOCAL3,
|
||||||
LOG_LOCAL4=syslog.LOG_LOCAL4,
|
"LOG_LOCAL4": syslog.LOG_LOCAL4,
|
||||||
LOG_LOCAL5=syslog.LOG_LOCAL5,
|
"LOG_LOCAL5": syslog.LOG_LOCAL5,
|
||||||
LOG_LOCAL6=syslog.LOG_LOCAL6,
|
"LOG_LOCAL6": syslog.LOG_LOCAL6,
|
||||||
LOG_LOCAL7=syslog.LOG_LOCAL7,
|
"LOG_LOCAL7": syslog.LOG_LOCAL7,
|
||||||
)
|
}
|
||||||
if SYSLOG_AVAILABLE
|
if SYSLOG_AVAILABLE
|
||||||
else dict()
|
else {}
|
||||||
)
|
)
|
||||||
|
|
||||||
_SOCK_MAP = dict(
|
_SOCK_MAP = {
|
||||||
SOCK_STREAM=socket.SOCK_STREAM,
|
"SOCK_STREAM": socket.SOCK_STREAM,
|
||||||
SOCK_DGRAM=socket.SOCK_DGRAM,
|
"SOCK_DGRAM": socket.SOCK_DGRAM,
|
||||||
)
|
}
|
||||||
|
|
||||||
|
|
||||||
class InvokeAIFormatter(logging.Formatter):
|
class InvokeAIFormatter(logging.Formatter):
|
||||||
@ -344,7 +344,7 @@ LOG_FORMATTERS = {
|
|||||||
|
|
||||||
|
|
||||||
class InvokeAILogger(object):
|
class InvokeAILogger(object):
|
||||||
loggers = dict()
|
loggers = {}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_logger(
|
def get_logger(
|
||||||
@ -364,7 +364,7 @@ class InvokeAILogger(object):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def get_loggers(cls, config: InvokeAIAppConfig) -> list[logging.Handler]:
|
def get_loggers(cls, config: InvokeAIAppConfig) -> list[logging.Handler]:
|
||||||
handler_strs = config.log_handlers
|
handler_strs = config.log_handlers
|
||||||
handlers = list()
|
handlers = []
|
||||||
for handler in handler_strs:
|
for handler in handler_strs:
|
||||||
handler_name, *args = handler.split("=", 2)
|
handler_name, *args = handler.split("=", 2)
|
||||||
args = args[0] if len(args) > 0 else None
|
args = args[0] if len(args) > 0 else None
|
||||||
@ -398,7 +398,7 @@ class InvokeAILogger(object):
|
|||||||
raise ValueError("syslog is not available on this system")
|
raise ValueError("syslog is not available on this system")
|
||||||
if not args:
|
if not args:
|
||||||
args = "/dev/log" if Path("/dev/log").exists() else "address:localhost:514"
|
args = "/dev/log" if Path("/dev/log").exists() else "address:localhost:514"
|
||||||
syslog_args = dict()
|
syslog_args = {}
|
||||||
try:
|
try:
|
||||||
for a in args.split(","):
|
for a in args.split(","):
|
||||||
arg_name, *arg_value = a.split(":", 2)
|
arg_name, *arg_value = a.split(":", 2)
|
||||||
@ -434,7 +434,7 @@ class InvokeAILogger(object):
|
|||||||
path = url.path
|
path = url.path
|
||||||
port = url.port or 80
|
port = url.port or 80
|
||||||
|
|
||||||
syslog_args = dict()
|
syslog_args = {}
|
||||||
for a in arg_list:
|
for a in arg_list:
|
||||||
arg_name, *arg_value = a.split(":", 2)
|
arg_name, *arg_value = a.split(":", 2)
|
||||||
if arg_name == "method":
|
if arg_name == "method":
|
||||||
|
@ -29,7 +29,7 @@ def log_txt_as_img(wh, xc, size=10):
|
|||||||
# wh a tuple of (width, height)
|
# wh a tuple of (width, height)
|
||||||
# xc a list of captions to plot
|
# xc a list of captions to plot
|
||||||
b = len(xc)
|
b = len(xc)
|
||||||
txts = list()
|
txts = []
|
||||||
for bi in range(b):
|
for bi in range(b):
|
||||||
txt = Image.new("RGB", wh, color="white")
|
txt = Image.new("RGB", wh, color="white")
|
||||||
draw = ImageDraw.Draw(txt)
|
draw = ImageDraw.Draw(txt)
|
||||||
@ -93,7 +93,7 @@ def instantiate_from_config(config, **kwargs):
|
|||||||
elif config == "__is_unconditional__":
|
elif config == "__is_unconditional__":
|
||||||
return None
|
return None
|
||||||
raise KeyError("Expected key `target` to instantiate.")
|
raise KeyError("Expected key `target` to instantiate.")
|
||||||
return get_obj_from_str(config["target"])(**config.get("params", dict()), **kwargs)
|
return get_obj_from_str(config["target"])(**config.get("params", {}), **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def get_obj_from_str(string, reload=False):
|
def get_obj_from_str(string, reload=False):
|
||||||
@ -231,8 +231,9 @@ def rand_perlin_2d(shape, res, device, fade=lambda t: 6 * t**5 - 15 * t**4 + 10
|
|||||||
angles = 2 * math.pi * rand_val
|
angles = 2 * math.pi * rand_val
|
||||||
gradients = torch.stack((torch.cos(angles), torch.sin(angles)), dim=-1).to(device)
|
gradients = torch.stack((torch.cos(angles), torch.sin(angles)), dim=-1).to(device)
|
||||||
|
|
||||||
tile_grads = (
|
def tile_grads(slice1, slice2):
|
||||||
lambda slice1, slice2: gradients[slice1[0] : slice1[1], slice2[0] : slice2[1]]
|
return (
|
||||||
|
gradients[slice1[0] : slice1[1], slice2[0] : slice2[1]]
|
||||||
.repeat_interleave(d[0], 0)
|
.repeat_interleave(d[0], 0)
|
||||||
.repeat_interleave(d[1], 1)
|
.repeat_interleave(d[1], 1)
|
||||||
)
|
)
|
||||||
|
@ -341,19 +341,19 @@ class InvokeAIMetadataParser:
|
|||||||
# this was more elegant as a case statement, but that's not available in python 3.9
|
# this was more elegant as a case statement, but that's not available in python 3.9
|
||||||
if old_scheduler is None:
|
if old_scheduler is None:
|
||||||
return None
|
return None
|
||||||
scheduler_map = dict(
|
scheduler_map = {
|
||||||
ddim="ddim",
|
"ddim": "ddim",
|
||||||
plms="pnmd",
|
"plms": "pnmd",
|
||||||
k_lms="lms",
|
"k_lms": "lms",
|
||||||
k_dpm_2="kdpm_2",
|
"k_dpm_2": "kdpm_2",
|
||||||
k_dpm_2_a="kdpm_2_a",
|
"k_dpm_2_a": "kdpm_2_a",
|
||||||
dpmpp_2="dpmpp_2s",
|
"dpmpp_2": "dpmpp_2s",
|
||||||
k_dpmpp_2="dpmpp_2m",
|
"k_dpmpp_2": "dpmpp_2m",
|
||||||
k_dpmpp_2_a=None, # invalid, in 2.3.x, selecting this sample would just fallback to last run or plms if new session
|
"k_dpmpp_2_a": None, # invalid, in 2.3.x, selecting this sample would just fallback to last run or plms if new session
|
||||||
k_euler="euler",
|
"k_euler": "euler",
|
||||||
k_euler_a="euler_a",
|
"k_euler_a": "euler_a",
|
||||||
k_heun="heun",
|
"k_heun": "heun",
|
||||||
)
|
}
|
||||||
return scheduler_map.get(old_scheduler)
|
return scheduler_map.get(old_scheduler)
|
||||||
|
|
||||||
def split_prompt(self, raw_prompt: str):
|
def split_prompt(self, raw_prompt: str):
|
||||||
|
@ -72,7 +72,7 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
|||||||
def __init__(self, parentApp, name, multipage=False, *args, **keywords):
|
def __init__(self, parentApp, name, multipage=False, *args, **keywords):
|
||||||
self.multipage = multipage
|
self.multipage = multipage
|
||||||
self.subprocess = None
|
self.subprocess = None
|
||||||
super().__init__(parentApp=parentApp, name=name, *args, **keywords)
|
super().__init__(parentApp=parentApp, name=name, *args, **keywords) # noqa: B026 # TODO: maybe this is bad?
|
||||||
|
|
||||||
def create(self):
|
def create(self):
|
||||||
self.keypress_timeout = 10
|
self.keypress_timeout = 10
|
||||||
@ -203,14 +203,14 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# This restores the selected page on return from an installation
|
# This restores the selected page on return from an installation
|
||||||
for i in range(1, self.current_tab + 1):
|
for _i in range(1, self.current_tab + 1):
|
||||||
self.tabs.h_cursor_line_down(1)
|
self.tabs.h_cursor_line_down(1)
|
||||||
self._toggle_tables([self.current_tab])
|
self._toggle_tables([self.current_tab])
|
||||||
|
|
||||||
############# diffusers tab ##########
|
############# diffusers tab ##########
|
||||||
def add_starter_pipelines(self) -> dict[str, npyscreen.widget]:
|
def add_starter_pipelines(self) -> dict[str, npyscreen.widget]:
|
||||||
"""Add widgets responsible for selecting diffusers models"""
|
"""Add widgets responsible for selecting diffusers models"""
|
||||||
widgets = dict()
|
widgets = {}
|
||||||
models = self.all_models
|
models = self.all_models
|
||||||
starters = self.starter_models
|
starters = self.starter_models
|
||||||
starter_model_labels = self.model_labels
|
starter_model_labels = self.model_labels
|
||||||
@ -258,10 +258,12 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
|||||||
model_type: ModelType,
|
model_type: ModelType,
|
||||||
window_width: int = 120,
|
window_width: int = 120,
|
||||||
install_prompt: str = None,
|
install_prompt: str = None,
|
||||||
exclude: set = set(),
|
exclude: set = None,
|
||||||
) -> dict[str, npyscreen.widget]:
|
) -> dict[str, npyscreen.widget]:
|
||||||
"""Generic code to create model selection widgets"""
|
"""Generic code to create model selection widgets"""
|
||||||
widgets = dict()
|
if exclude is None:
|
||||||
|
exclude = set()
|
||||||
|
widgets = {}
|
||||||
model_list = [x for x in self.all_models if self.all_models[x].model_type == model_type and x not in exclude]
|
model_list = [x for x in self.all_models if self.all_models[x].model_type == model_type and x not in exclude]
|
||||||
model_labels = [self.model_labels[x] for x in model_list]
|
model_labels = [self.model_labels[x] for x in model_list]
|
||||||
|
|
||||||
@ -366,13 +368,13 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
|||||||
]
|
]
|
||||||
|
|
||||||
for group in widgets:
|
for group in widgets:
|
||||||
for k, v in group.items():
|
for _k, v in group.items():
|
||||||
try:
|
try:
|
||||||
v.hidden = True
|
v.hidden = True
|
||||||
v.editable = False
|
v.editable = False
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
for k, v in widgets[selected_tab].items():
|
for _k, v in widgets[selected_tab].items():
|
||||||
try:
|
try:
|
||||||
v.hidden = False
|
v.hidden = False
|
||||||
if not isinstance(v, (npyscreen.FixedText, npyscreen.TitleFixedText, CenteredTitleText)):
|
if not isinstance(v, (npyscreen.FixedText, npyscreen.TitleFixedText, CenteredTitleText)):
|
||||||
@ -391,7 +393,7 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
|||||||
label_width = max([len(models[x].name) for x in models])
|
label_width = max([len(models[x].name) for x in models])
|
||||||
description_width = window_width - label_width - checkbox_width - spacing_width
|
description_width = window_width - label_width - checkbox_width - spacing_width
|
||||||
|
|
||||||
result = dict()
|
result = {}
|
||||||
for x in models.keys():
|
for x in models.keys():
|
||||||
description = models[x].description
|
description = models[x].description
|
||||||
description = (
|
description = (
|
||||||
@ -433,11 +435,11 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
|||||||
parent_conn, child_conn = Pipe()
|
parent_conn, child_conn = Pipe()
|
||||||
p = Process(
|
p = Process(
|
||||||
target=process_and_execute,
|
target=process_and_execute,
|
||||||
kwargs=dict(
|
kwargs={
|
||||||
opt=app.program_opts,
|
"opt": app.program_opts,
|
||||||
selections=app.install_selections,
|
"selections": app.install_selections,
|
||||||
conn_out=child_conn,
|
"conn_out": child_conn,
|
||||||
),
|
},
|
||||||
)
|
)
|
||||||
p.start()
|
p.start()
|
||||||
child_conn.close()
|
child_conn.close()
|
||||||
@ -558,7 +560,7 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
|||||||
for section in ui_sections:
|
for section in ui_sections:
|
||||||
if "models_selected" not in section:
|
if "models_selected" not in section:
|
||||||
continue
|
continue
|
||||||
selected = set([section["models"][x] for x in section["models_selected"].value])
|
selected = {section["models"][x] for x in section["models_selected"].value}
|
||||||
models_to_install = [x for x in selected if not self.all_models[x].installed]
|
models_to_install = [x for x in selected if not self.all_models[x].installed]
|
||||||
models_to_remove = [x for x in section["models"] if x not in selected and self.all_models[x].installed]
|
models_to_remove = [x for x in section["models"] if x not in selected and self.all_models[x].installed]
|
||||||
selections.remove_models.extend(models_to_remove)
|
selections.remove_models.extend(models_to_remove)
|
||||||
|
@ -11,6 +11,7 @@ import sys
|
|||||||
import textwrap
|
import textwrap
|
||||||
from curses import BUTTON2_CLICKED, BUTTON3_CLICKED
|
from curses import BUTTON2_CLICKED, BUTTON3_CLICKED
|
||||||
from shutil import get_terminal_size
|
from shutil import get_terminal_size
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import npyscreen
|
import npyscreen
|
||||||
import npyscreen.wgmultiline as wgmultiline
|
import npyscreen.wgmultiline as wgmultiline
|
||||||
@ -243,7 +244,9 @@ class SelectColumnBase:
|
|||||||
|
|
||||||
|
|
||||||
class MultiSelectColumns(SelectColumnBase, npyscreen.MultiSelect):
|
class MultiSelectColumns(SelectColumnBase, npyscreen.MultiSelect):
|
||||||
def __init__(self, screen, columns: int = 1, values: list = [], **keywords):
|
def __init__(self, screen, columns: int = 1, values: Optional[list] = None, **keywords):
|
||||||
|
if values is None:
|
||||||
|
values = []
|
||||||
self.columns = columns
|
self.columns = columns
|
||||||
self.value_cnt = len(values)
|
self.value_cnt = len(values)
|
||||||
self.rows = math.ceil(self.value_cnt / self.columns)
|
self.rows = math.ceil(self.value_cnt / self.columns)
|
||||||
@ -267,7 +270,9 @@ class SingleSelectWithChanged(npyscreen.SelectOne):
|
|||||||
class SingleSelectColumnsSimple(SelectColumnBase, SingleSelectWithChanged):
|
class SingleSelectColumnsSimple(SelectColumnBase, SingleSelectWithChanged):
|
||||||
"""Row of radio buttons. Spacebar to select."""
|
"""Row of radio buttons. Spacebar to select."""
|
||||||
|
|
||||||
def __init__(self, screen, columns: int = 1, values: list = [], **keywords):
|
def __init__(self, screen, columns: int = 1, values: list = None, **keywords):
|
||||||
|
if values is None:
|
||||||
|
values = []
|
||||||
self.columns = columns
|
self.columns = columns
|
||||||
self.value_cnt = len(values)
|
self.value_cnt = len(values)
|
||||||
self.rows = math.ceil(self.value_cnt / self.columns)
|
self.rows = math.ceil(self.value_cnt / self.columns)
|
||||||
|
@ -275,14 +275,14 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
|
|||||||
interp = self.interpolations[self.merge_method.value[0]]
|
interp = self.interpolations[self.merge_method.value[0]]
|
||||||
|
|
||||||
bases = ["sd-1", "sd-2", "sdxl"]
|
bases = ["sd-1", "sd-2", "sdxl"]
|
||||||
args = dict(
|
args = {
|
||||||
model_names=models,
|
"model_names": models,
|
||||||
base_model=BaseModelType(bases[self.base_select.value[0]]),
|
"base_model": BaseModelType(bases[self.base_select.value[0]]),
|
||||||
alpha=self.alpha.value,
|
"alpha": self.alpha.value,
|
||||||
interp=interp,
|
"interp": interp,
|
||||||
force=self.force.value,
|
"force": self.force.value,
|
||||||
merged_model_name=self.merged_model_name.value,
|
"merged_model_name": self.merged_model_name.value,
|
||||||
)
|
}
|
||||||
return args
|
return args
|
||||||
|
|
||||||
def check_for_overwrite(self) -> bool:
|
def check_for_overwrite(self) -> bool:
|
||||||
@ -297,7 +297,7 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
|
|||||||
def validate_field_values(self) -> bool:
|
def validate_field_values(self) -> bool:
|
||||||
bad_fields = []
|
bad_fields = []
|
||||||
model_names = self.model_names
|
model_names = self.model_names
|
||||||
selected_models = set((model_names[self.model1.value[0]], model_names[self.model2.value[0]]))
|
selected_models = {model_names[self.model1.value[0]], model_names[self.model2.value[0]]}
|
||||||
if self.model3.value[0] > 0:
|
if self.model3.value[0] > 0:
|
||||||
selected_models.add(model_names[self.model3.value[0] - 1])
|
selected_models.add(model_names[self.model3.value[0] - 1])
|
||||||
if len(selected_models) < 2:
|
if len(selected_models) < 2:
|
||||||
|
@ -276,13 +276,13 @@ class textualInversionForm(npyscreen.FormMultiPageAction):
|
|||||||
|
|
||||||
def get_model_names(self) -> Tuple[List[str], int]:
|
def get_model_names(self) -> Tuple[List[str], int]:
|
||||||
conf = OmegaConf.load(config.root_dir / "configs/models.yaml")
|
conf = OmegaConf.load(config.root_dir / "configs/models.yaml")
|
||||||
model_names = [idx for idx in sorted(list(conf.keys())) if conf[idx].get("format", None) == "diffusers"]
|
model_names = [idx for idx in sorted(conf.keys()) if conf[idx].get("format", None) == "diffusers"]
|
||||||
defaults = [idx for idx in range(len(model_names)) if "default" in conf[model_names[idx]]]
|
defaults = [idx for idx in range(len(model_names)) if "default" in conf[model_names[idx]]]
|
||||||
default = defaults[0] if len(defaults) > 0 else 0
|
default = defaults[0] if len(defaults) > 0 else 0
|
||||||
return (model_names, default)
|
return (model_names, default)
|
||||||
|
|
||||||
def marshall_arguments(self) -> dict:
|
def marshall_arguments(self) -> dict:
|
||||||
args = dict()
|
args = {}
|
||||||
|
|
||||||
# the choices
|
# the choices
|
||||||
args.update(
|
args.update(
|
||||||
|
@ -24,6 +24,7 @@ module.exports = {
|
|||||||
root: true,
|
root: true,
|
||||||
rules: {
|
rules: {
|
||||||
curly: 'error',
|
curly: 'error',
|
||||||
|
'react/jsx-no-bind': ['error', { allowBind: true }],
|
||||||
'react/jsx-curly-brace-presence': [
|
'react/jsx-curly-brace-presence': [
|
||||||
'error',
|
'error',
|
||||||
{ props: 'never', children: 'never' },
|
{ props: 'never', children: 'never' },
|
||||||
|
@ -583,6 +583,7 @@
|
|||||||
"strength": "Image to image strength",
|
"strength": "Image to image strength",
|
||||||
"Threshold": "Noise Threshold",
|
"Threshold": "Noise Threshold",
|
||||||
"variations": "Seed-weight pairs",
|
"variations": "Seed-weight pairs",
|
||||||
|
"vae": "VAE",
|
||||||
"width": "Width",
|
"width": "Width",
|
||||||
"workflow": "Workflow"
|
"workflow": "Workflow"
|
||||||
},
|
},
|
||||||
|
@ -1487,5 +1487,18 @@
|
|||||||
"scheduler": "Campionatore",
|
"scheduler": "Campionatore",
|
||||||
"recallParameters": "Richiama i parametri",
|
"recallParameters": "Richiama i parametri",
|
||||||
"noRecallParameters": "Nessun parametro da richiamare trovato"
|
"noRecallParameters": "Nessun parametro da richiamare trovato"
|
||||||
|
},
|
||||||
|
"hrf": {
|
||||||
|
"enableHrf": "Abilita Correzione Alta Risoluzione",
|
||||||
|
"upscaleMethod": "Metodo di ampliamento",
|
||||||
|
"enableHrfTooltip": "Genera con una risoluzione iniziale inferiore, esegue l'ampliamento alla risoluzione di base, quindi esegue Immagine a Immagine.",
|
||||||
|
"metadata": {
|
||||||
|
"strength": "Forza della Correzione Alta Risoluzione",
|
||||||
|
"enabled": "Correzione Alta Risoluzione Abilitata",
|
||||||
|
"method": "Metodo della Correzione Alta Risoluzione"
|
||||||
|
},
|
||||||
|
"hrf": "Correzione Alta Risoluzione",
|
||||||
|
"hrfStrength": "Forza della Correzione Alta Risoluzione",
|
||||||
|
"strengthTooltip": "Valori più bassi comportano meno dettagli, il che può ridurre potenziali artefatti."
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -19,7 +19,7 @@ import sdxlReducer from 'features/sdxl/store/sdxlSlice';
|
|||||||
import configReducer from 'features/system/store/configSlice';
|
import configReducer from 'features/system/store/configSlice';
|
||||||
import systemReducer from 'features/system/store/systemSlice';
|
import systemReducer from 'features/system/store/systemSlice';
|
||||||
import queueReducer from 'features/queue/store/queueSlice';
|
import queueReducer from 'features/queue/store/queueSlice';
|
||||||
import modelmanagerReducer from 'features/ui/components/tabs/ModelManager/store/modelManagerSlice';
|
import modelmanagerReducer from 'features/modelManager/store/modelManagerSlice';
|
||||||
import hotkeysReducer from 'features/ui/store/hotkeysSlice';
|
import hotkeysReducer from 'features/ui/store/hotkeysSlice';
|
||||||
import uiReducer from 'features/ui/store/uiSlice';
|
import uiReducer from 'features/ui/store/uiSlice';
|
||||||
import dynamicMiddlewares from 'redux-dynamic-middlewares';
|
import dynamicMiddlewares from 'redux-dynamic-middlewares';
|
||||||
|
@ -8,7 +8,14 @@ import {
|
|||||||
forwardRef,
|
forwardRef,
|
||||||
useDisclosure,
|
useDisclosure,
|
||||||
} from '@chakra-ui/react';
|
} from '@chakra-ui/react';
|
||||||
import { cloneElement, memo, ReactElement, ReactNode, useRef } from 'react';
|
import {
|
||||||
|
cloneElement,
|
||||||
|
memo,
|
||||||
|
ReactElement,
|
||||||
|
ReactNode,
|
||||||
|
useCallback,
|
||||||
|
useRef,
|
||||||
|
} from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import IAIButton from './IAIButton';
|
import IAIButton from './IAIButton';
|
||||||
|
|
||||||
@ -38,15 +45,15 @@ const IAIAlertDialog = forwardRef((props: Props, ref) => {
|
|||||||
const { isOpen, onOpen, onClose } = useDisclosure();
|
const { isOpen, onOpen, onClose } = useDisclosure();
|
||||||
const cancelRef = useRef<HTMLButtonElement | null>(null);
|
const cancelRef = useRef<HTMLButtonElement | null>(null);
|
||||||
|
|
||||||
const handleAccept = () => {
|
const handleAccept = useCallback(() => {
|
||||||
acceptCallback();
|
acceptCallback();
|
||||||
onClose();
|
onClose();
|
||||||
};
|
}, [acceptCallback, onClose]);
|
||||||
|
|
||||||
const handleCancel = () => {
|
const handleCancel = useCallback(() => {
|
||||||
cancelCallback && cancelCallback();
|
cancelCallback && cancelCallback();
|
||||||
onClose();
|
onClose();
|
||||||
};
|
}, [cancelCallback, onClose]);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<>
|
<>
|
||||||
|
@ -1,9 +1,12 @@
|
|||||||
import { Box, ChakraProps } from '@chakra-ui/react';
|
import { ChakraProps, Flex } from '@chakra-ui/react';
|
||||||
import { memo } from 'react';
|
import { memo, useCallback } from 'react';
|
||||||
import { RgbaColorPicker } from 'react-colorful';
|
import { RgbaColorPicker } from 'react-colorful';
|
||||||
import { ColorPickerBaseProps, RgbaColor } from 'react-colorful/dist/types';
|
import { ColorPickerBaseProps, RgbaColor } from 'react-colorful/dist/types';
|
||||||
|
import IAINumberInput from './IAINumberInput';
|
||||||
|
|
||||||
type IAIColorPickerProps = ColorPickerBaseProps<RgbaColor>;
|
type IAIColorPickerProps = ColorPickerBaseProps<RgbaColor> & {
|
||||||
|
withNumberInput?: boolean;
|
||||||
|
};
|
||||||
|
|
||||||
const colorPickerStyles: NonNullable<ChakraProps['sx']> = {
|
const colorPickerStyles: NonNullable<ChakraProps['sx']> = {
|
||||||
width: 6,
|
width: 6,
|
||||||
@ -11,17 +14,84 @@ const colorPickerStyles: NonNullable<ChakraProps['sx']> = {
|
|||||||
borderColor: 'base.100',
|
borderColor: 'base.100',
|
||||||
};
|
};
|
||||||
|
|
||||||
const sx = {
|
const sx: ChakraProps['sx'] = {
|
||||||
'.react-colorful__hue-pointer': colorPickerStyles,
|
'.react-colorful__hue-pointer': colorPickerStyles,
|
||||||
'.react-colorful__saturation-pointer': colorPickerStyles,
|
'.react-colorful__saturation-pointer': colorPickerStyles,
|
||||||
'.react-colorful__alpha-pointer': colorPickerStyles,
|
'.react-colorful__alpha-pointer': colorPickerStyles,
|
||||||
|
gap: 2,
|
||||||
|
flexDir: 'column',
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const numberInputWidth: ChakraProps['w'] = '4.2rem';
|
||||||
|
|
||||||
const IAIColorPicker = (props: IAIColorPickerProps) => {
|
const IAIColorPicker = (props: IAIColorPickerProps) => {
|
||||||
|
const { color, onChange, withNumberInput, ...rest } = props;
|
||||||
|
const handleChangeR = useCallback(
|
||||||
|
(r: number) => onChange({ ...color, r }),
|
||||||
|
[color, onChange]
|
||||||
|
);
|
||||||
|
const handleChangeG = useCallback(
|
||||||
|
(g: number) => onChange({ ...color, g }),
|
||||||
|
[color, onChange]
|
||||||
|
);
|
||||||
|
const handleChangeB = useCallback(
|
||||||
|
(b: number) => onChange({ ...color, b }),
|
||||||
|
[color, onChange]
|
||||||
|
);
|
||||||
|
const handleChangeA = useCallback(
|
||||||
|
(a: number) => onChange({ ...color, a }),
|
||||||
|
[color, onChange]
|
||||||
|
);
|
||||||
return (
|
return (
|
||||||
<Box sx={sx}>
|
<Flex sx={sx}>
|
||||||
<RgbaColorPicker {...props} />
|
<RgbaColorPicker
|
||||||
</Box>
|
color={color}
|
||||||
|
onChange={onChange}
|
||||||
|
style={{ width: '100%' }}
|
||||||
|
{...rest}
|
||||||
|
/>
|
||||||
|
{withNumberInput && (
|
||||||
|
<Flex>
|
||||||
|
<IAINumberInput
|
||||||
|
value={color.r}
|
||||||
|
onChange={handleChangeR}
|
||||||
|
min={0}
|
||||||
|
max={255}
|
||||||
|
step={1}
|
||||||
|
label="Red"
|
||||||
|
w={numberInputWidth}
|
||||||
|
/>
|
||||||
|
<IAINumberInput
|
||||||
|
value={color.g}
|
||||||
|
onChange={handleChangeG}
|
||||||
|
min={0}
|
||||||
|
max={255}
|
||||||
|
step={1}
|
||||||
|
label="Green"
|
||||||
|
w={numberInputWidth}
|
||||||
|
/>
|
||||||
|
<IAINumberInput
|
||||||
|
value={color.b}
|
||||||
|
onChange={handleChangeB}
|
||||||
|
min={0}
|
||||||
|
max={255}
|
||||||
|
step={1}
|
||||||
|
label="Blue"
|
||||||
|
w={numberInputWidth}
|
||||||
|
/>
|
||||||
|
<IAINumberInput
|
||||||
|
value={color.a}
|
||||||
|
onChange={handleChangeA}
|
||||||
|
step={0.1}
|
||||||
|
min={0}
|
||||||
|
max={1}
|
||||||
|
label="Alpha"
|
||||||
|
w={numberInputWidth}
|
||||||
|
isInteger={false}
|
||||||
|
/>
|
||||||
|
</Flex>
|
||||||
|
)}
|
||||||
|
</Flex>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -1,43 +0,0 @@
|
|||||||
import { Box, Flex, Icon } from '@chakra-ui/react';
|
|
||||||
import { memo } from 'react';
|
|
||||||
import { FaExclamation } from 'react-icons/fa';
|
|
||||||
|
|
||||||
const IAIErrorLoadingImageFallback = () => {
|
|
||||||
return (
|
|
||||||
<Box
|
|
||||||
sx={{
|
|
||||||
position: 'relative',
|
|
||||||
height: 'full',
|
|
||||||
width: 'full',
|
|
||||||
'::before': {
|
|
||||||
content: "''",
|
|
||||||
display: 'block',
|
|
||||||
pt: '100%',
|
|
||||||
},
|
|
||||||
}}
|
|
||||||
>
|
|
||||||
<Flex
|
|
||||||
sx={{
|
|
||||||
position: 'absolute',
|
|
||||||
top: 0,
|
|
||||||
insetInlineStart: 0,
|
|
||||||
height: 'full',
|
|
||||||
width: 'full',
|
|
||||||
alignItems: 'center',
|
|
||||||
justifyContent: 'center',
|
|
||||||
borderRadius: 'base',
|
|
||||||
bg: 'base.100',
|
|
||||||
color: 'base.500',
|
|
||||||
_dark: {
|
|
||||||
color: 'base.700',
|
|
||||||
bg: 'base.850',
|
|
||||||
},
|
|
||||||
}}
|
|
||||||
>
|
|
||||||
<Icon as={FaExclamation} boxSize={16} opacity={0.7} />
|
|
||||||
</Flex>
|
|
||||||
</Box>
|
|
||||||
);
|
|
||||||
};
|
|
||||||
|
|
||||||
export default memo(IAIErrorLoadingImageFallback);
|
|
@ -1,8 +0,0 @@
|
|||||||
import { chakra } from '@chakra-ui/react';
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Chakra-enabled <form />
|
|
||||||
*/
|
|
||||||
const IAIForm = chakra.form;
|
|
||||||
|
|
||||||
export default IAIForm;
|
|
@ -1,15 +0,0 @@
|
|||||||
import { FormErrorMessage, FormErrorMessageProps } from '@chakra-ui/react';
|
|
||||||
import { ReactNode } from 'react';
|
|
||||||
|
|
||||||
type IAIFormErrorMessageProps = FormErrorMessageProps & {
|
|
||||||
children: ReactNode | string;
|
|
||||||
};
|
|
||||||
|
|
||||||
export default function IAIFormErrorMessage(props: IAIFormErrorMessageProps) {
|
|
||||||
const { children, ...rest } = props;
|
|
||||||
return (
|
|
||||||
<FormErrorMessage color="error.400" {...rest}>
|
|
||||||
{children}
|
|
||||||
</FormErrorMessage>
|
|
||||||
);
|
|
||||||
}
|
|
@ -1,15 +0,0 @@
|
|||||||
import { FormHelperText, FormHelperTextProps } from '@chakra-ui/react';
|
|
||||||
import { ReactNode } from 'react';
|
|
||||||
|
|
||||||
type IAIFormHelperTextProps = FormHelperTextProps & {
|
|
||||||
children: ReactNode | string;
|
|
||||||
};
|
|
||||||
|
|
||||||
export default function IAIFormHelperText(props: IAIFormHelperTextProps) {
|
|
||||||
const { children, ...rest } = props;
|
|
||||||
return (
|
|
||||||
<FormHelperText margin={0} color="base.400" {...rest}>
|
|
||||||
{children}
|
|
||||||
</FormHelperText>
|
|
||||||
);
|
|
||||||
}
|
|
@ -1,25 +0,0 @@
|
|||||||
import { Flex, useColorMode } from '@chakra-ui/react';
|
|
||||||
import { ReactElement } from 'react';
|
|
||||||
import { mode } from 'theme/util/mode';
|
|
||||||
|
|
||||||
export function IAIFormItemWrapper({
|
|
||||||
children,
|
|
||||||
}: {
|
|
||||||
children: ReactElement | ReactElement[];
|
|
||||||
}) {
|
|
||||||
const { colorMode } = useColorMode();
|
|
||||||
return (
|
|
||||||
<Flex
|
|
||||||
sx={{
|
|
||||||
flexDirection: 'column',
|
|
||||||
padding: 4,
|
|
||||||
rowGap: 4,
|
|
||||||
borderRadius: 'base',
|
|
||||||
width: 'full',
|
|
||||||
bg: mode('base.100', 'base.900')(colorMode),
|
|
||||||
}}
|
|
||||||
>
|
|
||||||
{children}
|
|
||||||
</Flex>
|
|
||||||
);
|
|
||||||
}
|
|
@ -1,25 +0,0 @@
|
|||||||
import {
|
|
||||||
Checkbox,
|
|
||||||
CheckboxProps,
|
|
||||||
FormControl,
|
|
||||||
FormControlProps,
|
|
||||||
FormLabel,
|
|
||||||
} from '@chakra-ui/react';
|
|
||||||
import { memo, ReactNode } from 'react';
|
|
||||||
|
|
||||||
type IAIFullCheckboxProps = CheckboxProps & {
|
|
||||||
label: string | ReactNode;
|
|
||||||
formControlProps?: FormControlProps;
|
|
||||||
};
|
|
||||||
|
|
||||||
const IAIFullCheckbox = (props: IAIFullCheckboxProps) => {
|
|
||||||
const { label, formControlProps, ...rest } = props;
|
|
||||||
return (
|
|
||||||
<FormControl {...formControlProps}>
|
|
||||||
<FormLabel>{label}</FormLabel>
|
|
||||||
<Checkbox colorScheme="accent" {...rest} />
|
|
||||||
</FormControl>
|
|
||||||
);
|
|
||||||
};
|
|
||||||
|
|
||||||
export default memo(IAIFullCheckbox);
|
|
@ -1,6 +1,7 @@
|
|||||||
import { useColorMode } from '@chakra-ui/react';
|
import { useColorMode } from '@chakra-ui/react';
|
||||||
import { TextInput, TextInputProps } from '@mantine/core';
|
import { TextInput, TextInputProps } from '@mantine/core';
|
||||||
import { useChakraThemeTokens } from 'common/hooks/useChakraThemeTokens';
|
import { useChakraThemeTokens } from 'common/hooks/useChakraThemeTokens';
|
||||||
|
import { useCallback } from 'react';
|
||||||
import { mode } from 'theme/util/mode';
|
import { mode } from 'theme/util/mode';
|
||||||
|
|
||||||
type IAIMantineTextInputProps = TextInputProps;
|
type IAIMantineTextInputProps = TextInputProps;
|
||||||
@ -20,9 +21,8 @@ export default function IAIMantineTextInput(props: IAIMantineTextInputProps) {
|
|||||||
} = useChakraThemeTokens();
|
} = useChakraThemeTokens();
|
||||||
const { colorMode } = useColorMode();
|
const { colorMode } = useColorMode();
|
||||||
|
|
||||||
return (
|
const stylesFunc = useCallback(
|
||||||
<TextInput
|
() => ({
|
||||||
styles={() => ({
|
|
||||||
input: {
|
input: {
|
||||||
color: mode(base900, base100)(colorMode),
|
color: mode(base900, base100)(colorMode),
|
||||||
backgroundColor: mode(base50, base900)(colorMode),
|
backgroundColor: mode(base50, base900)(colorMode),
|
||||||
@ -35,11 +35,23 @@ export default function IAIMantineTextInput(props: IAIMantineTextInputProps) {
|
|||||||
},
|
},
|
||||||
label: {
|
label: {
|
||||||
color: mode(base700, base300)(colorMode),
|
color: mode(base700, base300)(colorMode),
|
||||||
fontWeight: 'normal',
|
fontWeight: 'normal' as const,
|
||||||
marginBottom: 4,
|
marginBottom: 4,
|
||||||
},
|
},
|
||||||
})}
|
}),
|
||||||
{...rest}
|
[
|
||||||
/>
|
accent300,
|
||||||
|
accent500,
|
||||||
|
base100,
|
||||||
|
base200,
|
||||||
|
base300,
|
||||||
|
base50,
|
||||||
|
base700,
|
||||||
|
base800,
|
||||||
|
base900,
|
||||||
|
colorMode,
|
||||||
|
]
|
||||||
);
|
);
|
||||||
|
|
||||||
|
return <TextInput styles={stylesFunc} {...rest} />;
|
||||||
}
|
}
|
||||||
|
@ -98,20 +98,24 @@ const IAINumberInput = forwardRef((props: Props, ref) => {
|
|||||||
}
|
}
|
||||||
}, [value, valueAsString]);
|
}, [value, valueAsString]);
|
||||||
|
|
||||||
const handleOnChange = (v: string) => {
|
const handleOnChange = useCallback(
|
||||||
|
(v: string) => {
|
||||||
setValueAsString(v);
|
setValueAsString(v);
|
||||||
// This allows negatives and decimals e.g. '-123', `.5`, `-0.2`, etc.
|
// This allows negatives and decimals e.g. '-123', `.5`, `-0.2`, etc.
|
||||||
if (!v.match(numberStringRegex)) {
|
if (!v.match(numberStringRegex)) {
|
||||||
// Cast the value to number. Floor it if it should be an integer.
|
// Cast the value to number. Floor it if it should be an integer.
|
||||||
onChange(isInteger ? Math.floor(Number(v)) : Number(v));
|
onChange(isInteger ? Math.floor(Number(v)) : Number(v));
|
||||||
}
|
}
|
||||||
};
|
},
|
||||||
|
[isInteger, onChange]
|
||||||
|
);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Clicking the steppers allows the value to go outside bounds; we need to
|
* Clicking the steppers allows the value to go outside bounds; we need to
|
||||||
* clamp it on blur and floor it if needed.
|
* clamp it on blur and floor it if needed.
|
||||||
*/
|
*/
|
||||||
const handleBlur = (e: FocusEvent<HTMLInputElement>) => {
|
const handleBlur = useCallback(
|
||||||
|
(e: FocusEvent<HTMLInputElement>) => {
|
||||||
const clamped = clamp(
|
const clamped = clamp(
|
||||||
isInteger ? Math.floor(Number(e.target.value)) : Number(e.target.value),
|
isInteger ? Math.floor(Number(e.target.value)) : Number(e.target.value),
|
||||||
min,
|
min,
|
||||||
@ -119,7 +123,9 @@ const IAINumberInput = forwardRef((props: Props, ref) => {
|
|||||||
);
|
);
|
||||||
setValueAsString(String(clamped));
|
setValueAsString(String(clamped));
|
||||||
onChange(clamped);
|
onChange(clamped);
|
||||||
};
|
},
|
||||||
|
[isInteger, max, min, onChange]
|
||||||
|
);
|
||||||
|
|
||||||
const handleKeyDown = useCallback(
|
const handleKeyDown = useCallback(
|
||||||
(e: KeyboardEvent<HTMLInputElement>) => {
|
(e: KeyboardEvent<HTMLInputElement>) => {
|
||||||
|
@ -6,7 +6,7 @@ import {
|
|||||||
Tooltip,
|
Tooltip,
|
||||||
TooltipProps,
|
TooltipProps,
|
||||||
} from '@chakra-ui/react';
|
} from '@chakra-ui/react';
|
||||||
import { memo, MouseEvent } from 'react';
|
import { memo, MouseEvent, useCallback } from 'react';
|
||||||
import IAIOption from './IAIOption';
|
import IAIOption from './IAIOption';
|
||||||
|
|
||||||
type IAISelectProps = SelectProps & {
|
type IAISelectProps = SelectProps & {
|
||||||
@ -33,15 +33,16 @@ const IAISelect = (props: IAISelectProps) => {
|
|||||||
spaceEvenly,
|
spaceEvenly,
|
||||||
...rest
|
...rest
|
||||||
} = props;
|
} = props;
|
||||||
return (
|
const handleClick = useCallback((e: MouseEvent<HTMLDivElement>) => {
|
||||||
<FormControl
|
|
||||||
isDisabled={isDisabled}
|
|
||||||
onClick={(e: MouseEvent<HTMLDivElement>) => {
|
|
||||||
e.stopPropagation();
|
e.stopPropagation();
|
||||||
e.nativeEvent.stopImmediatePropagation();
|
e.nativeEvent.stopImmediatePropagation();
|
||||||
e.nativeEvent.stopPropagation();
|
e.nativeEvent.stopPropagation();
|
||||||
e.nativeEvent.cancelBubble = true;
|
e.nativeEvent.cancelBubble = true;
|
||||||
}}
|
}, []);
|
||||||
|
return (
|
||||||
|
<FormControl
|
||||||
|
isDisabled={isDisabled}
|
||||||
|
onClick={handleClick}
|
||||||
sx={
|
sx={
|
||||||
horizontal
|
horizontal
|
||||||
? {
|
? {
|
||||||
|
@ -186,6 +186,13 @@ const IAISlider = forwardRef((props: IAIFullSliderProps, ref) => {
|
|||||||
[dispatch]
|
[dispatch]
|
||||||
);
|
);
|
||||||
|
|
||||||
|
const handleMouseEnter = useCallback(() => setShowTooltip(true), []);
|
||||||
|
const handleMouseLeave = useCallback(() => setShowTooltip(false), []);
|
||||||
|
const handleStepperClick = useCallback(
|
||||||
|
() => onChange(Number(localInputValue)),
|
||||||
|
[localInputValue, onChange]
|
||||||
|
);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<FormControl
|
<FormControl
|
||||||
ref={ref}
|
ref={ref}
|
||||||
@ -219,8 +226,8 @@ const IAISlider = forwardRef((props: IAIFullSliderProps, ref) => {
|
|||||||
max={max}
|
max={max}
|
||||||
step={step}
|
step={step}
|
||||||
onChange={handleSliderChange}
|
onChange={handleSliderChange}
|
||||||
onMouseEnter={() => setShowTooltip(true)}
|
onMouseEnter={handleMouseEnter}
|
||||||
onMouseLeave={() => setShowTooltip(false)}
|
onMouseLeave={handleMouseLeave}
|
||||||
focusThumbOnChange={false}
|
focusThumbOnChange={false}
|
||||||
isDisabled={isDisabled}
|
isDisabled={isDisabled}
|
||||||
{...rest}
|
{...rest}
|
||||||
@ -332,12 +339,8 @@ const IAISlider = forwardRef((props: IAIFullSliderProps, ref) => {
|
|||||||
{...sliderNumberInputFieldProps}
|
{...sliderNumberInputFieldProps}
|
||||||
/>
|
/>
|
||||||
<NumberInputStepper {...sliderNumberInputStepperProps}>
|
<NumberInputStepper {...sliderNumberInputStepperProps}>
|
||||||
<NumberIncrementStepper
|
<NumberIncrementStepper onClick={handleStepperClick} />
|
||||||
onClick={() => onChange(Number(localInputValue))}
|
<NumberDecrementStepper onClick={handleStepperClick} />
|
||||||
/>
|
|
||||||
<NumberDecrementStepper
|
|
||||||
onClick={() => onChange(Number(localInputValue))}
|
|
||||||
/>
|
|
||||||
</NumberInputStepper>
|
</NumberInputStepper>
|
||||||
</NumberInput>
|
</NumberInput>
|
||||||
)}
|
)}
|
||||||
|
@ -146,16 +146,15 @@ const ImageUploader = (props: ImageUploaderProps) => {
|
|||||||
};
|
};
|
||||||
}, [inputRef]);
|
}, [inputRef]);
|
||||||
|
|
||||||
return (
|
const handleKeyDown = useCallback((e: KeyboardEvent) => {
|
||||||
<Box
|
|
||||||
{...getRootProps({ style: {} })}
|
|
||||||
onKeyDown={(e: KeyboardEvent) => {
|
|
||||||
// Bail out if user hits spacebar - do not open the uploader
|
// Bail out if user hits spacebar - do not open the uploader
|
||||||
if (e.key === ' ') {
|
if (e.key === ' ') {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
}}
|
}, []);
|
||||||
>
|
|
||||||
|
return (
|
||||||
|
<Box {...getRootProps({ style: {} })} onKeyDown={handleKeyDown}>
|
||||||
<input {...getInputProps()} />
|
<input {...getInputProps()} />
|
||||||
{children}
|
{children}
|
||||||
<AnimatePresence>
|
<AnimatePresence>
|
||||||
|
@ -1,23 +0,0 @@
|
|||||||
import { Flex, Icon } from '@chakra-ui/react';
|
|
||||||
import { memo } from 'react';
|
|
||||||
import { FaImage } from 'react-icons/fa';
|
|
||||||
|
|
||||||
const SelectImagePlaceholder = () => {
|
|
||||||
return (
|
|
||||||
<Flex
|
|
||||||
sx={{
|
|
||||||
w: 'full',
|
|
||||||
h: 'full',
|
|
||||||
// bg: 'base.800',
|
|
||||||
borderRadius: 'base',
|
|
||||||
alignItems: 'center',
|
|
||||||
justifyContent: 'center',
|
|
||||||
aspectRatio: '1/1',
|
|
||||||
}}
|
|
||||||
>
|
|
||||||
<Icon color="base.400" boxSize={32} as={FaImage}></Icon>
|
|
||||||
</Flex>
|
|
||||||
);
|
|
||||||
};
|
|
||||||
|
|
||||||
export default memo(SelectImagePlaceholder);
|
|
@ -1,24 +0,0 @@
|
|||||||
import { useBreakpoint } from '@chakra-ui/react';
|
|
||||||
|
|
||||||
export default function useResolution():
|
|
||||||
| 'mobile'
|
|
||||||
| 'tablet'
|
|
||||||
| 'desktop'
|
|
||||||
| 'unknown' {
|
|
||||||
const breakpointValue = useBreakpoint();
|
|
||||||
|
|
||||||
const mobileResolutions = ['base', 'sm'];
|
|
||||||
const tabletResolutions = ['md', 'lg'];
|
|
||||||
const desktopResolutions = ['xl', '2xl'];
|
|
||||||
|
|
||||||
if (mobileResolutions.includes(breakpointValue)) {
|
|
||||||
return 'mobile';
|
|
||||||
}
|
|
||||||
if (tabletResolutions.includes(breakpointValue)) {
|
|
||||||
return 'tablet';
|
|
||||||
}
|
|
||||||
if (desktopResolutions.includes(breakpointValue)) {
|
|
||||||
return 'desktop';
|
|
||||||
}
|
|
||||||
return 'unknown';
|
|
||||||
}
|
|
@ -1,7 +0,0 @@
|
|||||||
import dateFormat from 'dateformat';
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Get a `now` timestamp with 1s precision, formatted as ISO datetime.
|
|
||||||
*/
|
|
||||||
export const getTimestamp = () =>
|
|
||||||
dateFormat(new Date(), `yyyy-mm-dd'T'HH:MM:ss:lo`);
|
|
@ -1,71 +0,0 @@
|
|||||||
// TODO: Restore variations
|
|
||||||
// Support code from v2.3 in here.
|
|
||||||
|
|
||||||
// export const stringToSeedWeights = (
|
|
||||||
// string: string
|
|
||||||
// ): InvokeAI.SeedWeights | boolean => {
|
|
||||||
// const stringPairs = string.split(',');
|
|
||||||
// const arrPairs = stringPairs.map((p) => p.split(':'));
|
|
||||||
// const pairs = arrPairs.map((p: Array<string>): InvokeAI.SeedWeightPair => {
|
|
||||||
// return { seed: Number(p[0]), weight: Number(p[1]) };
|
|
||||||
// });
|
|
||||||
|
|
||||||
// if (!validateSeedWeights(pairs)) {
|
|
||||||
// return false;
|
|
||||||
// }
|
|
||||||
|
|
||||||
// return pairs;
|
|
||||||
// };
|
|
||||||
|
|
||||||
// export const validateSeedWeights = (
|
|
||||||
// seedWeights: InvokeAI.SeedWeights | string
|
|
||||||
// ): boolean => {
|
|
||||||
// return typeof seedWeights === 'string'
|
|
||||||
// ? Boolean(stringToSeedWeights(seedWeights))
|
|
||||||
// : Boolean(
|
|
||||||
// seedWeights.length &&
|
|
||||||
// !seedWeights.some((pair: InvokeAI.SeedWeightPair) => {
|
|
||||||
// const { seed, weight } = pair;
|
|
||||||
// const isSeedValid = !isNaN(parseInt(seed.toString(), 10));
|
|
||||||
// const isWeightValid =
|
|
||||||
// !isNaN(parseInt(weight.toString(), 10)) &&
|
|
||||||
// weight >= 0 &&
|
|
||||||
// weight <= 1;
|
|
||||||
// return !(isSeedValid && isWeightValid);
|
|
||||||
// })
|
|
||||||
// );
|
|
||||||
// };
|
|
||||||
|
|
||||||
// export const seedWeightsToString = (
|
|
||||||
// seedWeights: InvokeAI.SeedWeights
|
|
||||||
// ): string => {
|
|
||||||
// return seedWeights.reduce((acc, pair, i, arr) => {
|
|
||||||
// const { seed, weight } = pair;
|
|
||||||
// acc += `${seed}:${weight}`;
|
|
||||||
// if (i !== arr.length - 1) {
|
|
||||||
// acc += ',';
|
|
||||||
// }
|
|
||||||
// return acc;
|
|
||||||
// }, '');
|
|
||||||
// };
|
|
||||||
|
|
||||||
// export const seedWeightsToArray = (
|
|
||||||
// seedWeights: InvokeAI.SeedWeights
|
|
||||||
// ): Array<Array<number>> => {
|
|
||||||
// return seedWeights.map((pair: InvokeAI.SeedWeightPair) => [
|
|
||||||
// pair.seed,
|
|
||||||
// pair.weight,
|
|
||||||
// ]);
|
|
||||||
// };
|
|
||||||
|
|
||||||
// export const stringToSeedWeightsArray = (
|
|
||||||
// string: string
|
|
||||||
// ): Array<Array<number>> => {
|
|
||||||
// const stringPairs = string.split(',');
|
|
||||||
// const arrPairs = stringPairs.map((p) => p.split(':'));
|
|
||||||
// return arrPairs.map(
|
|
||||||
// (p: Array<string>): Array<number> => [parseInt(p[0], 10), parseFloat(p[1])]
|
|
||||||
// );
|
|
||||||
// };
|
|
||||||
|
|
||||||
export default {};
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user