mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(api): chore: pydantic & fastapi upgrade
Upgrade pydantic and fastapi to latest. - pydantic~=2.4.2 - fastapi~=103.2 - fastapi-events~=0.9.1 **Big Changes** There are a number of logic changes needed to support pydantic v2. Most changes are very simple, like using the new methods to serialized and deserialize models, but there are a few more complex changes. **Invocations** The biggest change relates to invocation creation, instantiation and validation. Because pydantic v2 moves all validation logic into the rust pydantic-core, we may no longer directly stick our fingers into the validation pie. Previously, we (ab)used models and fields to allow invocation fields to be optional at instantiation, but required when `invoke()` is called. We directly manipulated the fields and invocation models when calling `invoke()`. With pydantic v2, this is much more involved. Changes to the python wrapper do not propagate down to the rust validation logic - you have to rebuild the model. This causes problem with concurrent access to the invocation classes and is not a free operation. This logic has been totally refactored and we do not need to change the model any more. The details are in `baseinvocation.py`, in the `InputField` function and `BaseInvocation.invoke_internal()` method. In the end, this implementation is cleaner. **Invocation Fields** In pydantic v2, you can no longer directly add or remove fields from a model. Previously, we did this to add the `type` field to invocations. **Invocation Decorators** With pydantic v2, we instead use the imperative `create_model()` API to create a new model with the additional field. This is done in `baseinvocation.py` in the `invocation()` wrapper. A similar technique is used for `invocation_output()`. **Minor Changes** There are a number of minor changes around the pydantic v2 models API. **Protected `model_` Namespace** All models' pydantic-provided methods and attributes are prefixed with `model_` and this is considered a protected namespace. This causes some conflict, because "model" means something to us, and we have a ton of pydantic models with attributes starting with "model_". Forunately, there are no direct conflicts. However, in any pydantic model where we define an attribute or method that starts with "model_", we must tell set the protected namespaces to an empty tuple. ```py class IPAdapterModelField(BaseModel): model_name: str = Field(description="Name of the IP-Adapter model") base_model: BaseModelType = Field(description="Base model") model_config = ConfigDict(protected_namespaces=()) ``` **Model Serialization** Pydantic models no longer have `Model.dict()` or `Model.json()`. Instead, we use `Model.model_dump()` or `Model.model_dump_json()`. **Model Deserialization** Pydantic models no longer have `Model.parse_obj()` or `Model.parse_raw()`, and there are no `parse_raw_as()` or `parse_obj_as()` functions. Instead, you need to create a `TypeAdapter` object to parse python objects or JSON into a model. ```py adapter_graph = TypeAdapter(Graph) deserialized_graph_from_json = adapter_graph.validate_json(graph_json) deserialized_graph_from_dict = adapter_graph.validate_python(graph_dict) ``` **Field Customisation** Pydantic `Field`s no longer accept arbitrary args. Now, you must put all additional arbitrary args in a `json_schema_extra` arg on the field. **Schema Customisation** FastAPI and pydantic schema generation now follows the OpenAPI version 3.1 spec. This necessitates two changes: - Our schema customization logic has been revised - Schema parsing to build node templates has been revised The specific aren't important, but this does present additional surface area for bugs. **Performance Improvements** Pydantic v2 is a full rewrite with a rust backend. This offers a substantial performance improvement (pydantic claims 5x to 50x depending on the task). We'll notice this the most during serialization and deserialization of sessions/graphs, which happens very very often - a couple times per node. I haven't done any benchmarks, but anecdotally, graph execution is much faster. Also, very larges graphs - like with massive iterators - are much, much faster.
This commit is contained in:
parent
19c5435332
commit
c238a7f18b
@ -42,7 +42,7 @@ async def upload_image(
|
||||
crop_visible: Optional[bool] = Query(default=False, description="Whether to crop the image"),
|
||||
) -> ImageDTO:
|
||||
"""Uploads an image"""
|
||||
if not file.content_type.startswith("image"):
|
||||
if not file.content_type or not file.content_type.startswith("image"):
|
||||
raise HTTPException(status_code=415, detail="Not an image")
|
||||
|
||||
contents = await file.read()
|
||||
|
@ -2,11 +2,11 @@
|
||||
|
||||
|
||||
import pathlib
|
||||
from typing import List, Literal, Optional, Union
|
||||
from typing import Annotated, List, Literal, Optional, Union
|
||||
|
||||
from fastapi import Body, Path, Query, Response
|
||||
from fastapi.routing import APIRouter
|
||||
from pydantic import BaseModel, parse_obj_as
|
||||
from pydantic import BaseModel, ConfigDict, Field, TypeAdapter
|
||||
from starlette.exceptions import HTTPException
|
||||
|
||||
from invokeai.backend import BaseModelType, ModelType
|
||||
@ -23,8 +23,14 @@ from ..dependencies import ApiDependencies
|
||||
models_router = APIRouter(prefix="/v1/models", tags=["models"])
|
||||
|
||||
UpdateModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||
update_models_response_adapter = TypeAdapter(UpdateModelResponse)
|
||||
|
||||
ImportModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||
import_models_response_adapter = TypeAdapter(ImportModelResponse)
|
||||
|
||||
ConvertModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||
convert_models_response_adapter = TypeAdapter(ConvertModelResponse)
|
||||
|
||||
MergeModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||
ImportModelAttributes = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||
|
||||
@ -32,6 +38,11 @@ ImportModelAttributes = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||
class ModelsList(BaseModel):
|
||||
models: list[Union[tuple(OPENAPI_MODEL_CONFIGS)]]
|
||||
|
||||
model_config = ConfigDict(use_enum_values=True)
|
||||
|
||||
|
||||
models_list_adapter = TypeAdapter(ModelsList)
|
||||
|
||||
|
||||
@models_router.get(
|
||||
"/",
|
||||
@ -49,7 +60,7 @@ async def list_models(
|
||||
models_raw.extend(ApiDependencies.invoker.services.model_manager.list_models(base_model, model_type))
|
||||
else:
|
||||
models_raw = ApiDependencies.invoker.services.model_manager.list_models(None, model_type)
|
||||
models = parse_obj_as(ModelsList, {"models": models_raw})
|
||||
models = models_list_adapter.validate_python({"models": models_raw})
|
||||
return models
|
||||
|
||||
|
||||
@ -105,11 +116,14 @@ async def update_model(
|
||||
info.path = new_info.get("path")
|
||||
|
||||
# replace empty string values with None/null to avoid phenomenon of vae: ''
|
||||
info_dict = info.dict()
|
||||
info_dict = info.model_dump()
|
||||
info_dict = {x: info_dict[x] if info_dict[x] else None for x in info_dict.keys()}
|
||||
|
||||
ApiDependencies.invoker.services.model_manager.update_model(
|
||||
model_name=model_name, base_model=base_model, model_type=model_type, model_attributes=info_dict
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
model_attributes=info_dict,
|
||||
)
|
||||
|
||||
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
||||
@ -117,7 +131,7 @@ async def update_model(
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
)
|
||||
model_response = parse_obj_as(UpdateModelResponse, model_raw)
|
||||
model_response = update_models_response_adapter.validate_python(model_raw)
|
||||
except ModelNotFoundException as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except ValueError as e:
|
||||
@ -159,7 +173,8 @@ async def import_model(
|
||||
|
||||
try:
|
||||
installed_models = ApiDependencies.invoker.services.model_manager.heuristic_import(
|
||||
items_to_import=items_to_import, prediction_type_helper=lambda x: prediction_types.get(prediction_type)
|
||||
items_to_import=items_to_import,
|
||||
prediction_type_helper=lambda x: prediction_types.get(prediction_type),
|
||||
)
|
||||
info = installed_models.get(location)
|
||||
|
||||
@ -171,7 +186,7 @@ async def import_model(
|
||||
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
||||
model_name=info.name, base_model=info.base_model, model_type=info.model_type
|
||||
)
|
||||
return parse_obj_as(ImportModelResponse, model_raw)
|
||||
return import_models_response_adapter.validate_python(model_raw)
|
||||
|
||||
except ModelNotFoundException as e:
|
||||
logger.error(str(e))
|
||||
@ -205,13 +220,18 @@ async def add_model(
|
||||
|
||||
try:
|
||||
ApiDependencies.invoker.services.model_manager.add_model(
|
||||
info.model_name, info.base_model, info.model_type, model_attributes=info.dict()
|
||||
info.model_name,
|
||||
info.base_model,
|
||||
info.model_type,
|
||||
model_attributes=info.model_dump(),
|
||||
)
|
||||
logger.info(f"Successfully added {info.model_name}")
|
||||
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
||||
model_name=info.model_name, base_model=info.base_model, model_type=info.model_type
|
||||
model_name=info.model_name,
|
||||
base_model=info.base_model,
|
||||
model_type=info.model_type,
|
||||
)
|
||||
return parse_obj_as(ImportModelResponse, model_raw)
|
||||
return import_models_response_adapter.validate_python(model_raw)
|
||||
except ModelNotFoundException as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
@ -223,7 +243,10 @@ async def add_model(
|
||||
@models_router.delete(
|
||||
"/{base_model}/{model_type}/{model_name}",
|
||||
operation_id="del_model",
|
||||
responses={204: {"description": "Model deleted successfully"}, 404: {"description": "Model not found"}},
|
||||
responses={
|
||||
204: {"description": "Model deleted successfully"},
|
||||
404: {"description": "Model not found"},
|
||||
},
|
||||
status_code=204,
|
||||
response_model=None,
|
||||
)
|
||||
@ -279,7 +302,7 @@ async def convert_model(
|
||||
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
||||
model_name, base_model=base_model, model_type=model_type
|
||||
)
|
||||
response = parse_obj_as(ConvertModelResponse, model_raw)
|
||||
response = convert_models_response_adapter.validate_python(model_raw)
|
||||
except ModelNotFoundException as e:
|
||||
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found: {str(e)}")
|
||||
except ValueError as e:
|
||||
@ -302,7 +325,8 @@ async def search_for_models(
|
||||
) -> List[pathlib.Path]:
|
||||
if not search_path.is_dir():
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"The search path '{search_path}' does not exist or is not directory"
|
||||
status_code=404,
|
||||
detail=f"The search path '{search_path}' does not exist or is not directory",
|
||||
)
|
||||
return ApiDependencies.invoker.services.model_manager.search_for_models(search_path)
|
||||
|
||||
@ -337,6 +361,26 @@ async def sync_to_config() -> bool:
|
||||
return True
|
||||
|
||||
|
||||
# There's some weird pydantic-fastapi behaviour that requires this to be a separate class
|
||||
# TODO: After a few updates, see if it works inside the route operation handler?
|
||||
class MergeModelsBody(BaseModel):
|
||||
model_names: List[str] = Field(description="model name", min_length=2, max_length=3)
|
||||
merged_model_name: Optional[str] = Field(description="Name of destination model")
|
||||
alpha: Optional[float] = Field(description="Alpha weighting strength to apply to 2d and 3d models", default=0.5)
|
||||
interp: Optional[MergeInterpolationMethod] = Field(description="Interpolation method")
|
||||
force: Optional[bool] = Field(
|
||||
description="Force merging of models created with different versions of diffusers",
|
||||
default=False,
|
||||
)
|
||||
|
||||
merge_dest_directory: Optional[str] = Field(
|
||||
description="Save the merged model to the designated directory (with 'merged_model_name' appended)",
|
||||
default=None,
|
||||
)
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
|
||||
@models_router.put(
|
||||
"/merge/{base_model}",
|
||||
operation_id="merge_models",
|
||||
@ -349,31 +393,23 @@ async def sync_to_config() -> bool:
|
||||
response_model=MergeModelResponse,
|
||||
)
|
||||
async def merge_models(
|
||||
body: Annotated[MergeModelsBody, Body(description="Model configuration", embed=True)],
|
||||
base_model: BaseModelType = Path(description="Base model"),
|
||||
model_names: List[str] = Body(description="model name", min_items=2, max_items=3),
|
||||
merged_model_name: Optional[str] = Body(description="Name of destination model"),
|
||||
alpha: Optional[float] = Body(description="Alpha weighting strength to apply to 2d and 3d models", default=0.5),
|
||||
interp: Optional[MergeInterpolationMethod] = Body(description="Interpolation method"),
|
||||
force: Optional[bool] = Body(
|
||||
description="Force merging of models created with different versions of diffusers", default=False
|
||||
),
|
||||
merge_dest_directory: Optional[str] = Body(
|
||||
description="Save the merged model to the designated directory (with 'merged_model_name' appended)",
|
||||
default=None,
|
||||
),
|
||||
) -> MergeModelResponse:
|
||||
"""Convert a checkpoint model into a diffusers model"""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
try:
|
||||
logger.info(f"Merging models: {model_names} into {merge_dest_directory or '<MODELS>'}/{merged_model_name}")
|
||||
dest = pathlib.Path(merge_dest_directory) if merge_dest_directory else None
|
||||
logger.info(
|
||||
f"Merging models: {body.model_names} into {body.merge_dest_directory or '<MODELS>'}/{body.merged_model_name}"
|
||||
)
|
||||
dest = pathlib.Path(body.merge_dest_directory) if body.merge_dest_directory else None
|
||||
result = ApiDependencies.invoker.services.model_manager.merge_models(
|
||||
model_names,
|
||||
base_model,
|
||||
merged_model_name=merged_model_name or "+".join(model_names),
|
||||
alpha=alpha,
|
||||
interp=interp,
|
||||
force=force,
|
||||
model_names=body.model_names,
|
||||
base_model=base_model,
|
||||
merged_model_name=body.merged_model_name or "+".join(body.model_names),
|
||||
alpha=body.alpha,
|
||||
interp=body.interp,
|
||||
force=body.force,
|
||||
merge_dest_directory=dest,
|
||||
)
|
||||
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
||||
@ -381,9 +417,12 @@ async def merge_models(
|
||||
base_model=base_model,
|
||||
model_type=ModelType.Main,
|
||||
)
|
||||
response = parse_obj_as(ConvertModelResponse, model_raw)
|
||||
response = convert_models_response_adapter.validate_python(model_raw)
|
||||
except ModelNotFoundException:
|
||||
raise HTTPException(status_code=404, detail=f"One or more of the models '{model_names}' not found")
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"One or more of the models '{body.model_names}' not found",
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
return response
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import Optional
|
||||
from typing import Optional, Union
|
||||
|
||||
from dynamicprompts.generators import CombinatorialPromptGenerator, RandomPromptGenerator
|
||||
from fastapi import Body
|
||||
@ -27,6 +27,7 @@ async def parse_dynamicprompts(
|
||||
combinatorial: bool = Body(default=True, description="Whether to use the combinatorial generator"),
|
||||
) -> DynamicPromptsResponse:
|
||||
"""Creates a batch process"""
|
||||
generator: Union[RandomPromptGenerator, CombinatorialPromptGenerator]
|
||||
try:
|
||||
error: Optional[str] = None
|
||||
if combinatorial:
|
||||
|
@ -22,7 +22,7 @@ if True: # hack to make flake8 happy with imports coming after setting up the c
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from fastapi_events.handlers.local import local_handler
|
||||
from fastapi_events.middleware import EventHandlerASGIMiddleware
|
||||
from pydantic.schema import schema
|
||||
from pydantic.json_schema import models_json_schema
|
||||
|
||||
# noinspection PyUnresolvedReferences
|
||||
import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import)
|
||||
@ -31,7 +31,7 @@ if True: # hack to make flake8 happy with imports coming after setting up the c
|
||||
|
||||
from ..backend.util.logging import InvokeAILogger
|
||||
from .api.dependencies import ApiDependencies
|
||||
from .api.routers import app_info, board_images, boards, images, models, session_queue, sessions, utilities
|
||||
from .api.routers import app_info, board_images, boards, images, models, session_queue, utilities
|
||||
from .api.sockets import SocketIO
|
||||
from .invocations.baseinvocation import BaseInvocation, UIConfigBase, _InputField, _OutputField
|
||||
|
||||
@ -51,7 +51,7 @@ mimetypes.add_type("text/css", ".css")
|
||||
|
||||
# Create the app
|
||||
# TODO: create this all in a method so configuration/etc. can be passed in?
|
||||
app = FastAPI(title="Invoke AI", docs_url=None, redoc_url=None)
|
||||
app = FastAPI(title="Invoke AI", docs_url=None, redoc_url=None, separate_input_output_schemas=False)
|
||||
|
||||
# Add event handler
|
||||
event_handler_id: int = id(app)
|
||||
@ -63,18 +63,18 @@ app.add_middleware(
|
||||
|
||||
socket_io = SocketIO(app)
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=app_config.allow_origins,
|
||||
allow_credentials=app_config.allow_credentials,
|
||||
allow_methods=app_config.allow_methods,
|
||||
allow_headers=app_config.allow_headers,
|
||||
)
|
||||
|
||||
|
||||
# Add startup event to load dependencies
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=app_config.allow_origins,
|
||||
allow_credentials=app_config.allow_credentials,
|
||||
allow_methods=app_config.allow_methods,
|
||||
allow_headers=app_config.allow_headers,
|
||||
)
|
||||
|
||||
ApiDependencies.initialize(config=app_config, event_handler_id=event_handler_id, logger=logger)
|
||||
|
||||
|
||||
@ -85,12 +85,7 @@ async def shutdown_event():
|
||||
|
||||
|
||||
# Include all routers
|
||||
# TODO: REMOVE
|
||||
# app.include_router(
|
||||
# invocation.invocation_router,
|
||||
# prefix = '/api')
|
||||
|
||||
app.include_router(sessions.session_router, prefix="/api")
|
||||
# app.include_router(sessions.session_router, prefix="/api")
|
||||
|
||||
app.include_router(utilities.utilities_router, prefix="/api")
|
||||
|
||||
@ -117,6 +112,7 @@ def custom_openapi():
|
||||
description="An API for invoking AI image operations",
|
||||
version="1.0.0",
|
||||
routes=app.routes,
|
||||
separate_input_output_schemas=False, # https://fastapi.tiangolo.com/how-to/separate-openapi-schemas/
|
||||
)
|
||||
|
||||
# Add all outputs
|
||||
@ -127,29 +123,32 @@ def custom_openapi():
|
||||
output_type = signature(invoker.invoke).return_annotation
|
||||
output_types.add(output_type)
|
||||
|
||||
output_schemas = schema(output_types, ref_prefix="#/components/schemas/")
|
||||
for schema_key, output_schema in output_schemas["definitions"].items():
|
||||
output_schema["class"] = "output"
|
||||
openapi_schema["components"]["schemas"][schema_key] = output_schema
|
||||
|
||||
output_schemas = models_json_schema(
|
||||
models=[(o, "serialization") for o in output_types], ref_template="#/components/schemas/{model}"
|
||||
)
|
||||
for schema_key, output_schema in output_schemas[1]["$defs"].items():
|
||||
# TODO: note that we assume the schema_key here is the TYPE.__name__
|
||||
# This could break in some cases, figure out a better way to do it
|
||||
output_type_titles[schema_key] = output_schema["title"]
|
||||
|
||||
# Add Node Editor UI helper schemas
|
||||
ui_config_schemas = schema([UIConfigBase, _InputField, _OutputField], ref_prefix="#/components/schemas/")
|
||||
for schema_key, ui_config_schema in ui_config_schemas["definitions"].items():
|
||||
ui_config_schemas = models_json_schema(
|
||||
[(UIConfigBase, "serialization"), (_InputField, "serialization"), (_OutputField, "serialization")],
|
||||
ref_template="#/components/schemas/{model}",
|
||||
)
|
||||
for schema_key, ui_config_schema in ui_config_schemas[1]["$defs"].items():
|
||||
openapi_schema["components"]["schemas"][schema_key] = ui_config_schema
|
||||
|
||||
# Add a reference to the output type to additionalProperties of the invoker schema
|
||||
for invoker in all_invocations:
|
||||
invoker_name = invoker.__name__
|
||||
output_type = signature(invoker.invoke).return_annotation
|
||||
output_type = signature(obj=invoker.invoke).return_annotation
|
||||
output_type_title = output_type_titles[output_type.__name__]
|
||||
invoker_schema = openapi_schema["components"]["schemas"][invoker_name]
|
||||
invoker_schema = openapi_schema["components"]["schemas"][f"{invoker_name}"]
|
||||
outputs_ref = {"$ref": f"#/components/schemas/{output_type_title}"}
|
||||
invoker_schema["output"] = outputs_ref
|
||||
invoker_schema["class"] = "invocation"
|
||||
openapi_schema["components"]["schemas"][f"{output_type_title}"]["class"] = "output"
|
||||
|
||||
from invokeai.backend.model_management.models import get_model_config_enums
|
||||
|
||||
@ -172,7 +171,7 @@ def custom_openapi():
|
||||
return app.openapi_schema
|
||||
|
||||
|
||||
app.openapi = custom_openapi
|
||||
app.openapi = custom_openapi # type: ignore [method-assign] # this is a valid assignment
|
||||
|
||||
# Override API doc favicons
|
||||
app.mount("/static", StaticFiles(directory=Path(web_dir.__path__[0], "static/dream_web")), name="static")
|
||||
|
@ -24,8 +24,8 @@ def add_field_argument(command_parser, name: str, field, default_override=None):
|
||||
if field.default_factory is None
|
||||
else field.default_factory()
|
||||
)
|
||||
if get_origin(field.type_) == Literal:
|
||||
allowed_values = get_args(field.type_)
|
||||
if get_origin(field.annotation) == Literal:
|
||||
allowed_values = get_args(field.annotation)
|
||||
allowed_types = set()
|
||||
for val in allowed_values:
|
||||
allowed_types.add(type(val))
|
||||
@ -38,15 +38,15 @@ def add_field_argument(command_parser, name: str, field, default_override=None):
|
||||
type=field_type,
|
||||
default=default,
|
||||
choices=allowed_values,
|
||||
help=field.field_info.description,
|
||||
help=field.description,
|
||||
)
|
||||
else:
|
||||
command_parser.add_argument(
|
||||
f"--{name}",
|
||||
dest=name,
|
||||
type=field.type_,
|
||||
type=field.annotation,
|
||||
default=default,
|
||||
help=field.field_info.description,
|
||||
help=field.description,
|
||||
)
|
||||
|
||||
|
||||
@ -142,7 +142,6 @@ class BaseCommand(ABC, BaseModel):
|
||||
"""A CLI command"""
|
||||
|
||||
# All commands must include a type name like this:
|
||||
# type: Literal['your_command_name'] = 'your_command_name'
|
||||
|
||||
@classmethod
|
||||
def get_all_subclasses(cls):
|
||||
|
@ -7,28 +7,16 @@ import re
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from inspect import signature
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
AbstractSet,
|
||||
Any,
|
||||
Callable,
|
||||
ClassVar,
|
||||
Literal,
|
||||
Mapping,
|
||||
Optional,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
get_args,
|
||||
get_type_hints,
|
||||
)
|
||||
from types import UnionType
|
||||
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Iterable, Literal, Optional, Type, TypeVar, Union
|
||||
|
||||
import semver
|
||||
from pydantic import BaseModel, Field, validator
|
||||
from pydantic.fields import ModelField, Undefined
|
||||
from pydantic.typing import NoArgAnyCallable
|
||||
from pydantic import BaseModel, ConfigDict, Field, create_model, field_validator
|
||||
from pydantic.fields import _Unset
|
||||
from pydantic_core import PydanticUndefined
|
||||
|
||||
from invokeai.app.services.config.config_default import InvokeAIAppConfig
|
||||
from invokeai.app.util.misc import uuid_string
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..services.invocation_services import InvocationServices
|
||||
@ -211,6 +199,11 @@ class _InputField(BaseModel):
|
||||
ui_choice_labels: Optional[dict[str, str]]
|
||||
item_default: Optional[Any]
|
||||
|
||||
model_config = ConfigDict(
|
||||
validate_assignment=True,
|
||||
json_schema_serialization_defaults_required=True,
|
||||
)
|
||||
|
||||
|
||||
class _OutputField(BaseModel):
|
||||
"""
|
||||
@ -224,34 +217,36 @@ class _OutputField(BaseModel):
|
||||
ui_type: Optional[UIType]
|
||||
ui_order: Optional[int]
|
||||
|
||||
model_config = ConfigDict(
|
||||
validate_assignment=True,
|
||||
json_schema_serialization_defaults_required=True,
|
||||
)
|
||||
|
||||
|
||||
def get_type(klass: BaseModel) -> str:
|
||||
"""Helper function to get an invocation or invocation output's type. This is the default value of the `type` field."""
|
||||
return klass.model_fields["type"].default
|
||||
|
||||
|
||||
def InputField(
|
||||
*args: Any,
|
||||
default: Any = Undefined,
|
||||
default_factory: Optional[NoArgAnyCallable] = None,
|
||||
alias: Optional[str] = None,
|
||||
title: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
exclude: Optional[Union[AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any]] = None,
|
||||
include: Optional[Union[AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any]] = None,
|
||||
const: Optional[bool] = None,
|
||||
gt: Optional[float] = None,
|
||||
ge: Optional[float] = None,
|
||||
lt: Optional[float] = None,
|
||||
le: Optional[float] = None,
|
||||
multiple_of: Optional[float] = None,
|
||||
allow_inf_nan: Optional[bool] = None,
|
||||
max_digits: Optional[int] = None,
|
||||
decimal_places: Optional[int] = None,
|
||||
min_items: Optional[int] = None,
|
||||
max_items: Optional[int] = None,
|
||||
unique_items: Optional[bool] = None,
|
||||
min_length: Optional[int] = None,
|
||||
max_length: Optional[int] = None,
|
||||
allow_mutation: bool = True,
|
||||
regex: Optional[str] = None,
|
||||
discriminator: Optional[str] = None,
|
||||
repr: bool = True,
|
||||
# copied from pydantic's Field
|
||||
default: Any = _Unset,
|
||||
default_factory: Callable[[], Any] | None = _Unset,
|
||||
title: str | None = _Unset,
|
||||
description: str | None = _Unset,
|
||||
pattern: str | None = _Unset,
|
||||
strict: bool | None = _Unset,
|
||||
gt: float | None = _Unset,
|
||||
ge: float | None = _Unset,
|
||||
lt: float | None = _Unset,
|
||||
le: float | None = _Unset,
|
||||
multiple_of: float | None = _Unset,
|
||||
allow_inf_nan: bool | None = _Unset,
|
||||
max_digits: int | None = _Unset,
|
||||
decimal_places: int | None = _Unset,
|
||||
min_length: int | None = _Unset,
|
||||
max_length: int | None = _Unset,
|
||||
# custom
|
||||
input: Input = Input.Any,
|
||||
ui_type: Optional[UIType] = None,
|
||||
ui_component: Optional[UIComponent] = None,
|
||||
@ -259,7 +254,6 @@ def InputField(
|
||||
ui_order: Optional[int] = None,
|
||||
ui_choice_labels: Optional[dict[str, str]] = None,
|
||||
item_default: Optional[Any] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""
|
||||
Creates an input field for an invocation.
|
||||
@ -289,18 +283,26 @@ def InputField(
|
||||
: param int ui_order: [None] Specifies the order in which this field should be rendered in the UI. \
|
||||
|
||||
: param bool item_default: [None] Specifies the default item value, if this is a collection input. \
|
||||
Ignored for non-collection fields..
|
||||
Ignored for non-collection fields.
|
||||
"""
|
||||
return Field(
|
||||
*args,
|
||||
|
||||
json_schema_extra_: dict[str, Any] = dict(
|
||||
input=input,
|
||||
ui_type=ui_type,
|
||||
ui_component=ui_component,
|
||||
ui_hidden=ui_hidden,
|
||||
ui_order=ui_order,
|
||||
item_default=item_default,
|
||||
ui_choice_labels=ui_choice_labels,
|
||||
)
|
||||
|
||||
field_args = dict(
|
||||
default=default,
|
||||
default_factory=default_factory,
|
||||
alias=alias,
|
||||
title=title,
|
||||
description=description,
|
||||
exclude=exclude,
|
||||
include=include,
|
||||
const=const,
|
||||
pattern=pattern,
|
||||
strict=strict,
|
||||
gt=gt,
|
||||
ge=ge,
|
||||
lt=lt,
|
||||
@ -309,57 +311,92 @@ def InputField(
|
||||
allow_inf_nan=allow_inf_nan,
|
||||
max_digits=max_digits,
|
||||
decimal_places=decimal_places,
|
||||
min_items=min_items,
|
||||
max_items=max_items,
|
||||
unique_items=unique_items,
|
||||
min_length=min_length,
|
||||
max_length=max_length,
|
||||
allow_mutation=allow_mutation,
|
||||
regex=regex,
|
||||
discriminator=discriminator,
|
||||
repr=repr,
|
||||
input=input,
|
||||
ui_type=ui_type,
|
||||
ui_component=ui_component,
|
||||
ui_hidden=ui_hidden,
|
||||
ui_order=ui_order,
|
||||
item_default=item_default,
|
||||
ui_choice_labels=ui_choice_labels,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
"""
|
||||
Invocation definitions have their fields typed correctly for their `invoke()` functions.
|
||||
This typing is often more specific than the actual invocation definition requires, because
|
||||
fields may have values provided only by connections.
|
||||
|
||||
For example, consider an ResizeImageInvocation with an `image: ImageField` field.
|
||||
|
||||
`image` is required during the call to `invoke()`, but when the python class is instantiated,
|
||||
the field may not be present. This is fine, because that image field will be provided by a
|
||||
an ancestor node that outputs the image.
|
||||
|
||||
So we'd like to type that `image` field as `Optional[ImageField]`. If we do that, however, then
|
||||
we need to handle a lot of extra logic in the `invoke()` function to check if the field has a
|
||||
value or not. This is very tedious.
|
||||
|
||||
Ideally, the invocation definition would be able to specify that the field is required during
|
||||
invocation, but optional during instantiation. So the field would be typed as `image: ImageField`,
|
||||
but when calling the `invoke()` function, we raise an error if the field is not present.
|
||||
|
||||
To do this, we need to do a bit of fanagling to make the pydantic field optional, and then do
|
||||
extra validation when calling `invoke()`.
|
||||
|
||||
There is some additional logic here to cleaning create the pydantic field via the wrapper.
|
||||
"""
|
||||
|
||||
# Filter out field args not provided
|
||||
provided_args = {k: v for (k, v) in field_args.items() if v is not PydanticUndefined}
|
||||
|
||||
if (default is not PydanticUndefined) and (default_factory is not PydanticUndefined):
|
||||
raise ValueError("Cannot specify both default and default_factory")
|
||||
|
||||
# because we are manually making fields optional, we need to store the original required bool for reference later
|
||||
if default is PydanticUndefined and default_factory is PydanticUndefined:
|
||||
json_schema_extra_.update(dict(orig_required=True))
|
||||
else:
|
||||
json_schema_extra_.update(dict(orig_required=False))
|
||||
|
||||
# make Input.Any and Input.Connection fields optional, providing None as a default if the field doesn't already have one
|
||||
if (input is Input.Any or input is Input.Connection) and default_factory is PydanticUndefined:
|
||||
default_ = None if default is PydanticUndefined else default
|
||||
provided_args.update(dict(default=default_))
|
||||
if default is not PydanticUndefined:
|
||||
# before invoking, we'll grab the original default value and set it on the field if the field wasn't provided a value
|
||||
json_schema_extra_.update(dict(default=default))
|
||||
json_schema_extra_.update(dict(orig_default=default))
|
||||
elif default is not PydanticUndefined and default_factory is PydanticUndefined:
|
||||
default_ = default
|
||||
provided_args.update(dict(default=default_))
|
||||
json_schema_extra_.update(dict(orig_default=default_))
|
||||
elif default_factory is not PydanticUndefined:
|
||||
provided_args.update(dict(default_factory=default_factory))
|
||||
# TODO: cannot serialize default_factory...
|
||||
# json_schema_extra_.update(dict(orig_default_factory=default_factory))
|
||||
|
||||
return Field(
|
||||
**provided_args,
|
||||
json_schema_extra=json_schema_extra_,
|
||||
)
|
||||
|
||||
|
||||
def OutputField(
|
||||
*args: Any,
|
||||
default: Any = Undefined,
|
||||
default_factory: Optional[NoArgAnyCallable] = None,
|
||||
alias: Optional[str] = None,
|
||||
title: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
exclude: Optional[Union[AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any]] = None,
|
||||
include: Optional[Union[AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any]] = None,
|
||||
const: Optional[bool] = None,
|
||||
gt: Optional[float] = None,
|
||||
ge: Optional[float] = None,
|
||||
lt: Optional[float] = None,
|
||||
le: Optional[float] = None,
|
||||
multiple_of: Optional[float] = None,
|
||||
allow_inf_nan: Optional[bool] = None,
|
||||
max_digits: Optional[int] = None,
|
||||
decimal_places: Optional[int] = None,
|
||||
min_items: Optional[int] = None,
|
||||
max_items: Optional[int] = None,
|
||||
unique_items: Optional[bool] = None,
|
||||
min_length: Optional[int] = None,
|
||||
max_length: Optional[int] = None,
|
||||
allow_mutation: bool = True,
|
||||
regex: Optional[str] = None,
|
||||
discriminator: Optional[str] = None,
|
||||
repr: bool = True,
|
||||
# copied from pydantic's Field
|
||||
default: Any = _Unset,
|
||||
default_factory: Callable[[], Any] | None = _Unset,
|
||||
title: str | None = _Unset,
|
||||
description: str | None = _Unset,
|
||||
pattern: str | None = _Unset,
|
||||
strict: bool | None = _Unset,
|
||||
gt: float | None = _Unset,
|
||||
ge: float | None = _Unset,
|
||||
lt: float | None = _Unset,
|
||||
le: float | None = _Unset,
|
||||
multiple_of: float | None = _Unset,
|
||||
allow_inf_nan: bool | None = _Unset,
|
||||
max_digits: int | None = _Unset,
|
||||
decimal_places: int | None = _Unset,
|
||||
min_length: int | None = _Unset,
|
||||
max_length: int | None = _Unset,
|
||||
# custom
|
||||
ui_type: Optional[UIType] = None,
|
||||
ui_hidden: bool = False,
|
||||
ui_order: Optional[int] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""
|
||||
Creates an output field for an invocation output.
|
||||
@ -379,15 +416,12 @@ def OutputField(
|
||||
: param int ui_order: [None] Specifies the order in which this field should be rendered in the UI. \
|
||||
"""
|
||||
return Field(
|
||||
*args,
|
||||
default=default,
|
||||
default_factory=default_factory,
|
||||
alias=alias,
|
||||
title=title,
|
||||
description=description,
|
||||
exclude=exclude,
|
||||
include=include,
|
||||
const=const,
|
||||
pattern=pattern,
|
||||
strict=strict,
|
||||
gt=gt,
|
||||
ge=ge,
|
||||
lt=lt,
|
||||
@ -396,19 +430,13 @@ def OutputField(
|
||||
allow_inf_nan=allow_inf_nan,
|
||||
max_digits=max_digits,
|
||||
decimal_places=decimal_places,
|
||||
min_items=min_items,
|
||||
max_items=max_items,
|
||||
unique_items=unique_items,
|
||||
min_length=min_length,
|
||||
max_length=max_length,
|
||||
allow_mutation=allow_mutation,
|
||||
regex=regex,
|
||||
discriminator=discriminator,
|
||||
repr=repr,
|
||||
ui_type=ui_type,
|
||||
ui_hidden=ui_hidden,
|
||||
ui_order=ui_order,
|
||||
**kwargs,
|
||||
json_schema_extra=dict(
|
||||
ui_type=ui_type,
|
||||
ui_hidden=ui_hidden,
|
||||
ui_order=ui_order,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@ -422,7 +450,13 @@ class UIConfigBase(BaseModel):
|
||||
title: Optional[str] = Field(default=None, description="The node's display name")
|
||||
category: Optional[str] = Field(default=None, description="The node's category")
|
||||
version: Optional[str] = Field(
|
||||
default=None, description='The node\'s version. Should be a valid semver string e.g. "1.0.0" or "3.8.13".'
|
||||
default=None,
|
||||
description='The node\'s version. Should be a valid semver string e.g. "1.0.0" or "3.8.13".',
|
||||
)
|
||||
|
||||
model_config = ConfigDict(
|
||||
validate_assignment=True,
|
||||
json_schema_serialization_defaults_required=True,
|
||||
)
|
||||
|
||||
|
||||
@ -457,23 +491,38 @@ class BaseInvocationOutput(BaseModel):
|
||||
All invocation outputs must use the `@invocation_output` decorator to provide their unique type.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def get_all_subclasses_tuple(cls):
|
||||
subclasses = []
|
||||
toprocess = [cls]
|
||||
while len(toprocess) > 0:
|
||||
next = toprocess.pop(0)
|
||||
next_subclasses = next.__subclasses__()
|
||||
subclasses.extend(next_subclasses)
|
||||
toprocess.extend(next_subclasses)
|
||||
return tuple(subclasses)
|
||||
_output_classes: ClassVar[set[BaseInvocationOutput]] = set()
|
||||
|
||||
class Config:
|
||||
@staticmethod
|
||||
def schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None:
|
||||
if "required" not in schema or not isinstance(schema["required"], list):
|
||||
schema["required"] = list()
|
||||
schema["required"].extend(["type"])
|
||||
@classmethod
|
||||
def register_output(cls, output: BaseInvocationOutput) -> None:
|
||||
cls._output_classes.add(output)
|
||||
|
||||
@classmethod
|
||||
def get_outputs(cls) -> Iterable[BaseInvocationOutput]:
|
||||
return cls._output_classes
|
||||
|
||||
@classmethod
|
||||
def get_outputs_union(cls) -> UnionType:
|
||||
outputs_union = Union[tuple(cls._output_classes)] # type: ignore [valid-type]
|
||||
return outputs_union # type: ignore [return-value]
|
||||
|
||||
@classmethod
|
||||
def get_output_types(cls) -> Iterable[str]:
|
||||
return map(lambda i: get_type(i), BaseInvocationOutput.get_outputs())
|
||||
|
||||
@staticmethod
|
||||
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None:
|
||||
# Because we use a pydantic Literal field with default value for the invocation type,
|
||||
# it will be typed as optional in the OpenAPI schema. Make it required manually.
|
||||
if "required" not in schema or not isinstance(schema["required"], list):
|
||||
schema["required"] = list()
|
||||
schema["required"].extend(["type"])
|
||||
|
||||
model_config = ConfigDict(
|
||||
validate_assignment=True,
|
||||
json_schema_serialization_defaults_required=True,
|
||||
json_schema_extra=json_schema_extra,
|
||||
)
|
||||
|
||||
|
||||
class RequiredConnectionException(Exception):
|
||||
@ -498,104 +547,91 @@ class BaseInvocation(ABC, BaseModel):
|
||||
All invocations must use the `@invocation` decorator to provide their unique type.
|
||||
"""
|
||||
|
||||
_invocation_classes: ClassVar[set[BaseInvocation]] = set()
|
||||
|
||||
@classmethod
|
||||
def get_all_subclasses(cls):
|
||||
def register_invocation(cls, invocation: BaseInvocation) -> None:
|
||||
cls._invocation_classes.add(invocation)
|
||||
|
||||
@classmethod
|
||||
def get_invocations_union(cls) -> UnionType:
|
||||
invocations_union = Union[tuple(cls._invocation_classes)] # type: ignore [valid-type]
|
||||
return invocations_union # type: ignore [return-value]
|
||||
|
||||
@classmethod
|
||||
def get_invocations(cls) -> Iterable[BaseInvocation]:
|
||||
app_config = InvokeAIAppConfig.get_config()
|
||||
subclasses = []
|
||||
toprocess = [cls]
|
||||
while len(toprocess) > 0:
|
||||
next = toprocess.pop(0)
|
||||
next_subclasses = next.__subclasses__()
|
||||
subclasses.extend(next_subclasses)
|
||||
toprocess.extend(next_subclasses)
|
||||
allowed_invocations = []
|
||||
for sc in subclasses:
|
||||
allowed_invocations: set[BaseInvocation] = set()
|
||||
for sc in cls._invocation_classes:
|
||||
invocation_type = get_type(sc)
|
||||
is_in_allowlist = (
|
||||
sc.__fields__.get("type").default in app_config.allow_nodes
|
||||
if isinstance(app_config.allow_nodes, list)
|
||||
else True
|
||||
invocation_type in app_config.allow_nodes if isinstance(app_config.allow_nodes, list) else True
|
||||
)
|
||||
|
||||
is_in_denylist = (
|
||||
sc.__fields__.get("type").default in app_config.deny_nodes
|
||||
if isinstance(app_config.deny_nodes, list)
|
||||
else False
|
||||
invocation_type in app_config.deny_nodes if isinstance(app_config.deny_nodes, list) else False
|
||||
)
|
||||
|
||||
if is_in_allowlist and not is_in_denylist:
|
||||
allowed_invocations.append(sc)
|
||||
allowed_invocations.add(sc)
|
||||
return allowed_invocations
|
||||
|
||||
@classmethod
|
||||
def get_invocations(cls):
|
||||
return tuple(BaseInvocation.get_all_subclasses())
|
||||
|
||||
@classmethod
|
||||
def get_invocations_map(cls):
|
||||
def get_invocations_map(cls) -> dict[str, BaseInvocation]:
|
||||
# Get the type strings out of the literals and into a dictionary
|
||||
return dict(
|
||||
map(
|
||||
lambda t: (get_args(get_type_hints(t)["type"])[0], t),
|
||||
BaseInvocation.get_all_subclasses(),
|
||||
lambda i: (get_type(i), i),
|
||||
BaseInvocation.get_invocations(),
|
||||
)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_output_type(cls):
|
||||
def get_invocation_types(cls) -> Iterable[str]:
|
||||
return map(lambda i: get_type(i), BaseInvocation.get_invocations())
|
||||
|
||||
@classmethod
|
||||
def get_output_type(cls) -> BaseInvocationOutput:
|
||||
return signature(cls.invoke).return_annotation
|
||||
|
||||
class Config:
|
||||
validate_assignment = True
|
||||
validate_all = True
|
||||
|
||||
@staticmethod
|
||||
def schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None:
|
||||
uiconfig = getattr(model_class, "UIConfig", None)
|
||||
if uiconfig and hasattr(uiconfig, "title"):
|
||||
schema["title"] = uiconfig.title
|
||||
if uiconfig and hasattr(uiconfig, "tags"):
|
||||
schema["tags"] = uiconfig.tags
|
||||
if uiconfig and hasattr(uiconfig, "category"):
|
||||
schema["category"] = uiconfig.category
|
||||
if uiconfig and hasattr(uiconfig, "version"):
|
||||
schema["version"] = uiconfig.version
|
||||
if "required" not in schema or not isinstance(schema["required"], list):
|
||||
schema["required"] = list()
|
||||
schema["required"].extend(["type", "id"])
|
||||
@staticmethod
|
||||
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None:
|
||||
# Add the various UI-facing attributes to the schema. These are used to build the invocation templates.
|
||||
uiconfig = getattr(model_class, "UIConfig", None)
|
||||
if uiconfig and hasattr(uiconfig, "title"):
|
||||
schema["title"] = uiconfig.title
|
||||
if uiconfig and hasattr(uiconfig, "tags"):
|
||||
schema["tags"] = uiconfig.tags
|
||||
if uiconfig and hasattr(uiconfig, "category"):
|
||||
schema["category"] = uiconfig.category
|
||||
if uiconfig and hasattr(uiconfig, "version"):
|
||||
schema["version"] = uiconfig.version
|
||||
if "required" not in schema or not isinstance(schema["required"], list):
|
||||
schema["required"] = list()
|
||||
schema["required"].extend(["type", "id"])
|
||||
|
||||
@abstractmethod
|
||||
def invoke(self, context: InvocationContext) -> BaseInvocationOutput:
|
||||
"""Invoke with provided context and return outputs."""
|
||||
pass
|
||||
|
||||
def __init__(self, **data):
|
||||
# nodes may have required fields, that can accept input from connections
|
||||
# on instantiation of the model, we need to exclude these from validation
|
||||
restore = dict()
|
||||
try:
|
||||
field_names = list(self.__fields__.keys())
|
||||
for field_name in field_names:
|
||||
# if the field is required and may get its value from a connection, exclude it from validation
|
||||
field = self.__fields__[field_name]
|
||||
_input = field.field_info.extra.get("input", None)
|
||||
if _input in [Input.Connection, Input.Any] and field.required:
|
||||
if field_name not in data:
|
||||
restore[field_name] = self.__fields__.pop(field_name)
|
||||
# instantiate the node, which will validate the data
|
||||
super().__init__(**data)
|
||||
finally:
|
||||
# restore the removed fields
|
||||
for field_name, field in restore.items():
|
||||
self.__fields__[field_name] = field
|
||||
|
||||
def invoke_internal(self, context: InvocationContext) -> BaseInvocationOutput:
|
||||
for field_name, field in self.__fields__.items():
|
||||
_input = field.field_info.extra.get("input", None)
|
||||
if field.required and not hasattr(self, field_name):
|
||||
if _input == Input.Connection:
|
||||
raise RequiredConnectionException(self.__fields__["type"].default, field_name)
|
||||
elif _input == Input.Any:
|
||||
raise MissingInputException(self.__fields__["type"].default, field_name)
|
||||
for field_name, field in self.model_fields.items():
|
||||
if not field.json_schema_extra or callable(field.json_schema_extra):
|
||||
# something has gone terribly awry, we should always have this and it should be a dict
|
||||
continue
|
||||
|
||||
# Here we handle the case where the field is optional in the pydantic class, but required
|
||||
# in the `invoke()` method.
|
||||
|
||||
orig_default = field.json_schema_extra.get("orig_default", PydanticUndefined)
|
||||
orig_required = field.json_schema_extra.get("orig_required", True)
|
||||
input_ = field.json_schema_extra.get("input", None)
|
||||
if orig_default is not PydanticUndefined and not hasattr(self, field_name):
|
||||
setattr(self, field_name, orig_default)
|
||||
if orig_required and orig_default is PydanticUndefined and getattr(self, field_name) is None:
|
||||
if input_ == Input.Connection:
|
||||
raise RequiredConnectionException(self.model_fields["type"].default, field_name)
|
||||
elif input_ == Input.Any:
|
||||
raise MissingInputException(self.model_fields["type"].default, field_name)
|
||||
|
||||
# skip node cache codepath if it's disabled
|
||||
if context.services.configuration.node_cache_size == 0:
|
||||
@ -618,23 +654,31 @@ class BaseInvocation(ABC, BaseModel):
|
||||
return self.invoke(context)
|
||||
|
||||
def get_type(self) -> str:
|
||||
return self.__fields__["type"].default
|
||||
return self.model_fields["type"].default
|
||||
|
||||
id: str = Field(
|
||||
description="The id of this instance of an invocation. Must be unique among all instances of invocations."
|
||||
default_factory=uuid_string,
|
||||
description="The id of this instance of an invocation. Must be unique among all instances of invocations.",
|
||||
)
|
||||
is_intermediate: bool = InputField(
|
||||
default=False, description="Whether or not this is an intermediate invocation.", ui_type=UIType.IsIntermediate
|
||||
is_intermediate: Optional[bool] = Field(
|
||||
default=False,
|
||||
description="Whether or not this is an intermediate invocation.",
|
||||
json_schema_extra=dict(ui_type=UIType.IsIntermediate),
|
||||
)
|
||||
workflow: Optional[str] = InputField(
|
||||
workflow: Optional[str] = Field(
|
||||
default=None,
|
||||
description="The workflow to save with the image",
|
||||
ui_type=UIType.WorkflowField,
|
||||
json_schema_extra=dict(ui_type=UIType.WorkflowField),
|
||||
)
|
||||
use_cache: Optional[bool] = Field(
|
||||
default=True,
|
||||
description="Whether or not to use the cache",
|
||||
)
|
||||
use_cache: bool = InputField(default=True, description="Whether or not to use the cache")
|
||||
|
||||
@validator("workflow", pre=True)
|
||||
@field_validator("workflow", mode="before")
|
||||
@classmethod
|
||||
def validate_workflow_is_json(cls, v):
|
||||
"""We don't have a workflow schema in the backend, so we just check that it's valid JSON"""
|
||||
if v is None:
|
||||
return None
|
||||
try:
|
||||
@ -645,8 +689,14 @@ class BaseInvocation(ABC, BaseModel):
|
||||
|
||||
UIConfig: ClassVar[Type[UIConfigBase]]
|
||||
|
||||
model_config = ConfigDict(
|
||||
validate_assignment=True,
|
||||
json_schema_extra=json_schema_extra,
|
||||
json_schema_serialization_defaults_required=True,
|
||||
)
|
||||
|
||||
GenericBaseInvocation = TypeVar("GenericBaseInvocation", bound=BaseInvocation)
|
||||
|
||||
TBaseInvocation = TypeVar("TBaseInvocation", bound=BaseInvocation)
|
||||
|
||||
|
||||
def invocation(
|
||||
@ -656,7 +706,7 @@ def invocation(
|
||||
category: Optional[str] = None,
|
||||
version: Optional[str] = None,
|
||||
use_cache: Optional[bool] = True,
|
||||
) -> Callable[[Type[GenericBaseInvocation]], Type[GenericBaseInvocation]]:
|
||||
) -> Callable[[Type[TBaseInvocation]], Type[TBaseInvocation]]:
|
||||
"""
|
||||
Adds metadata to an invocation.
|
||||
|
||||
@ -668,12 +718,15 @@ def invocation(
|
||||
:param Optional[bool] use_cache: Whether or not to use the invocation cache. Defaults to True. The user may override this in the workflow editor.
|
||||
"""
|
||||
|
||||
def wrapper(cls: Type[GenericBaseInvocation]) -> Type[GenericBaseInvocation]:
|
||||
def wrapper(cls: Type[TBaseInvocation]) -> Type[TBaseInvocation]:
|
||||
# Validate invocation types on creation of invocation classes
|
||||
# TODO: ensure unique?
|
||||
if re.compile(r"^\S+$").match(invocation_type) is None:
|
||||
raise ValueError(f'"invocation_type" must consist of non-whitespace characters, got "{invocation_type}"')
|
||||
|
||||
if invocation_type in BaseInvocation.get_invocation_types():
|
||||
raise ValueError(f'Invocation type "{invocation_type}" already exists')
|
||||
|
||||
# Add OpenAPI schema extras
|
||||
uiconf_name = cls.__qualname__ + ".UIConfig"
|
||||
if not hasattr(cls, "UIConfig") or cls.UIConfig.__qualname__ != uiconf_name:
|
||||
@ -691,59 +744,83 @@ def invocation(
|
||||
raise InvalidVersionError(f'Invalid version string for node "{invocation_type}": "{version}"') from e
|
||||
cls.UIConfig.version = version
|
||||
if use_cache is not None:
|
||||
cls.__fields__["use_cache"].default = use_cache
|
||||
cls.model_fields["use_cache"].default = use_cache
|
||||
|
||||
# Add the invocation type to the model.
|
||||
|
||||
# You'd be tempted to just add the type field and rebuild the model, like this:
|
||||
# cls.model_fields.update(type=FieldInfo.from_annotated_attribute(Literal[invocation_type], invocation_type))
|
||||
# cls.model_rebuild() or cls.model_rebuild(force=True)
|
||||
|
||||
# Unfortunately, because the `GraphInvocation` uses a forward ref in its `graph` field's annotation, this does
|
||||
# not work. Instead, we have to create a new class with the type field and patch the original class with it.
|
||||
|
||||
# Add the invocation type to the pydantic model of the invocation
|
||||
invocation_type_annotation = Literal[invocation_type] # type: ignore
|
||||
invocation_type_field = ModelField.infer(
|
||||
name="type",
|
||||
value=invocation_type,
|
||||
annotation=invocation_type_annotation,
|
||||
class_validators=None,
|
||||
config=cls.__config__,
|
||||
invocation_type_field = Field(
|
||||
title="type",
|
||||
default=invocation_type,
|
||||
)
|
||||
cls.__fields__.update({"type": invocation_type_field})
|
||||
# to support 3.9, 3.10 and 3.11, as described in https://docs.python.org/3/howto/annotations.html
|
||||
if annotations := cls.__dict__.get("__annotations__", None):
|
||||
annotations.update({"type": invocation_type_annotation})
|
||||
|
||||
docstring = cls.__doc__
|
||||
cls = create_model(
|
||||
cls.__qualname__,
|
||||
__base__=cls,
|
||||
__module__=cls.__module__,
|
||||
type=(invocation_type_annotation, invocation_type_field),
|
||||
)
|
||||
cls.__doc__ = docstring
|
||||
|
||||
# TODO: how to type this correctly? it's typed as ModelMetaclass, a private class in pydantic
|
||||
BaseInvocation.register_invocation(cls) # type: ignore
|
||||
|
||||
return cls
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
GenericBaseInvocationOutput = TypeVar("GenericBaseInvocationOutput", bound=BaseInvocationOutput)
|
||||
TBaseInvocationOutput = TypeVar("TBaseInvocationOutput", bound=BaseInvocationOutput)
|
||||
|
||||
|
||||
def invocation_output(
|
||||
output_type: str,
|
||||
) -> Callable[[Type[GenericBaseInvocationOutput]], Type[GenericBaseInvocationOutput]]:
|
||||
) -> Callable[[Type[TBaseInvocationOutput]], Type[TBaseInvocationOutput]]:
|
||||
"""
|
||||
Adds metadata to an invocation output.
|
||||
|
||||
:param str output_type: The type of the invocation output. Must be unique among all invocation outputs.
|
||||
"""
|
||||
|
||||
def wrapper(cls: Type[GenericBaseInvocationOutput]) -> Type[GenericBaseInvocationOutput]:
|
||||
def wrapper(cls: Type[TBaseInvocationOutput]) -> Type[TBaseInvocationOutput]:
|
||||
# Validate output types on creation of invocation output classes
|
||||
# TODO: ensure unique?
|
||||
if re.compile(r"^\S+$").match(output_type) is None:
|
||||
raise ValueError(f'"output_type" must consist of non-whitespace characters, got "{output_type}"')
|
||||
|
||||
# Add the output type to the pydantic model of the invocation output
|
||||
output_type_annotation = Literal[output_type] # type: ignore
|
||||
output_type_field = ModelField.infer(
|
||||
name="type",
|
||||
value=output_type,
|
||||
annotation=output_type_annotation,
|
||||
class_validators=None,
|
||||
config=cls.__config__,
|
||||
)
|
||||
cls.__fields__.update({"type": output_type_field})
|
||||
if output_type in BaseInvocationOutput.get_output_types():
|
||||
raise ValueError(f'Invocation type "{output_type}" already exists')
|
||||
|
||||
# to support 3.9, 3.10 and 3.11, as described in https://docs.python.org/3/howto/annotations.html
|
||||
if annotations := cls.__dict__.get("__annotations__", None):
|
||||
annotations.update({"type": output_type_annotation})
|
||||
# Add the output type to the model.
|
||||
|
||||
output_type_annotation = Literal[output_type] # type: ignore
|
||||
output_type_field = Field(
|
||||
title="type",
|
||||
default=output_type,
|
||||
)
|
||||
|
||||
docstring = cls.__doc__
|
||||
cls = create_model(
|
||||
cls.__qualname__,
|
||||
__base__=cls,
|
||||
__module__=cls.__module__,
|
||||
type=(output_type_annotation, output_type_field),
|
||||
)
|
||||
cls.__doc__ = docstring
|
||||
|
||||
BaseInvocationOutput.register_output(cls) # type: ignore # TODO: how to type this correctly?
|
||||
|
||||
return cls
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
GenericBaseModel = TypeVar("GenericBaseModel", bound=BaseModel)
|
||||
|
@ -2,7 +2,7 @@
|
||||
|
||||
|
||||
import numpy as np
|
||||
from pydantic import validator
|
||||
from pydantic import ValidationInfo, field_validator
|
||||
|
||||
from invokeai.app.invocations.primitives import IntegerCollectionOutput
|
||||
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
||||
@ -20,9 +20,9 @@ class RangeInvocation(BaseInvocation):
|
||||
stop: int = InputField(default=10, description="The stop of the range")
|
||||
step: int = InputField(default=1, description="The step of the range")
|
||||
|
||||
@validator("stop")
|
||||
def stop_gt_start(cls, v, values):
|
||||
if "start" in values and v <= values["start"]:
|
||||
@field_validator("stop")
|
||||
def stop_gt_start(cls, v: int, info: ValidationInfo):
|
||||
if "start" in info.data and v <= info.data["start"]:
|
||||
raise ValueError("stop must be greater than start")
|
||||
return v
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Union
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
from compel import Compel, ReturnedEmbeddingsType
|
||||
@ -43,7 +43,13 @@ class ConditioningFieldData:
|
||||
# PerpNeg = "perp_neg"
|
||||
|
||||
|
||||
@invocation("compel", title="Prompt", tags=["prompt", "compel"], category="conditioning", version="1.0.0")
|
||||
@invocation(
|
||||
"compel",
|
||||
title="Prompt",
|
||||
tags=["prompt", "compel"],
|
||||
category="conditioning",
|
||||
version="1.0.0",
|
||||
)
|
||||
class CompelInvocation(BaseInvocation):
|
||||
"""Parse prompt using compel package to conditioning."""
|
||||
|
||||
@ -61,17 +67,19 @@ class CompelInvocation(BaseInvocation):
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
||||
tokenizer_info = context.services.model_manager.get_model(
|
||||
**self.clip.tokenizer.dict(),
|
||||
**self.clip.tokenizer.model_dump(),
|
||||
context=context,
|
||||
)
|
||||
text_encoder_info = context.services.model_manager.get_model(
|
||||
**self.clip.text_encoder.dict(),
|
||||
**self.clip.text_encoder.model_dump(),
|
||||
context=context,
|
||||
)
|
||||
|
||||
def _lora_loader():
|
||||
for lora in self.clip.loras:
|
||||
lora_info = context.services.model_manager.get_model(**lora.dict(exclude={"weight"}), context=context)
|
||||
lora_info = context.services.model_manager.get_model(
|
||||
**lora.model_dump(exclude={"weight"}), context=context
|
||||
)
|
||||
yield (lora_info.context.model, lora.weight)
|
||||
del lora_info
|
||||
return
|
||||
@ -160,11 +168,11 @@ class SDXLPromptInvocationBase:
|
||||
zero_on_empty: bool,
|
||||
):
|
||||
tokenizer_info = context.services.model_manager.get_model(
|
||||
**clip_field.tokenizer.dict(),
|
||||
**clip_field.tokenizer.model_dump(),
|
||||
context=context,
|
||||
)
|
||||
text_encoder_info = context.services.model_manager.get_model(
|
||||
**clip_field.text_encoder.dict(),
|
||||
**clip_field.text_encoder.model_dump(),
|
||||
context=context,
|
||||
)
|
||||
|
||||
@ -172,7 +180,11 @@ class SDXLPromptInvocationBase:
|
||||
if prompt == "" and zero_on_empty:
|
||||
cpu_text_encoder = text_encoder_info.context.model
|
||||
c = torch.zeros(
|
||||
(1, cpu_text_encoder.config.max_position_embeddings, cpu_text_encoder.config.hidden_size),
|
||||
(
|
||||
1,
|
||||
cpu_text_encoder.config.max_position_embeddings,
|
||||
cpu_text_encoder.config.hidden_size,
|
||||
),
|
||||
dtype=text_encoder_info.context.cache.precision,
|
||||
)
|
||||
if get_pooled:
|
||||
@ -186,7 +198,9 @@ class SDXLPromptInvocationBase:
|
||||
|
||||
def _lora_loader():
|
||||
for lora in clip_field.loras:
|
||||
lora_info = context.services.model_manager.get_model(**lora.dict(exclude={"weight"}), context=context)
|
||||
lora_info = context.services.model_manager.get_model(
|
||||
**lora.model_dump(exclude={"weight"}), context=context
|
||||
)
|
||||
yield (lora_info.context.model, lora.weight)
|
||||
del lora_info
|
||||
return
|
||||
@ -273,8 +287,16 @@ class SDXLPromptInvocationBase:
|
||||
class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||
"""Parse prompt using compel package to conditioning."""
|
||||
|
||||
prompt: str = InputField(default="", description=FieldDescriptions.compel_prompt, ui_component=UIComponent.Textarea)
|
||||
style: str = InputField(default="", description=FieldDescriptions.compel_prompt, ui_component=UIComponent.Textarea)
|
||||
prompt: str = InputField(
|
||||
default="",
|
||||
description=FieldDescriptions.compel_prompt,
|
||||
ui_component=UIComponent.Textarea,
|
||||
)
|
||||
style: str = InputField(
|
||||
default="",
|
||||
description=FieldDescriptions.compel_prompt,
|
||||
ui_component=UIComponent.Textarea,
|
||||
)
|
||||
original_width: int = InputField(default=1024, description="")
|
||||
original_height: int = InputField(default=1024, description="")
|
||||
crop_top: int = InputField(default=0, description="")
|
||||
@ -310,7 +332,9 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||
[
|
||||
c1,
|
||||
torch.zeros(
|
||||
(c1.shape[0], c2.shape[1] - c1.shape[1], c1.shape[2]), device=c1.device, dtype=c1.dtype
|
||||
(c1.shape[0], c2.shape[1] - c1.shape[1], c1.shape[2]),
|
||||
device=c1.device,
|
||||
dtype=c1.dtype,
|
||||
),
|
||||
],
|
||||
dim=1,
|
||||
@ -321,7 +345,9 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||
[
|
||||
c2,
|
||||
torch.zeros(
|
||||
(c2.shape[0], c1.shape[1] - c2.shape[1], c2.shape[2]), device=c2.device, dtype=c2.dtype
|
||||
(c2.shape[0], c1.shape[1] - c2.shape[1], c2.shape[2]),
|
||||
device=c2.device,
|
||||
dtype=c2.dtype,
|
||||
),
|
||||
],
|
||||
dim=1,
|
||||
@ -359,7 +385,9 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase
|
||||
"""Parse prompt using compel package to conditioning."""
|
||||
|
||||
style: str = InputField(
|
||||
default="", description=FieldDescriptions.compel_prompt, ui_component=UIComponent.Textarea
|
||||
default="",
|
||||
description=FieldDescriptions.compel_prompt,
|
||||
ui_component=UIComponent.Textarea,
|
||||
) # TODO: ?
|
||||
original_width: int = InputField(default=1024, description="")
|
||||
original_height: int = InputField(default=1024, description="")
|
||||
@ -403,10 +431,16 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase
|
||||
class ClipSkipInvocationOutput(BaseInvocationOutput):
|
||||
"""Clip skip node output"""
|
||||
|
||||
clip: ClipField = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP")
|
||||
clip: Optional[ClipField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP")
|
||||
|
||||
|
||||
@invocation("clip_skip", title="CLIP Skip", tags=["clipskip", "clip", "skip"], category="conditioning", version="1.0.0")
|
||||
@invocation(
|
||||
"clip_skip",
|
||||
title="CLIP Skip",
|
||||
tags=["clipskip", "clip", "skip"],
|
||||
category="conditioning",
|
||||
version="1.0.0",
|
||||
)
|
||||
class ClipSkipInvocation(BaseInvocation):
|
||||
"""Skip layers in clip text_encoder model."""
|
||||
|
||||
@ -421,7 +455,9 @@ class ClipSkipInvocation(BaseInvocation):
|
||||
|
||||
|
||||
def get_max_token_count(
|
||||
tokenizer, prompt: Union[FlattenedPrompt, Blend, Conjunction], truncate_if_too_long=False
|
||||
tokenizer,
|
||||
prompt: Union[FlattenedPrompt, Blend, Conjunction],
|
||||
truncate_if_too_long=False,
|
||||
) -> int:
|
||||
if type(prompt) is Blend:
|
||||
blend: Blend = prompt
|
||||
|
@ -2,7 +2,7 @@
|
||||
# initial implementation by Gregg Helt, 2023
|
||||
# heavily leverages controlnet_aux package: https://github.com/patrickvonplaten/controlnet_aux
|
||||
from builtins import bool, float
|
||||
from typing import Dict, List, Literal, Optional, Union
|
||||
from typing import Dict, List, Literal, Union
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
@ -24,7 +24,7 @@ from controlnet_aux import (
|
||||
)
|
||||
from controlnet_aux.util import HWC3, ade_palette
|
||||
from PIL import Image
|
||||
from pydantic import BaseModel, Field, validator
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
from invokeai.app.invocations.primitives import ImageField, ImageOutput
|
||||
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
|
||||
@ -57,6 +57,8 @@ class ControlNetModelField(BaseModel):
|
||||
model_name: str = Field(description="Name of the ControlNet model")
|
||||
base_model: BaseModelType = Field(description="Base model")
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
|
||||
class ControlField(BaseModel):
|
||||
image: ImageField = Field(description="The control image")
|
||||
@ -71,7 +73,7 @@ class ControlField(BaseModel):
|
||||
control_mode: CONTROLNET_MODE_VALUES = Field(default="balanced", description="The control mode to use")
|
||||
resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode to use")
|
||||
|
||||
@validator("control_weight")
|
||||
@field_validator("control_weight")
|
||||
def validate_control_weight(cls, v):
|
||||
"""Validate that all control weights in the valid range"""
|
||||
if isinstance(v, list):
|
||||
@ -124,9 +126,7 @@ class ControlNetInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
|
||||
@invocation(
|
||||
"image_processor", title="Base Image Processor", tags=["controlnet"], category="controlnet", version="1.0.0"
|
||||
)
|
||||
# This invocation exists for other invocations to subclass it - do not register with @invocation!
|
||||
class ImageProcessorInvocation(BaseInvocation):
|
||||
"""Base class for invocations that preprocess images for ControlNet"""
|
||||
|
||||
@ -393,9 +393,9 @@ class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation):
|
||||
|
||||
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
|
||||
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
|
||||
h: Optional[int] = InputField(default=512, ge=0, description="Content shuffle `h` parameter")
|
||||
w: Optional[int] = InputField(default=512, ge=0, description="Content shuffle `w` parameter")
|
||||
f: Optional[int] = InputField(default=256, ge=0, description="Content shuffle `f` parameter")
|
||||
h: int = InputField(default=512, ge=0, description="Content shuffle `h` parameter")
|
||||
w: int = InputField(default=512, ge=0, description="Content shuffle `w` parameter")
|
||||
f: int = InputField(default=256, ge=0, description="Content shuffle `f` parameter")
|
||||
|
||||
def run_processor(self, image):
|
||||
content_shuffle_processor = ContentShuffleDetector()
|
||||
@ -575,14 +575,14 @@ class ColorMapImageProcessorInvocation(ImageProcessorInvocation):
|
||||
|
||||
def run_processor(self, image: Image.Image):
|
||||
image = image.convert("RGB")
|
||||
image = np.array(image, dtype=np.uint8)
|
||||
height, width = image.shape[:2]
|
||||
np_image = np.array(image, dtype=np.uint8)
|
||||
height, width = np_image.shape[:2]
|
||||
|
||||
width_tile_size = min(self.color_map_tile_size, width)
|
||||
height_tile_size = min(self.color_map_tile_size, height)
|
||||
|
||||
color_map = cv2.resize(
|
||||
image,
|
||||
np_image,
|
||||
(width // width_tile_size, height // height_tile_size),
|
||||
interpolation=cv2.INTER_CUBIC,
|
||||
)
|
||||
|
@ -8,7 +8,7 @@ import numpy as np
|
||||
from mediapipe.python.solutions.face_mesh import FaceMesh # type: ignore[import]
|
||||
from PIL import Image, ImageDraw, ImageFilter, ImageFont, ImageOps
|
||||
from PIL.Image import Image as ImageType
|
||||
from pydantic import validator
|
||||
from pydantic import field_validator
|
||||
|
||||
import invokeai.assets.fonts as font_assets
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
@ -550,7 +550,7 @@ class FaceMaskInvocation(BaseInvocation):
|
||||
)
|
||||
invert_mask: bool = InputField(default=False, description="Toggle to invert the mask")
|
||||
|
||||
@validator("face_ids")
|
||||
@field_validator("face_ids")
|
||||
def validate_comma_separated_ints(cls, v) -> str:
|
||||
comma_separated_ints_regex = re.compile(r"^\d*(,\d+)*$")
|
||||
if comma_separated_ints_regex.match(v) is None:
|
||||
|
@ -36,7 +36,13 @@ class ShowImageInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
|
||||
@invocation("blank_image", title="Blank Image", tags=["image"], category="image", version="1.0.0")
|
||||
@invocation(
|
||||
"blank_image",
|
||||
title="Blank Image",
|
||||
tags=["image"],
|
||||
category="image",
|
||||
version="1.0.0",
|
||||
)
|
||||
class BlankImageInvocation(BaseInvocation):
|
||||
"""Creates a blank image and forwards it to the pipeline"""
|
||||
|
||||
@ -65,7 +71,13 @@ class BlankImageInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
|
||||
@invocation("img_crop", title="Crop Image", tags=["image", "crop"], category="image", version="1.0.0")
|
||||
@invocation(
|
||||
"img_crop",
|
||||
title="Crop Image",
|
||||
tags=["image", "crop"],
|
||||
category="image",
|
||||
version="1.0.0",
|
||||
)
|
||||
class ImageCropInvocation(BaseInvocation):
|
||||
"""Crops an image to a specified box. The box can be outside of the image."""
|
||||
|
||||
@ -98,7 +110,13 @@ class ImageCropInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
|
||||
@invocation("img_paste", title="Paste Image", tags=["image", "paste"], category="image", version="1.0.1")
|
||||
@invocation(
|
||||
"img_paste",
|
||||
title="Paste Image",
|
||||
tags=["image", "paste"],
|
||||
category="image",
|
||||
version="1.0.1",
|
||||
)
|
||||
class ImagePasteInvocation(BaseInvocation):
|
||||
"""Pastes an image into another image."""
|
||||
|
||||
@ -151,7 +169,13 @@ class ImagePasteInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
|
||||
@invocation("tomask", title="Mask from Alpha", tags=["image", "mask"], category="image", version="1.0.0")
|
||||
@invocation(
|
||||
"tomask",
|
||||
title="Mask from Alpha",
|
||||
tags=["image", "mask"],
|
||||
category="image",
|
||||
version="1.0.0",
|
||||
)
|
||||
class MaskFromAlphaInvocation(BaseInvocation):
|
||||
"""Extracts the alpha channel of an image as a mask."""
|
||||
|
||||
@ -182,7 +206,13 @@ class MaskFromAlphaInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
|
||||
@invocation("img_mul", title="Multiply Images", tags=["image", "multiply"], category="image", version="1.0.0")
|
||||
@invocation(
|
||||
"img_mul",
|
||||
title="Multiply Images",
|
||||
tags=["image", "multiply"],
|
||||
category="image",
|
||||
version="1.0.0",
|
||||
)
|
||||
class ImageMultiplyInvocation(BaseInvocation):
|
||||
"""Multiplies two images together using `PIL.ImageChops.multiply()`."""
|
||||
|
||||
@ -215,7 +245,13 @@ class ImageMultiplyInvocation(BaseInvocation):
|
||||
IMAGE_CHANNELS = Literal["A", "R", "G", "B"]
|
||||
|
||||
|
||||
@invocation("img_chan", title="Extract Image Channel", tags=["image", "channel"], category="image", version="1.0.0")
|
||||
@invocation(
|
||||
"img_chan",
|
||||
title="Extract Image Channel",
|
||||
tags=["image", "channel"],
|
||||
category="image",
|
||||
version="1.0.0",
|
||||
)
|
||||
class ImageChannelInvocation(BaseInvocation):
|
||||
"""Gets a channel from an image."""
|
||||
|
||||
@ -247,7 +283,13 @@ class ImageChannelInvocation(BaseInvocation):
|
||||
IMAGE_MODES = Literal["L", "RGB", "RGBA", "CMYK", "YCbCr", "LAB", "HSV", "I", "F"]
|
||||
|
||||
|
||||
@invocation("img_conv", title="Convert Image Mode", tags=["image", "convert"], category="image", version="1.0.0")
|
||||
@invocation(
|
||||
"img_conv",
|
||||
title="Convert Image Mode",
|
||||
tags=["image", "convert"],
|
||||
category="image",
|
||||
version="1.0.0",
|
||||
)
|
||||
class ImageConvertInvocation(BaseInvocation):
|
||||
"""Converts an image to a different mode."""
|
||||
|
||||
@ -276,7 +318,13 @@ class ImageConvertInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
|
||||
@invocation("img_blur", title="Blur Image", tags=["image", "blur"], category="image", version="1.0.0")
|
||||
@invocation(
|
||||
"img_blur",
|
||||
title="Blur Image",
|
||||
tags=["image", "blur"],
|
||||
category="image",
|
||||
version="1.0.0",
|
||||
)
|
||||
class ImageBlurInvocation(BaseInvocation):
|
||||
"""Blurs an image"""
|
||||
|
||||
@ -330,7 +378,13 @@ PIL_RESAMPLING_MAP = {
|
||||
}
|
||||
|
||||
|
||||
@invocation("img_resize", title="Resize Image", tags=["image", "resize"], category="image", version="1.0.0")
|
||||
@invocation(
|
||||
"img_resize",
|
||||
title="Resize Image",
|
||||
tags=["image", "resize"],
|
||||
category="image",
|
||||
version="1.0.0",
|
||||
)
|
||||
class ImageResizeInvocation(BaseInvocation):
|
||||
"""Resizes an image to specific dimensions"""
|
||||
|
||||
@ -359,7 +413,7 @@ class ImageResizeInvocation(BaseInvocation):
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata.dict() if self.metadata else None,
|
||||
metadata=self.metadata.model_dump() if self.metadata else None,
|
||||
workflow=self.workflow,
|
||||
)
|
||||
|
||||
@ -370,7 +424,13 @@ class ImageResizeInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
|
||||
@invocation("img_scale", title="Scale Image", tags=["image", "scale"], category="image", version="1.0.0")
|
||||
@invocation(
|
||||
"img_scale",
|
||||
title="Scale Image",
|
||||
tags=["image", "scale"],
|
||||
category="image",
|
||||
version="1.0.0",
|
||||
)
|
||||
class ImageScaleInvocation(BaseInvocation):
|
||||
"""Scales an image by a factor"""
|
||||
|
||||
@ -411,7 +471,13 @@ class ImageScaleInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
|
||||
@invocation("img_lerp", title="Lerp Image", tags=["image", "lerp"], category="image", version="1.0.0")
|
||||
@invocation(
|
||||
"img_lerp",
|
||||
title="Lerp Image",
|
||||
tags=["image", "lerp"],
|
||||
category="image",
|
||||
version="1.0.0",
|
||||
)
|
||||
class ImageLerpInvocation(BaseInvocation):
|
||||
"""Linear interpolation of all pixels of an image"""
|
||||
|
||||
@ -444,7 +510,13 @@ class ImageLerpInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
|
||||
@invocation("img_ilerp", title="Inverse Lerp Image", tags=["image", "ilerp"], category="image", version="1.0.0")
|
||||
@invocation(
|
||||
"img_ilerp",
|
||||
title="Inverse Lerp Image",
|
||||
tags=["image", "ilerp"],
|
||||
category="image",
|
||||
version="1.0.0",
|
||||
)
|
||||
class ImageInverseLerpInvocation(BaseInvocation):
|
||||
"""Inverse linear interpolation of all pixels of an image"""
|
||||
|
||||
@ -456,7 +528,7 @@ class ImageInverseLerpInvocation(BaseInvocation):
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
|
||||
image_arr = numpy.asarray(image, dtype=numpy.float32)
|
||||
image_arr = numpy.minimum(numpy.maximum(image_arr - self.min, 0) / float(self.max - self.min), 1) * 255
|
||||
image_arr = numpy.minimum(numpy.maximum(image_arr - self.min, 0) / float(self.max - self.min), 1) * 255 # type: ignore [assignment]
|
||||
|
||||
ilerp_image = Image.fromarray(numpy.uint8(image_arr))
|
||||
|
||||
@ -477,7 +549,13 @@ class ImageInverseLerpInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
|
||||
@invocation("img_nsfw", title="Blur NSFW Image", tags=["image", "nsfw"], category="image", version="1.0.0")
|
||||
@invocation(
|
||||
"img_nsfw",
|
||||
title="Blur NSFW Image",
|
||||
tags=["image", "nsfw"],
|
||||
category="image",
|
||||
version="1.0.0",
|
||||
)
|
||||
class ImageNSFWBlurInvocation(BaseInvocation):
|
||||
"""Add blur to NSFW-flagged images"""
|
||||
|
||||
@ -505,7 +583,7 @@ class ImageNSFWBlurInvocation(BaseInvocation):
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata.dict() if self.metadata else None,
|
||||
metadata=self.metadata.model_dump() if self.metadata else None,
|
||||
workflow=self.workflow,
|
||||
)
|
||||
|
||||
@ -515,7 +593,7 @@ class ImageNSFWBlurInvocation(BaseInvocation):
|
||||
height=image_dto.height,
|
||||
)
|
||||
|
||||
def _get_caution_img(self) -> Image:
|
||||
def _get_caution_img(self) -> Image.Image:
|
||||
import invokeai.app.assets.images as image_assets
|
||||
|
||||
caution = Image.open(Path(image_assets.__path__[0]) / "caution.png")
|
||||
@ -523,7 +601,11 @@ class ImageNSFWBlurInvocation(BaseInvocation):
|
||||
|
||||
|
||||
@invocation(
|
||||
"img_watermark", title="Add Invisible Watermark", tags=["image", "watermark"], category="image", version="1.0.0"
|
||||
"img_watermark",
|
||||
title="Add Invisible Watermark",
|
||||
tags=["image", "watermark"],
|
||||
category="image",
|
||||
version="1.0.0",
|
||||
)
|
||||
class ImageWatermarkInvocation(BaseInvocation):
|
||||
"""Add an invisible watermark to an image"""
|
||||
@ -544,7 +626,7 @@ class ImageWatermarkInvocation(BaseInvocation):
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata.dict() if self.metadata else None,
|
||||
metadata=self.metadata.model_dump() if self.metadata else None,
|
||||
workflow=self.workflow,
|
||||
)
|
||||
|
||||
@ -555,7 +637,13 @@ class ImageWatermarkInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
|
||||
@invocation("mask_edge", title="Mask Edge", tags=["image", "mask", "inpaint"], category="image", version="1.0.0")
|
||||
@invocation(
|
||||
"mask_edge",
|
||||
title="Mask Edge",
|
||||
tags=["image", "mask", "inpaint"],
|
||||
category="image",
|
||||
version="1.0.0",
|
||||
)
|
||||
class MaskEdgeInvocation(BaseInvocation):
|
||||
"""Applies an edge mask to an image"""
|
||||
|
||||
@ -601,7 +689,11 @@ class MaskEdgeInvocation(BaseInvocation):
|
||||
|
||||
|
||||
@invocation(
|
||||
"mask_combine", title="Combine Masks", tags=["image", "mask", "multiply"], category="image", version="1.0.0"
|
||||
"mask_combine",
|
||||
title="Combine Masks",
|
||||
tags=["image", "mask", "multiply"],
|
||||
category="image",
|
||||
version="1.0.0",
|
||||
)
|
||||
class MaskCombineInvocation(BaseInvocation):
|
||||
"""Combine two masks together by multiplying them using `PIL.ImageChops.multiply()`."""
|
||||
@ -632,7 +724,13 @@ class MaskCombineInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
|
||||
@invocation("color_correct", title="Color Correct", tags=["image", "color"], category="image", version="1.0.0")
|
||||
@invocation(
|
||||
"color_correct",
|
||||
title="Color Correct",
|
||||
tags=["image", "color"],
|
||||
category="image",
|
||||
version="1.0.0",
|
||||
)
|
||||
class ColorCorrectInvocation(BaseInvocation):
|
||||
"""
|
||||
Shifts the colors of a target image to match the reference image, optionally
|
||||
@ -742,7 +840,13 @@ class ColorCorrectInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
|
||||
@invocation("img_hue_adjust", title="Adjust Image Hue", tags=["image", "hue"], category="image", version="1.0.0")
|
||||
@invocation(
|
||||
"img_hue_adjust",
|
||||
title="Adjust Image Hue",
|
||||
tags=["image", "hue"],
|
||||
category="image",
|
||||
version="1.0.0",
|
||||
)
|
||||
class ImageHueAdjustmentInvocation(BaseInvocation):
|
||||
"""Adjusts the Hue of an image."""
|
||||
|
||||
@ -980,7 +1084,7 @@ class SaveImageInvocation(BaseInvocation):
|
||||
|
||||
image: ImageField = InputField(description=FieldDescriptions.image)
|
||||
board: Optional[BoardField] = InputField(default=None, description=FieldDescriptions.board, input=Input.Direct)
|
||||
metadata: CoreMetadata = InputField(
|
||||
metadata: Optional[CoreMetadata] = InputField(
|
||||
default=None,
|
||||
description=FieldDescriptions.core_metadata,
|
||||
ui_hidden=True,
|
||||
@ -997,7 +1101,7 @@ class SaveImageInvocation(BaseInvocation):
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata.dict() if self.metadata else None,
|
||||
metadata=self.metadata.model_dump() if self.metadata else None,
|
||||
workflow=self.workflow,
|
||||
)
|
||||
|
||||
|
@ -2,7 +2,7 @@ import os
|
||||
from builtins import float
|
||||
from typing import List, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
@ -25,11 +25,15 @@ class IPAdapterModelField(BaseModel):
|
||||
model_name: str = Field(description="Name of the IP-Adapter model")
|
||||
base_model: BaseModelType = Field(description="Base model")
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
|
||||
class CLIPVisionModelField(BaseModel):
|
||||
model_name: str = Field(description="Name of the CLIP Vision image encoder model")
|
||||
base_model: BaseModelType = Field(description="Base model (usually 'Any')")
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
|
||||
class IPAdapterField(BaseModel):
|
||||
image: ImageField = Field(description="The IP-Adapter image prompt.")
|
||||
|
@ -19,7 +19,7 @@ from diffusers.models.attention_processor import (
|
||||
)
|
||||
from diffusers.schedulers import DPMSolverSDEScheduler
|
||||
from diffusers.schedulers import SchedulerMixin as Scheduler
|
||||
from pydantic import validator
|
||||
from pydantic import field_validator
|
||||
from torchvision.transforms.functional import resize as tv_resize
|
||||
|
||||
from invokeai.app.invocations.ip_adapter import IPAdapterField
|
||||
@ -84,12 +84,20 @@ class SchedulerOutput(BaseInvocationOutput):
|
||||
scheduler: SAMPLER_NAME_VALUES = OutputField(description=FieldDescriptions.scheduler, ui_type=UIType.Scheduler)
|
||||
|
||||
|
||||
@invocation("scheduler", title="Scheduler", tags=["scheduler"], category="latents", version="1.0.0")
|
||||
@invocation(
|
||||
"scheduler",
|
||||
title="Scheduler",
|
||||
tags=["scheduler"],
|
||||
category="latents",
|
||||
version="1.0.0",
|
||||
)
|
||||
class SchedulerInvocation(BaseInvocation):
|
||||
"""Selects a scheduler."""
|
||||
|
||||
scheduler: SAMPLER_NAME_VALUES = InputField(
|
||||
default="euler", description=FieldDescriptions.scheduler, ui_type=UIType.Scheduler
|
||||
default="euler",
|
||||
description=FieldDescriptions.scheduler,
|
||||
ui_type=UIType.Scheduler,
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> SchedulerOutput:
|
||||
@ -97,7 +105,11 @@ class SchedulerInvocation(BaseInvocation):
|
||||
|
||||
|
||||
@invocation(
|
||||
"create_denoise_mask", title="Create Denoise Mask", tags=["mask", "denoise"], category="latents", version="1.0.0"
|
||||
"create_denoise_mask",
|
||||
title="Create Denoise Mask",
|
||||
tags=["mask", "denoise"],
|
||||
category="latents",
|
||||
version="1.0.0",
|
||||
)
|
||||
class CreateDenoiseMaskInvocation(BaseInvocation):
|
||||
"""Creates mask for denoising model run."""
|
||||
@ -106,7 +118,11 @@ class CreateDenoiseMaskInvocation(BaseInvocation):
|
||||
image: Optional[ImageField] = InputField(default=None, description="Image which will be masked", ui_order=1)
|
||||
mask: ImageField = InputField(description="The mask to use when pasting", ui_order=2)
|
||||
tiled: bool = InputField(default=False, description=FieldDescriptions.tiled, ui_order=3)
|
||||
fp32: bool = InputField(default=DEFAULT_PRECISION == "float32", description=FieldDescriptions.fp32, ui_order=4)
|
||||
fp32: bool = InputField(
|
||||
default=DEFAULT_PRECISION == "float32",
|
||||
description=FieldDescriptions.fp32,
|
||||
ui_order=4,
|
||||
)
|
||||
|
||||
def prep_mask_tensor(self, mask_image):
|
||||
if mask_image.mode != "L":
|
||||
@ -134,7 +150,7 @@ class CreateDenoiseMaskInvocation(BaseInvocation):
|
||||
|
||||
if image is not None:
|
||||
vae_info = context.services.model_manager.get_model(
|
||||
**self.vae.vae.dict(),
|
||||
**self.vae.vae.model_dump(),
|
||||
context=context,
|
||||
)
|
||||
|
||||
@ -167,7 +183,7 @@ def get_scheduler(
|
||||
) -> Scheduler:
|
||||
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP["ddim"])
|
||||
orig_scheduler_info = context.services.model_manager.get_model(
|
||||
**scheduler_info.dict(),
|
||||
**scheduler_info.model_dump(),
|
||||
context=context,
|
||||
)
|
||||
with orig_scheduler_info as orig_scheduler:
|
||||
@ -209,34 +225,64 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
negative_conditioning: ConditioningField = InputField(
|
||||
description=FieldDescriptions.negative_cond, input=Input.Connection, ui_order=1
|
||||
)
|
||||
noise: Optional[LatentsField] = InputField(description=FieldDescriptions.noise, input=Input.Connection, ui_order=3)
|
||||
noise: Optional[LatentsField] = InputField(
|
||||
default=None,
|
||||
description=FieldDescriptions.noise,
|
||||
input=Input.Connection,
|
||||
ui_order=3,
|
||||
)
|
||||
steps: int = InputField(default=10, gt=0, description=FieldDescriptions.steps)
|
||||
cfg_scale: Union[float, List[float]] = InputField(
|
||||
default=7.5, ge=1, description=FieldDescriptions.cfg_scale, title="CFG Scale"
|
||||
)
|
||||
denoising_start: float = InputField(default=0.0, ge=0, le=1, description=FieldDescriptions.denoising_start)
|
||||
denoising_start: float = InputField(
|
||||
default=0.0,
|
||||
ge=0,
|
||||
le=1,
|
||||
description=FieldDescriptions.denoising_start,
|
||||
)
|
||||
denoising_end: float = InputField(default=1.0, ge=0, le=1, description=FieldDescriptions.denoising_end)
|
||||
scheduler: SAMPLER_NAME_VALUES = InputField(
|
||||
default="euler", description=FieldDescriptions.scheduler, ui_type=UIType.Scheduler
|
||||
default="euler",
|
||||
description=FieldDescriptions.scheduler,
|
||||
ui_type=UIType.Scheduler,
|
||||
)
|
||||
unet: UNetField = InputField(description=FieldDescriptions.unet, input=Input.Connection, title="UNet", ui_order=2)
|
||||
control: Union[ControlField, list[ControlField]] = InputField(
|
||||
unet: UNetField = InputField(
|
||||
description=FieldDescriptions.unet,
|
||||
input=Input.Connection,
|
||||
title="UNet",
|
||||
ui_order=2,
|
||||
)
|
||||
control: Optional[Union[ControlField, list[ControlField]]] = InputField(
|
||||
default=None,
|
||||
input=Input.Connection,
|
||||
ui_order=5,
|
||||
)
|
||||
ip_adapter: Optional[Union[IPAdapterField, list[IPAdapterField]]] = InputField(
|
||||
description=FieldDescriptions.ip_adapter, title="IP-Adapter", default=None, input=Input.Connection, ui_order=6
|
||||
description=FieldDescriptions.ip_adapter,
|
||||
title="IP-Adapter",
|
||||
default=None,
|
||||
input=Input.Connection,
|
||||
ui_order=6,
|
||||
)
|
||||
t2i_adapter: Union[T2IAdapterField, list[T2IAdapterField]] = InputField(
|
||||
description=FieldDescriptions.t2i_adapter, title="T2I-Adapter", default=None, input=Input.Connection, ui_order=7
|
||||
t2i_adapter: Optional[Union[T2IAdapterField, list[T2IAdapterField]]] = InputField(
|
||||
description=FieldDescriptions.t2i_adapter,
|
||||
title="T2I-Adapter",
|
||||
default=None,
|
||||
input=Input.Connection,
|
||||
ui_order=7,
|
||||
)
|
||||
latents: Optional[LatentsField] = InputField(
|
||||
default=None, description=FieldDescriptions.latents, input=Input.Connection
|
||||
)
|
||||
latents: Optional[LatentsField] = InputField(description=FieldDescriptions.latents, input=Input.Connection)
|
||||
denoise_mask: Optional[DenoiseMaskField] = InputField(
|
||||
default=None, description=FieldDescriptions.mask, input=Input.Connection, ui_order=8
|
||||
default=None,
|
||||
description=FieldDescriptions.mask,
|
||||
input=Input.Connection,
|
||||
ui_order=8,
|
||||
)
|
||||
|
||||
@validator("cfg_scale")
|
||||
@field_validator("cfg_scale")
|
||||
def ge_one(cls, v):
|
||||
"""validate that all cfg_scale values are >= 1"""
|
||||
if isinstance(v, list):
|
||||
@ -259,7 +305,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
stable_diffusion_step_callback(
|
||||
context=context,
|
||||
intermediate_state=intermediate_state,
|
||||
node=self.dict(),
|
||||
node=self.model_dump(),
|
||||
source_node_id=source_node_id,
|
||||
base_model=base_model,
|
||||
)
|
||||
@ -451,9 +497,10 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
# models are needed in memory. This would help to reduce peak memory utilization in low-memory environments.
|
||||
with image_encoder_model_info as image_encoder_model:
|
||||
# Get image embeddings from CLIP and ImageProjModel.
|
||||
image_prompt_embeds, uncond_image_prompt_embeds = ip_adapter_model.get_image_embeds(
|
||||
input_image, image_encoder_model
|
||||
)
|
||||
(
|
||||
image_prompt_embeds,
|
||||
uncond_image_prompt_embeds,
|
||||
) = ip_adapter_model.get_image_embeds(input_image, image_encoder_model)
|
||||
conditioning_data.ip_adapter_conditioning.append(
|
||||
IPAdapterConditioningInfo(image_prompt_embeds, uncond_image_prompt_embeds)
|
||||
)
|
||||
@ -628,7 +675,10 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
# TODO(ryand): I have hard-coded `do_classifier_free_guidance=True` to mirror the behaviour of ControlNets,
|
||||
# below. Investigate whether this is appropriate.
|
||||
t2i_adapter_data = self.run_t2i_adapters(
|
||||
context, self.t2i_adapter, latents.shape, do_classifier_free_guidance=True
|
||||
context,
|
||||
self.t2i_adapter,
|
||||
latents.shape,
|
||||
do_classifier_free_guidance=True,
|
||||
)
|
||||
|
||||
# Get the source node id (we are invoking the prepared node)
|
||||
@ -641,7 +691,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
def _lora_loader():
|
||||
for lora in self.unet.loras:
|
||||
lora_info = context.services.model_manager.get_model(
|
||||
**lora.dict(exclude={"weight"}),
|
||||
**lora.model_dump(exclude={"weight"}),
|
||||
context=context,
|
||||
)
|
||||
yield (lora_info.context.model, lora.weight)
|
||||
@ -649,7 +699,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
return
|
||||
|
||||
unet_info = context.services.model_manager.get_model(
|
||||
**self.unet.unet.dict(),
|
||||
**self.unet.unet.model_dump(),
|
||||
context=context,
|
||||
)
|
||||
with (
|
||||
@ -700,7 +750,10 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
denoising_end=self.denoising_end,
|
||||
)
|
||||
|
||||
result_latents, result_attention_map_saver = pipeline.latents_from_embeddings(
|
||||
(
|
||||
result_latents,
|
||||
result_attention_map_saver,
|
||||
) = pipeline.latents_from_embeddings(
|
||||
latents=latents,
|
||||
timesteps=timesteps,
|
||||
init_timestep=init_timestep,
|
||||
@ -728,7 +781,11 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
|
||||
|
||||
@invocation(
|
||||
"l2i", title="Latents to Image", tags=["latents", "image", "vae", "l2i"], category="latents", version="1.0.0"
|
||||
"l2i",
|
||||
title="Latents to Image",
|
||||
tags=["latents", "image", "vae", "l2i"],
|
||||
category="latents",
|
||||
version="1.0.0",
|
||||
)
|
||||
class LatentsToImageInvocation(BaseInvocation):
|
||||
"""Generates an image from latents."""
|
||||
@ -743,7 +800,7 @@ class LatentsToImageInvocation(BaseInvocation):
|
||||
)
|
||||
tiled: bool = InputField(default=False, description=FieldDescriptions.tiled)
|
||||
fp32: bool = InputField(default=DEFAULT_PRECISION == "float32", description=FieldDescriptions.fp32)
|
||||
metadata: CoreMetadata = InputField(
|
||||
metadata: Optional[CoreMetadata] = InputField(
|
||||
default=None,
|
||||
description=FieldDescriptions.core_metadata,
|
||||
ui_hidden=True,
|
||||
@ -754,7 +811,7 @@ class LatentsToImageInvocation(BaseInvocation):
|
||||
latents = context.services.latents.get(self.latents.latents_name)
|
||||
|
||||
vae_info = context.services.model_manager.get_model(
|
||||
**self.vae.vae.dict(),
|
||||
**self.vae.vae.model_dump(),
|
||||
context=context,
|
||||
)
|
||||
|
||||
@ -816,7 +873,7 @@ class LatentsToImageInvocation(BaseInvocation):
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata.dict() if self.metadata else None,
|
||||
metadata=self.metadata.model_dump() if self.metadata else None,
|
||||
workflow=self.workflow,
|
||||
)
|
||||
|
||||
@ -830,7 +887,13 @@ class LatentsToImageInvocation(BaseInvocation):
|
||||
LATENTS_INTERPOLATION_MODE = Literal["nearest", "linear", "bilinear", "bicubic", "trilinear", "area", "nearest-exact"]
|
||||
|
||||
|
||||
@invocation("lresize", title="Resize Latents", tags=["latents", "resize"], category="latents", version="1.0.0")
|
||||
@invocation(
|
||||
"lresize",
|
||||
title="Resize Latents",
|
||||
tags=["latents", "resize"],
|
||||
category="latents",
|
||||
version="1.0.0",
|
||||
)
|
||||
class ResizeLatentsInvocation(BaseInvocation):
|
||||
"""Resizes latents to explicit width/height (in pixels). Provided dimensions are floor-divided by 8."""
|
||||
|
||||
@ -876,7 +939,13 @@ class ResizeLatentsInvocation(BaseInvocation):
|
||||
return build_latents_output(latents_name=name, latents=resized_latents, seed=self.latents.seed)
|
||||
|
||||
|
||||
@invocation("lscale", title="Scale Latents", tags=["latents", "resize"], category="latents", version="1.0.0")
|
||||
@invocation(
|
||||
"lscale",
|
||||
title="Scale Latents",
|
||||
tags=["latents", "resize"],
|
||||
category="latents",
|
||||
version="1.0.0",
|
||||
)
|
||||
class ScaleLatentsInvocation(BaseInvocation):
|
||||
"""Scales latents by a given factor."""
|
||||
|
||||
@ -915,7 +984,11 @@ class ScaleLatentsInvocation(BaseInvocation):
|
||||
|
||||
|
||||
@invocation(
|
||||
"i2l", title="Image to Latents", tags=["latents", "image", "vae", "i2l"], category="latents", version="1.0.0"
|
||||
"i2l",
|
||||
title="Image to Latents",
|
||||
tags=["latents", "image", "vae", "i2l"],
|
||||
category="latents",
|
||||
version="1.0.0",
|
||||
)
|
||||
class ImageToLatentsInvocation(BaseInvocation):
|
||||
"""Encodes an image into latents."""
|
||||
@ -979,7 +1052,7 @@ class ImageToLatentsInvocation(BaseInvocation):
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
|
||||
vae_info = context.services.model_manager.get_model(
|
||||
**self.vae.vae.dict(),
|
||||
**self.vae.vae.model_dump(),
|
||||
context=context,
|
||||
)
|
||||
|
||||
@ -1007,7 +1080,13 @@ class ImageToLatentsInvocation(BaseInvocation):
|
||||
return vae.encode(image_tensor).latents
|
||||
|
||||
|
||||
@invocation("lblend", title="Blend Latents", tags=["latents", "blend"], category="latents", version="1.0.0")
|
||||
@invocation(
|
||||
"lblend",
|
||||
title="Blend Latents",
|
||||
tags=["latents", "blend"],
|
||||
category="latents",
|
||||
version="1.0.0",
|
||||
)
|
||||
class BlendLatentsInvocation(BaseInvocation):
|
||||
"""Blend two latents using a given alpha. Latents must have same size."""
|
||||
|
||||
|
@ -3,7 +3,7 @@
|
||||
from typing import Literal
|
||||
|
||||
import numpy as np
|
||||
from pydantic import validator
|
||||
from pydantic import field_validator
|
||||
|
||||
from invokeai.app.invocations.primitives import FloatOutput, IntegerOutput
|
||||
|
||||
@ -72,7 +72,14 @@ class RandomIntInvocation(BaseInvocation):
|
||||
return IntegerOutput(value=np.random.randint(self.low, self.high))
|
||||
|
||||
|
||||
@invocation("rand_float", title="Random Float", tags=["math", "float", "random"], category="math", version="1.0.0")
|
||||
@invocation(
|
||||
"rand_float",
|
||||
title="Random Float",
|
||||
tags=["math", "float", "random"],
|
||||
category="math",
|
||||
version="1.0.1",
|
||||
use_cache=False,
|
||||
)
|
||||
class RandomFloatInvocation(BaseInvocation):
|
||||
"""Outputs a single random float"""
|
||||
|
||||
@ -178,7 +185,7 @@ class IntegerMathInvocation(BaseInvocation):
|
||||
a: int = InputField(default=0, description=FieldDescriptions.num_1)
|
||||
b: int = InputField(default=0, description=FieldDescriptions.num_2)
|
||||
|
||||
@validator("b")
|
||||
@field_validator("b")
|
||||
def no_unrepresentable_results(cls, v, values):
|
||||
if values["operation"] == "DIV" and v == 0:
|
||||
raise ValueError("Cannot divide by zero")
|
||||
@ -252,7 +259,7 @@ class FloatMathInvocation(BaseInvocation):
|
||||
a: float = InputField(default=0, description=FieldDescriptions.num_1)
|
||||
b: float = InputField(default=0, description=FieldDescriptions.num_2)
|
||||
|
||||
@validator("b")
|
||||
@field_validator("b")
|
||||
def no_unrepresentable_results(cls, v, values):
|
||||
if values["operation"] == "DIV" and v == 0:
|
||||
raise ValueError("Cannot divide by zero")
|
||||
|
@ -223,4 +223,4 @@ class MetadataAccumulatorInvocation(BaseInvocation):
|
||||
def invoke(self, context: InvocationContext) -> MetadataAccumulatorOutput:
|
||||
"""Collects and outputs a CoreMetadata object"""
|
||||
|
||||
return MetadataAccumulatorOutput(metadata=CoreMetadata(**self.dict()))
|
||||
return MetadataAccumulatorOutput(metadata=CoreMetadata(**self.model_dump()))
|
||||
|
@ -1,7 +1,7 @@
|
||||
import copy
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from ...backend.model_management import BaseModelType, ModelType, SubModelType
|
||||
from .baseinvocation import (
|
||||
@ -24,6 +24,8 @@ class ModelInfo(BaseModel):
|
||||
model_type: ModelType = Field(description="Info to load submodel")
|
||||
submodel: Optional[SubModelType] = Field(default=None, description="Info to load submodel")
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
|
||||
class LoraInfo(ModelInfo):
|
||||
weight: float = Field(description="Lora's weight which to use when apply to model")
|
||||
@ -65,6 +67,8 @@ class MainModelField(BaseModel):
|
||||
base_model: BaseModelType = Field(description="Base model")
|
||||
model_type: ModelType = Field(description="Model Type")
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
|
||||
class LoRAModelField(BaseModel):
|
||||
"""LoRA model field"""
|
||||
@ -72,8 +76,16 @@ class LoRAModelField(BaseModel):
|
||||
model_name: str = Field(description="Name of the LoRA model")
|
||||
base_model: BaseModelType = Field(description="Base model")
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
@invocation("main_model_loader", title="Main Model", tags=["model"], category="model", version="1.0.0")
|
||||
|
||||
@invocation(
|
||||
"main_model_loader",
|
||||
title="Main Model",
|
||||
tags=["model"],
|
||||
category="model",
|
||||
version="1.0.0",
|
||||
)
|
||||
class MainModelLoaderInvocation(BaseInvocation):
|
||||
"""Loads a main model, outputting its submodels."""
|
||||
|
||||
@ -180,10 +192,16 @@ class LoraLoaderInvocation(BaseInvocation):
|
||||
lora: LoRAModelField = InputField(description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA")
|
||||
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
|
||||
unet: Optional[UNetField] = InputField(
|
||||
default=None, description=FieldDescriptions.unet, input=Input.Connection, title="UNet"
|
||||
default=None,
|
||||
description=FieldDescriptions.unet,
|
||||
input=Input.Connection,
|
||||
title="UNet",
|
||||
)
|
||||
clip: Optional[ClipField] = InputField(
|
||||
default=None, description=FieldDescriptions.clip, input=Input.Connection, title="CLIP"
|
||||
default=None,
|
||||
description=FieldDescriptions.clip,
|
||||
input=Input.Connection,
|
||||
title="CLIP",
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> LoraLoaderOutput:
|
||||
@ -244,20 +262,35 @@ class SDXLLoraLoaderOutput(BaseInvocationOutput):
|
||||
clip2: Optional[ClipField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP 2")
|
||||
|
||||
|
||||
@invocation("sdxl_lora_loader", title="SDXL LoRA", tags=["lora", "model"], category="model", version="1.0.0")
|
||||
@invocation(
|
||||
"sdxl_lora_loader",
|
||||
title="SDXL LoRA",
|
||||
tags=["lora", "model"],
|
||||
category="model",
|
||||
version="1.0.0",
|
||||
)
|
||||
class SDXLLoraLoaderInvocation(BaseInvocation):
|
||||
"""Apply selected lora to unet and text_encoder."""
|
||||
|
||||
lora: LoRAModelField = InputField(description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA")
|
||||
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
|
||||
unet: Optional[UNetField] = InputField(
|
||||
default=None, description=FieldDescriptions.unet, input=Input.Connection, title="UNet"
|
||||
default=None,
|
||||
description=FieldDescriptions.unet,
|
||||
input=Input.Connection,
|
||||
title="UNet",
|
||||
)
|
||||
clip: Optional[ClipField] = InputField(
|
||||
default=None, description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 1"
|
||||
default=None,
|
||||
description=FieldDescriptions.clip,
|
||||
input=Input.Connection,
|
||||
title="CLIP 1",
|
||||
)
|
||||
clip2: Optional[ClipField] = InputField(
|
||||
default=None, description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 2"
|
||||
default=None,
|
||||
description=FieldDescriptions.clip,
|
||||
input=Input.Connection,
|
||||
title="CLIP 2",
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> SDXLLoraLoaderOutput:
|
||||
@ -330,6 +363,8 @@ class VAEModelField(BaseModel):
|
||||
model_name: str = Field(description="Name of the model")
|
||||
base_model: BaseModelType = Field(description="Base model")
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
|
||||
@invocation_output("vae_loader_output")
|
||||
class VaeLoaderOutput(BaseInvocationOutput):
|
||||
@ -343,7 +378,10 @@ class VaeLoaderInvocation(BaseInvocation):
|
||||
"""Loads a VAE model, outputting a VaeLoaderOutput"""
|
||||
|
||||
vae_model: VAEModelField = InputField(
|
||||
description=FieldDescriptions.vae_model, input=Input.Direct, ui_type=UIType.VaeModel, title="VAE"
|
||||
description=FieldDescriptions.vae_model,
|
||||
input=Input.Direct,
|
||||
ui_type=UIType.VaeModel,
|
||||
title="VAE",
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> VaeLoaderOutput:
|
||||
@ -372,19 +410,31 @@ class VaeLoaderInvocation(BaseInvocation):
|
||||
class SeamlessModeOutput(BaseInvocationOutput):
|
||||
"""Modified Seamless Model output"""
|
||||
|
||||
unet: Optional[UNetField] = OutputField(description=FieldDescriptions.unet, title="UNet")
|
||||
vae: Optional[VaeField] = OutputField(description=FieldDescriptions.vae, title="VAE")
|
||||
unet: Optional[UNetField] = OutputField(default=None, description=FieldDescriptions.unet, title="UNet")
|
||||
vae: Optional[VaeField] = OutputField(default=None, description=FieldDescriptions.vae, title="VAE")
|
||||
|
||||
|
||||
@invocation("seamless", title="Seamless", tags=["seamless", "model"], category="model", version="1.0.0")
|
||||
@invocation(
|
||||
"seamless",
|
||||
title="Seamless",
|
||||
tags=["seamless", "model"],
|
||||
category="model",
|
||||
version="1.0.0",
|
||||
)
|
||||
class SeamlessModeInvocation(BaseInvocation):
|
||||
"""Applies the seamless transformation to the Model UNet and VAE."""
|
||||
|
||||
unet: Optional[UNetField] = InputField(
|
||||
default=None, description=FieldDescriptions.unet, input=Input.Connection, title="UNet"
|
||||
default=None,
|
||||
description=FieldDescriptions.unet,
|
||||
input=Input.Connection,
|
||||
title="UNet",
|
||||
)
|
||||
vae: Optional[VaeField] = InputField(
|
||||
default=None, description=FieldDescriptions.vae_model, input=Input.Connection, title="VAE"
|
||||
default=None,
|
||||
description=FieldDescriptions.vae_model,
|
||||
input=Input.Connection,
|
||||
title="VAE",
|
||||
)
|
||||
seamless_y: bool = InputField(default=True, input=Input.Any, description="Specify whether Y axis is seamless")
|
||||
seamless_x: bool = InputField(default=True, input=Input.Any, description="Specify whether X axis is seamless")
|
||||
|
@ -2,7 +2,7 @@
|
||||
|
||||
|
||||
import torch
|
||||
from pydantic import validator
|
||||
from pydantic import field_validator
|
||||
|
||||
from invokeai.app.invocations.latent import LatentsField
|
||||
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
||||
@ -65,7 +65,7 @@ Nodes
|
||||
class NoiseOutput(BaseInvocationOutput):
|
||||
"""Invocation noise output"""
|
||||
|
||||
noise: LatentsField = OutputField(default=None, description=FieldDescriptions.noise)
|
||||
noise: LatentsField = OutputField(description=FieldDescriptions.noise)
|
||||
width: int = OutputField(description=FieldDescriptions.width)
|
||||
height: int = OutputField(description=FieldDescriptions.height)
|
||||
|
||||
@ -78,7 +78,13 @@ def build_noise_output(latents_name: str, latents: torch.Tensor, seed: int):
|
||||
)
|
||||
|
||||
|
||||
@invocation("noise", title="Noise", tags=["latents", "noise"], category="latents", version="1.0.0")
|
||||
@invocation(
|
||||
"noise",
|
||||
title="Noise",
|
||||
tags=["latents", "noise"],
|
||||
category="latents",
|
||||
version="1.0.0",
|
||||
)
|
||||
class NoiseInvocation(BaseInvocation):
|
||||
"""Generates latent noise."""
|
||||
|
||||
@ -105,7 +111,7 @@ class NoiseInvocation(BaseInvocation):
|
||||
description="Use CPU for noise generation (for reproducible results across platforms)",
|
||||
)
|
||||
|
||||
@validator("seed", pre=True)
|
||||
@field_validator("seed", mode="before")
|
||||
def modulo_seed(cls, v):
|
||||
"""Returns the seed modulo (SEED_MAX + 1) to ensure it is within the valid range."""
|
||||
return v % (SEED_MAX + 1)
|
||||
|
@ -9,7 +9,7 @@ from typing import List, Literal, Optional, Union
|
||||
import numpy as np
|
||||
import torch
|
||||
from diffusers.image_processor import VaeImageProcessor
|
||||
from pydantic import BaseModel, Field, validator
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
from tqdm import tqdm
|
||||
|
||||
from invokeai.app.invocations.metadata import CoreMetadata
|
||||
@ -63,14 +63,17 @@ class ONNXPromptInvocation(BaseInvocation):
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
||||
tokenizer_info = context.services.model_manager.get_model(
|
||||
**self.clip.tokenizer.dict(),
|
||||
**self.clip.tokenizer.model_dump(),
|
||||
)
|
||||
text_encoder_info = context.services.model_manager.get_model(
|
||||
**self.clip.text_encoder.dict(),
|
||||
**self.clip.text_encoder.model_dump(),
|
||||
)
|
||||
with tokenizer_info as orig_tokenizer, text_encoder_info as text_encoder: # , ExitStack() as stack:
|
||||
loras = [
|
||||
(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight)
|
||||
(
|
||||
context.services.model_manager.get_model(**lora.model_dump(exclude={"weight"})).context.model,
|
||||
lora.weight,
|
||||
)
|
||||
for lora in self.clip.loras
|
||||
]
|
||||
|
||||
@ -175,14 +178,14 @@ class ONNXTextToLatentsInvocation(BaseInvocation):
|
||||
description=FieldDescriptions.unet,
|
||||
input=Input.Connection,
|
||||
)
|
||||
control: Optional[Union[ControlField, list[ControlField]]] = InputField(
|
||||
control: Union[ControlField, list[ControlField]] = InputField(
|
||||
default=None,
|
||||
description=FieldDescriptions.control,
|
||||
)
|
||||
# seamless: bool = InputField(default=False, description="Whether or not to generate an image that can tile without seams", )
|
||||
# seamless_axes: str = InputField(default="", description="The axes to tile the image on, 'x' and/or 'y'")
|
||||
|
||||
@validator("cfg_scale")
|
||||
@field_validator("cfg_scale")
|
||||
def ge_one(cls, v):
|
||||
"""validate that all cfg_scale values are >= 1"""
|
||||
if isinstance(v, list):
|
||||
@ -241,7 +244,7 @@ class ONNXTextToLatentsInvocation(BaseInvocation):
|
||||
stable_diffusion_step_callback(
|
||||
context=context,
|
||||
intermediate_state=intermediate_state,
|
||||
node=self.dict(),
|
||||
node=self.model_dump(),
|
||||
source_node_id=source_node_id,
|
||||
)
|
||||
|
||||
@ -254,12 +257,15 @@ class ONNXTextToLatentsInvocation(BaseInvocation):
|
||||
eta=0.0,
|
||||
)
|
||||
|
||||
unet_info = context.services.model_manager.get_model(**self.unet.unet.dict())
|
||||
unet_info = context.services.model_manager.get_model(**self.unet.unet.model_dump())
|
||||
|
||||
with unet_info as unet: # , ExitStack() as stack:
|
||||
# loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.unet.loras]
|
||||
loras = [
|
||||
(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight)
|
||||
(
|
||||
context.services.model_manager.get_model(**lora.model_dump(exclude={"weight"})).context.model,
|
||||
lora.weight,
|
||||
)
|
||||
for lora in self.unet.loras
|
||||
]
|
||||
|
||||
@ -346,7 +352,7 @@ class ONNXLatentsToImageInvocation(BaseInvocation):
|
||||
raise Exception(f"Expected vae_decoder, found: {self.vae.vae.model_type}")
|
||||
|
||||
vae_info = context.services.model_manager.get_model(
|
||||
**self.vae.vae.dict(),
|
||||
**self.vae.vae.model_dump(),
|
||||
)
|
||||
|
||||
# clear memory as vae decode can request a lot
|
||||
@ -375,7 +381,7 @@ class ONNXLatentsToImageInvocation(BaseInvocation):
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata.dict() if self.metadata else None,
|
||||
metadata=self.metadata.model_dump() if self.metadata else None,
|
||||
workflow=self.workflow,
|
||||
)
|
||||
|
||||
@ -403,6 +409,8 @@ class OnnxModelField(BaseModel):
|
||||
base_model: BaseModelType = Field(description="Base model")
|
||||
model_type: ModelType = Field(description="Model Type")
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
|
||||
@invocation("onnx_model_loader", title="ONNX Main Model", tags=["onnx", "model"], category="model", version="1.0.0")
|
||||
class OnnxModelLoaderInvocation(BaseInvocation):
|
||||
|
@ -44,13 +44,22 @@ from invokeai.app.invocations.primitives import FloatCollectionOutput
|
||||
from .baseinvocation import BaseInvocation, InputField, InvocationContext, invocation
|
||||
|
||||
|
||||
@invocation("float_range", title="Float Range", tags=["math", "range"], category="math", version="1.0.0")
|
||||
@invocation(
|
||||
"float_range",
|
||||
title="Float Range",
|
||||
tags=["math", "range"],
|
||||
category="math",
|
||||
version="1.0.0",
|
||||
)
|
||||
class FloatLinearRangeInvocation(BaseInvocation):
|
||||
"""Creates a range"""
|
||||
|
||||
start: float = InputField(default=5, description="The first value of the range")
|
||||
stop: float = InputField(default=10, description="The last value of the range")
|
||||
steps: int = InputField(default=30, description="number of values to interpolate over (including start and stop)")
|
||||
steps: int = InputField(
|
||||
default=30,
|
||||
description="number of values to interpolate over (including start and stop)",
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> FloatCollectionOutput:
|
||||
param_list = list(np.linspace(self.start, self.stop, self.steps))
|
||||
@ -95,7 +104,13 @@ EASING_FUNCTION_KEYS = Literal[tuple(list(EASING_FUNCTIONS_MAP.keys()))]
|
||||
|
||||
|
||||
# actually I think for now could just use CollectionOutput (which is list[Any]
|
||||
@invocation("step_param_easing", title="Step Param Easing", tags=["step", "easing"], category="step", version="1.0.0")
|
||||
@invocation(
|
||||
"step_param_easing",
|
||||
title="Step Param Easing",
|
||||
tags=["step", "easing"],
|
||||
category="step",
|
||||
version="1.0.0",
|
||||
)
|
||||
class StepParamEasingInvocation(BaseInvocation):
|
||||
"""Experimental per-step parameter easing for denoising steps"""
|
||||
|
||||
@ -159,7 +174,9 @@ class StepParamEasingInvocation(BaseInvocation):
|
||||
context.services.logger.debug("base easing duration: " + str(base_easing_duration))
|
||||
even_num_steps = num_easing_steps % 2 == 0 # even number of steps
|
||||
easing_function = easing_class(
|
||||
start=self.start_value, end=self.end_value, duration=base_easing_duration - 1
|
||||
start=self.start_value,
|
||||
end=self.end_value,
|
||||
duration=base_easing_duration - 1,
|
||||
)
|
||||
base_easing_vals = list()
|
||||
for step_index in range(base_easing_duration):
|
||||
@ -199,7 +216,11 @@ class StepParamEasingInvocation(BaseInvocation):
|
||||
#
|
||||
|
||||
else: # no mirroring (default)
|
||||
easing_function = easing_class(start=self.start_value, end=self.end_value, duration=num_easing_steps - 1)
|
||||
easing_function = easing_class(
|
||||
start=self.start_value,
|
||||
end=self.end_value,
|
||||
duration=num_easing_steps - 1,
|
||||
)
|
||||
for step_index in range(num_easing_steps):
|
||||
step_val = easing_function.ease(step_index)
|
||||
easing_list.append(step_val)
|
||||
|
@ -3,7 +3,7 @@ from typing import Optional, Union
|
||||
|
||||
import numpy as np
|
||||
from dynamicprompts.generators import CombinatorialPromptGenerator, RandomPromptGenerator
|
||||
from pydantic import validator
|
||||
from pydantic import field_validator
|
||||
|
||||
from invokeai.app.invocations.primitives import StringCollectionOutput
|
||||
|
||||
@ -21,7 +21,10 @@ from .baseinvocation import BaseInvocation, InputField, InvocationContext, UICom
|
||||
class DynamicPromptInvocation(BaseInvocation):
|
||||
"""Parses a prompt using adieyal/dynamicprompts' random or combinatorial generator"""
|
||||
|
||||
prompt: str = InputField(description="The prompt to parse with dynamicprompts", ui_component=UIComponent.Textarea)
|
||||
prompt: str = InputField(
|
||||
description="The prompt to parse with dynamicprompts",
|
||||
ui_component=UIComponent.Textarea,
|
||||
)
|
||||
max_prompts: int = InputField(default=1, description="The number of prompts to generate")
|
||||
combinatorial: bool = InputField(default=False, description="Whether to use the combinatorial generator")
|
||||
|
||||
@ -36,21 +39,31 @@ class DynamicPromptInvocation(BaseInvocation):
|
||||
return StringCollectionOutput(collection=prompts)
|
||||
|
||||
|
||||
@invocation("prompt_from_file", title="Prompts from File", tags=["prompt", "file"], category="prompt", version="1.0.0")
|
||||
@invocation(
|
||||
"prompt_from_file",
|
||||
title="Prompts from File",
|
||||
tags=["prompt", "file"],
|
||||
category="prompt",
|
||||
version="1.0.0",
|
||||
)
|
||||
class PromptsFromFileInvocation(BaseInvocation):
|
||||
"""Loads prompts from a text file"""
|
||||
|
||||
file_path: str = InputField(description="Path to prompt text file")
|
||||
pre_prompt: Optional[str] = InputField(
|
||||
default=None, description="String to prepend to each prompt", ui_component=UIComponent.Textarea
|
||||
default=None,
|
||||
description="String to prepend to each prompt",
|
||||
ui_component=UIComponent.Textarea,
|
||||
)
|
||||
post_prompt: Optional[str] = InputField(
|
||||
default=None, description="String to append to each prompt", ui_component=UIComponent.Textarea
|
||||
default=None,
|
||||
description="String to append to each prompt",
|
||||
ui_component=UIComponent.Textarea,
|
||||
)
|
||||
start_line: int = InputField(default=1, ge=1, description="Line in the file to start start from")
|
||||
max_prompts: int = InputField(default=1, ge=0, description="Max lines to read from file (0=all)")
|
||||
|
||||
@validator("file_path")
|
||||
@field_validator("file_path")
|
||||
def file_path_exists(cls, v):
|
||||
if not exists(v):
|
||||
raise ValueError(FileNotFoundError)
|
||||
@ -79,6 +92,10 @@ class PromptsFromFileInvocation(BaseInvocation):
|
||||
|
||||
def invoke(self, context: InvocationContext) -> StringCollectionOutput:
|
||||
prompts = self.promptsFromFile(
|
||||
self.file_path, self.pre_prompt, self.post_prompt, self.start_line, self.max_prompts
|
||||
self.file_path,
|
||||
self.pre_prompt,
|
||||
self.post_prompt,
|
||||
self.start_line,
|
||||
self.max_prompts,
|
||||
)
|
||||
return StringCollectionOutput(collection=prompts)
|
||||
|
@ -1,6 +1,6 @@
|
||||
from typing import Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
@ -23,6 +23,8 @@ class T2IAdapterModelField(BaseModel):
|
||||
model_name: str = Field(description="Name of the T2I-Adapter model")
|
||||
base_model: BaseModelType = Field(description="Base model")
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
|
||||
class T2IAdapterField(BaseModel):
|
||||
image: ImageField = Field(description="The T2I-Adapter image prompt.")
|
||||
|
@ -7,6 +7,7 @@ import numpy as np
|
||||
import torch
|
||||
from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||
from PIL import Image
|
||||
from pydantic import ConfigDict
|
||||
from realesrgan import RealESRGANer
|
||||
|
||||
from invokeai.app.invocations.primitives import ImageField, ImageOutput
|
||||
@ -38,6 +39,8 @@ class ESRGANInvocation(BaseInvocation):
|
||||
default=400, ge=0, description="Tile size for tiled ESRGAN upscaling (0=tiling disabled)"
|
||||
)
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
models_path = context.services.configuration.models_path
|
||||
|
@ -1,7 +1,7 @@
|
||||
from datetime import datetime
|
||||
from typing import Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Extra, Field
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.util.misc import get_iso_timestamp
|
||||
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
|
||||
@ -18,9 +18,9 @@ class BoardRecord(BaseModelExcludeNull):
|
||||
"""The created timestamp of the image."""
|
||||
updated_at: Union[datetime, str] = Field(description="The updated timestamp of the board.")
|
||||
"""The updated timestamp of the image."""
|
||||
deleted_at: Union[datetime, str, None] = Field(description="The deleted timestamp of the board.")
|
||||
deleted_at: Optional[Union[datetime, str]] = Field(default=None, description="The deleted timestamp of the board.")
|
||||
"""The updated timestamp of the image."""
|
||||
cover_image_name: Optional[str] = Field(description="The name of the cover image of the board.")
|
||||
cover_image_name: Optional[str] = Field(default=None, description="The name of the cover image of the board.")
|
||||
"""The name of the cover image of the board."""
|
||||
|
||||
|
||||
@ -46,9 +46,9 @@ def deserialize_board_record(board_dict: dict) -> BoardRecord:
|
||||
)
|
||||
|
||||
|
||||
class BoardChanges(BaseModel, extra=Extra.forbid):
|
||||
board_name: Optional[str] = Field(description="The board's new name.")
|
||||
cover_image_name: Optional[str] = Field(description="The name of the board's new cover image.")
|
||||
class BoardChanges(BaseModel, extra="forbid"):
|
||||
board_name: Optional[str] = Field(default=None, description="The board's new name.")
|
||||
cover_image_name: Optional[str] = Field(default=None, description="The name of the board's new cover image.")
|
||||
|
||||
|
||||
class BoardRecordNotFoundException(Exception):
|
||||
|
@ -17,7 +17,7 @@ class BoardDTO(BoardRecord):
|
||||
def board_record_to_dto(board_record: BoardRecord, cover_image_name: Optional[str], image_count: int) -> BoardDTO:
|
||||
"""Converts a board record to a board DTO."""
|
||||
return BoardDTO(
|
||||
**board_record.dict(exclude={"cover_image_name"}),
|
||||
**board_record.model_dump(exclude={"cover_image_name"}),
|
||||
cover_image_name=cover_image_name,
|
||||
image_count=image_count,
|
||||
)
|
||||
|
@ -18,7 +18,7 @@ from pathlib import Path
|
||||
from typing import ClassVar, Dict, List, Literal, Optional, Union, get_args, get_origin, get_type_hints
|
||||
|
||||
from omegaconf import DictConfig, ListConfig, OmegaConf
|
||||
from pydantic import BaseSettings
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
from invokeai.app.services.config.config_common import PagingArgumentParser, int_or_float_or_str
|
||||
|
||||
@ -32,12 +32,14 @@ class InvokeAISettings(BaseSettings):
|
||||
initconf: ClassVar[Optional[DictConfig]] = None
|
||||
argparse_groups: ClassVar[Dict] = {}
|
||||
|
||||
model_config = SettingsConfigDict(env_file_encoding="utf-8", arbitrary_types_allowed=True, case_sensitive=True)
|
||||
|
||||
def parse_args(self, argv: Optional[list] = sys.argv[1:]):
|
||||
parser = self.get_parser()
|
||||
opt, unknown_opts = parser.parse_known_args(argv)
|
||||
if len(unknown_opts) > 0:
|
||||
print("Unknown args:", unknown_opts)
|
||||
for name in self.__fields__:
|
||||
for name in self.model_fields:
|
||||
if name not in self._excluded():
|
||||
value = getattr(opt, name)
|
||||
if isinstance(value, ListConfig):
|
||||
@ -54,10 +56,12 @@ class InvokeAISettings(BaseSettings):
|
||||
cls = self.__class__
|
||||
type = get_args(get_type_hints(cls)["type"])[0]
|
||||
field_dict = dict({type: dict()})
|
||||
for name, field in self.__fields__.items():
|
||||
for name, field in self.model_fields.items():
|
||||
if name in cls._excluded_from_yaml():
|
||||
continue
|
||||
category = field.field_info.extra.get("category") or "Uncategorized"
|
||||
category = (
|
||||
field.json_schema_extra.get("category", "Uncategorized") if field.json_schema_extra else "Uncategorized"
|
||||
)
|
||||
value = getattr(self, name)
|
||||
if category not in field_dict[type]:
|
||||
field_dict[type][category] = dict()
|
||||
@ -73,7 +77,7 @@ class InvokeAISettings(BaseSettings):
|
||||
else:
|
||||
settings_stanza = "Uncategorized"
|
||||
|
||||
env_prefix = getattr(cls.Config, "env_prefix", None)
|
||||
env_prefix = getattr(cls.model_config, "env_prefix", None)
|
||||
env_prefix = env_prefix if env_prefix is not None else settings_stanza.upper()
|
||||
|
||||
initconf = (
|
||||
@ -89,14 +93,18 @@ class InvokeAISettings(BaseSettings):
|
||||
for key, value in os.environ.items():
|
||||
upcase_environ[key.upper()] = value
|
||||
|
||||
fields = cls.__fields__
|
||||
fields = cls.model_fields
|
||||
cls.argparse_groups = {}
|
||||
|
||||
for name, field in fields.items():
|
||||
if name not in cls._excluded():
|
||||
current_default = field.default
|
||||
|
||||
category = field.field_info.extra.get("category", "Uncategorized")
|
||||
category = (
|
||||
field.json_schema_extra.get("category", "Uncategorized")
|
||||
if field.json_schema_extra
|
||||
else "Uncategorized"
|
||||
)
|
||||
env_name = env_prefix + "_" + name
|
||||
if category in initconf and name in initconf.get(category):
|
||||
field.default = initconf.get(category).get(name)
|
||||
@ -146,11 +154,6 @@ class InvokeAISettings(BaseSettings):
|
||||
"tiled_decode",
|
||||
]
|
||||
|
||||
class Config:
|
||||
env_file_encoding = "utf-8"
|
||||
arbitrary_types_allowed = True
|
||||
case_sensitive = True
|
||||
|
||||
@classmethod
|
||||
def add_field_argument(cls, command_parser, name: str, field, default_override=None):
|
||||
field_type = get_type_hints(cls).get(name)
|
||||
@ -161,7 +164,7 @@ class InvokeAISettings(BaseSettings):
|
||||
if field.default_factory is None
|
||||
else field.default_factory()
|
||||
)
|
||||
if category := field.field_info.extra.get("category"):
|
||||
if category := (field.json_schema_extra.get("category", None) if field.json_schema_extra else None):
|
||||
if category not in cls.argparse_groups:
|
||||
cls.argparse_groups[category] = command_parser.add_argument_group(category)
|
||||
argparse_group = cls.argparse_groups[category]
|
||||
@ -169,7 +172,7 @@ class InvokeAISettings(BaseSettings):
|
||||
argparse_group = command_parser
|
||||
|
||||
if get_origin(field_type) == Literal:
|
||||
allowed_values = get_args(field.type_)
|
||||
allowed_values = get_args(field.annotation)
|
||||
allowed_types = set()
|
||||
for val in allowed_values:
|
||||
allowed_types.add(type(val))
|
||||
@ -182,7 +185,7 @@ class InvokeAISettings(BaseSettings):
|
||||
type=field_type,
|
||||
default=default,
|
||||
choices=allowed_values,
|
||||
help=field.field_info.description,
|
||||
help=field.description,
|
||||
)
|
||||
|
||||
elif get_origin(field_type) == Union:
|
||||
@ -191,7 +194,7 @@ class InvokeAISettings(BaseSettings):
|
||||
dest=name,
|
||||
type=int_or_float_or_str,
|
||||
default=default,
|
||||
help=field.field_info.description,
|
||||
help=field.description,
|
||||
)
|
||||
|
||||
elif get_origin(field_type) == list:
|
||||
@ -199,17 +202,17 @@ class InvokeAISettings(BaseSettings):
|
||||
f"--{name}",
|
||||
dest=name,
|
||||
nargs="*",
|
||||
type=field.type_,
|
||||
type=field.annotation,
|
||||
default=default,
|
||||
action=argparse.BooleanOptionalAction if field.type_ == bool else "store",
|
||||
help=field.field_info.description,
|
||||
action=argparse.BooleanOptionalAction if field.annotation == bool else "store",
|
||||
help=field.description,
|
||||
)
|
||||
else:
|
||||
argparse_group.add_argument(
|
||||
f"--{name}",
|
||||
dest=name,
|
||||
type=field.type_,
|
||||
type=field.annotation,
|
||||
default=default,
|
||||
action=argparse.BooleanOptionalAction if field.type_ == bool else "store",
|
||||
help=field.field_info.description,
|
||||
action=argparse.BooleanOptionalAction if field.annotation == bool else "store",
|
||||
help=field.description,
|
||||
)
|
||||
|
@ -144,8 +144,8 @@ which is set to the desired top-level name. For example, to create a
|
||||
|
||||
class InvokeBatch(InvokeAISettings):
|
||||
type: Literal["InvokeBatch"] = "InvokeBatch"
|
||||
node_count : int = Field(default=1, description="Number of nodes to run on", category='Resources')
|
||||
cpu_count : int = Field(default=8, description="Number of GPUs to run on per node", category='Resources')
|
||||
node_count : int = Field(default=1, description="Number of nodes to run on", json_schema_extra=dict(category='Resources'))
|
||||
cpu_count : int = Field(default=8, description="Number of GPUs to run on per node", json_schema_extra=dict(category='Resources'))
|
||||
|
||||
This will now read and write from the "InvokeBatch" section of the
|
||||
config file, look for environment variables named INVOKEBATCH_*, and
|
||||
@ -175,7 +175,8 @@ from pathlib import Path
|
||||
from typing import ClassVar, Dict, List, Literal, Optional, Union, get_type_hints
|
||||
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
from pydantic import Field, parse_obj_as
|
||||
from pydantic import Field, TypeAdapter
|
||||
from pydantic_settings import SettingsConfigDict
|
||||
|
||||
from .config_base import InvokeAISettings
|
||||
|
||||
@ -185,6 +186,21 @@ LEGACY_INIT_FILE = Path("invokeai.init")
|
||||
DEFAULT_MAX_VRAM = 0.5
|
||||
|
||||
|
||||
class Categories(object):
|
||||
WebServer = dict(category="Web Server")
|
||||
Features = dict(category="Features")
|
||||
Paths = dict(category="Paths")
|
||||
Logging = dict(category="Logging")
|
||||
Development = dict(category="Development")
|
||||
Other = dict(category="Other")
|
||||
ModelCache = dict(category="Model Cache")
|
||||
Device = dict(category="Device")
|
||||
Generation = dict(category="Generation")
|
||||
Queue = dict(category="Queue")
|
||||
Nodes = dict(category="Nodes")
|
||||
MemoryPerformance = dict(category="Memory/Performance")
|
||||
|
||||
|
||||
class InvokeAIAppConfig(InvokeAISettings):
|
||||
"""
|
||||
Generate images using Stable Diffusion. Use "invokeai" to launch
|
||||
@ -201,86 +217,88 @@ class InvokeAIAppConfig(InvokeAISettings):
|
||||
type: Literal["InvokeAI"] = "InvokeAI"
|
||||
|
||||
# WEB
|
||||
host : str = Field(default="127.0.0.1", description="IP address to bind to", category='Web Server')
|
||||
port : int = Field(default=9090, description="Port to bind to", category='Web Server')
|
||||
allow_origins : List[str] = Field(default=[], description="Allowed CORS origins", category='Web Server')
|
||||
allow_credentials : bool = Field(default=True, description="Allow CORS credentials", category='Web Server')
|
||||
allow_methods : List[str] = Field(default=["*"], description="Methods allowed for CORS", category='Web Server')
|
||||
allow_headers : List[str] = Field(default=["*"], description="Headers allowed for CORS", category='Web Server')
|
||||
host : str = Field(default="127.0.0.1", description="IP address to bind to", json_schema_extra=Categories.WebServer)
|
||||
port : int = Field(default=9090, description="Port to bind to", json_schema_extra=Categories.WebServer)
|
||||
allow_origins : List[str] = Field(default=[], description="Allowed CORS origins", json_schema_extra=Categories.WebServer)
|
||||
allow_credentials : bool = Field(default=True, description="Allow CORS credentials", json_schema_extra=Categories.WebServer)
|
||||
allow_methods : List[str] = Field(default=["*"], description="Methods allowed for CORS", json_schema_extra=Categories.WebServer)
|
||||
allow_headers : List[str] = Field(default=["*"], description="Headers allowed for CORS", json_schema_extra=Categories.WebServer)
|
||||
|
||||
# FEATURES
|
||||
esrgan : bool = Field(default=True, description="Enable/disable upscaling code", category='Features')
|
||||
internet_available : bool = Field(default=True, description="If true, attempt to download models on the fly; otherwise only use local models", category='Features')
|
||||
log_tokenization : bool = Field(default=False, description="Enable logging of parsed prompt tokens.", category='Features')
|
||||
patchmatch : bool = Field(default=True, description="Enable/disable patchmatch inpaint code", category='Features')
|
||||
ignore_missing_core_models : bool = Field(default=False, description='Ignore missing models in models/core/convert', category='Features')
|
||||
esrgan : bool = Field(default=True, description="Enable/disable upscaling code", json_schema_extra=Categories.Features)
|
||||
internet_available : bool = Field(default=True, description="If true, attempt to download models on the fly; otherwise only use local models", json_schema_extra=Categories.Features)
|
||||
log_tokenization : bool = Field(default=False, description="Enable logging of parsed prompt tokens.", json_schema_extra=Categories.Features)
|
||||
patchmatch : bool = Field(default=True, description="Enable/disable patchmatch inpaint code", json_schema_extra=Categories.Features)
|
||||
ignore_missing_core_models : bool = Field(default=False, description='Ignore missing models in models/core/convert', json_schema_extra=Categories.Features)
|
||||
|
||||
# PATHS
|
||||
root : Path = Field(default=None, description='InvokeAI runtime root directory', category='Paths')
|
||||
autoimport_dir : Path = Field(default='autoimport', description='Path to a directory of models files to be imported on startup.', category='Paths')
|
||||
lora_dir : Path = Field(default=None, description='Path to a directory of LoRA/LyCORIS models to be imported on startup.', category='Paths')
|
||||
embedding_dir : Path = Field(default=None, description='Path to a directory of Textual Inversion embeddings to be imported on startup.', category='Paths')
|
||||
controlnet_dir : Path = Field(default=None, description='Path to a directory of ControlNet embeddings to be imported on startup.', category='Paths')
|
||||
conf_path : Path = Field(default='configs/models.yaml', description='Path to models definition file', category='Paths')
|
||||
models_dir : Path = Field(default='models', description='Path to the models directory', category='Paths')
|
||||
legacy_conf_dir : Path = Field(default='configs/stable-diffusion', description='Path to directory of legacy checkpoint config files', category='Paths')
|
||||
db_dir : Path = Field(default='databases', description='Path to InvokeAI databases directory', category='Paths')
|
||||
outdir : Path = Field(default='outputs', description='Default folder for output images', category='Paths')
|
||||
use_memory_db : bool = Field(default=False, description='Use in-memory database for storing image metadata', category='Paths')
|
||||
from_file : Path = Field(default=None, description='Take command input from the indicated file (command-line client only)', category='Paths')
|
||||
root : Optional[Path] = Field(default=None, description='InvokeAI runtime root directory', json_schema_extra=Categories.Paths)
|
||||
autoimport_dir : Optional[Path] = Field(default=Path('autoimport'), description='Path to a directory of models files to be imported on startup.', json_schema_extra=Categories.Paths)
|
||||
lora_dir : Optional[Path] = Field(default=None, description='Path to a directory of LoRA/LyCORIS models to be imported on startup.', json_schema_extra=Categories.Paths)
|
||||
embedding_dir : Optional[Path] = Field(default=None, description='Path to a directory of Textual Inversion embeddings to be imported on startup.', json_schema_extra=Categories.Paths)
|
||||
controlnet_dir : Optional[Path] = Field(default=None, description='Path to a directory of ControlNet embeddings to be imported on startup.', json_schema_extra=Categories.Paths)
|
||||
conf_path : Optional[Path] = Field(default=Path('configs/models.yaml'), description='Path to models definition file', json_schema_extra=Categories.Paths)
|
||||
models_dir : Optional[Path] = Field(default=Path('models'), description='Path to the models directory', json_schema_extra=Categories.Paths)
|
||||
legacy_conf_dir : Optional[Path] = Field(default=Path('configs/stable-diffusion'), description='Path to directory of legacy checkpoint config files', json_schema_extra=Categories.Paths)
|
||||
db_dir : Optional[Path] = Field(default=Path('databases'), description='Path to InvokeAI databases directory', json_schema_extra=Categories.Paths)
|
||||
outdir : Optional[Path] = Field(default=Path('outputs'), description='Default folder for output images', json_schema_extra=Categories.Paths)
|
||||
use_memory_db : bool = Field(default=False, description='Use in-memory database for storing image metadata', json_schema_extra=Categories.Paths)
|
||||
from_file : Optional[Path] = Field(default=None, description='Take command input from the indicated file (command-line client only)', json_schema_extra=Categories.Paths)
|
||||
|
||||
# LOGGING
|
||||
log_handlers : List[str] = Field(default=["console"], description='Log handler. Valid options are "console", "file=<path>", "syslog=path|address:host:port", "http=<url>"', category="Logging")
|
||||
log_handlers : List[str] = Field(default=["console"], description='Log handler. Valid options are "console", "file=<path>", "syslog=path|address:host:port", "http=<url>"', json_schema_extra=Categories.Logging)
|
||||
# note - would be better to read the log_format values from logging.py, but this creates circular dependencies issues
|
||||
log_format : Literal['plain', 'color', 'syslog', 'legacy'] = Field(default="color", description='Log format. Use "plain" for text-only, "color" for colorized output, "legacy" for 2.3-style logging and "syslog" for syslog-style', category="Logging")
|
||||
log_level : Literal["debug", "info", "warning", "error", "critical"] = Field(default="info", description="Emit logging messages at this level or higher", category="Logging")
|
||||
log_sql : bool = Field(default=False, description="Log SQL queries", category="Logging")
|
||||
log_format : Literal['plain', 'color', 'syslog', 'legacy'] = Field(default="color", description='Log format. Use "plain" for text-only, "color" for colorized output, "legacy" for 2.3-style logging and "syslog" for syslog-style', json_schema_extra=Categories.Logging)
|
||||
log_level : Literal["debug", "info", "warning", "error", "critical"] = Field(default="info", description="Emit logging messages at this level or higher", json_schema_extra=Categories.Logging)
|
||||
log_sql : bool = Field(default=False, description="Log SQL queries", json_schema_extra=Categories.Logging)
|
||||
|
||||
dev_reload : bool = Field(default=False, description="Automatically reload when Python sources are changed.", category="Development")
|
||||
dev_reload : bool = Field(default=False, description="Automatically reload when Python sources are changed.", json_schema_extra=Categories.Development)
|
||||
|
||||
version : bool = Field(default=False, description="Show InvokeAI version and exit", category="Other")
|
||||
version : bool = Field(default=False, description="Show InvokeAI version and exit", json_schema_extra=Categories.Other)
|
||||
|
||||
# CACHE
|
||||
ram : float = Field(default=7.5, gt=0, description="Maximum memory amount used by model cache for rapid switching (floating point number, GB)", category="Model Cache", )
|
||||
vram : float = Field(default=0.25, ge=0, description="Amount of VRAM reserved for model storage (floating point number, GB)", category="Model Cache", )
|
||||
lazy_offload : bool = Field(default=True, description="Keep models in VRAM until their space is needed", category="Model Cache", )
|
||||
ram : float = Field(default=7.5, gt=0, description="Maximum memory amount used by model cache for rapid switching (floating point number, GB)", json_schema_extra=Categories.ModelCache, )
|
||||
vram : float = Field(default=0.25, ge=0, description="Amount of VRAM reserved for model storage (floating point number, GB)", json_schema_extra=Categories.ModelCache, )
|
||||
lazy_offload : bool = Field(default=True, description="Keep models in VRAM until their space is needed", json_schema_extra=Categories.ModelCache, )
|
||||
|
||||
# DEVICE
|
||||
device : Literal["auto", "cpu", "cuda", "cuda:1", "mps"] = Field(default="auto", description="Generation device", category="Device", )
|
||||
precision : Literal["auto", "float16", "float32", "autocast"] = Field(default="auto", description="Floating point precision", category="Device", )
|
||||
device : Literal["auto", "cpu", "cuda", "cuda:1", "mps"] = Field(default="auto", description="Generation device", json_schema_extra=Categories.Device)
|
||||
precision : Literal["auto", "float16", "float32", "autocast"] = Field(default="auto", description="Floating point precision", json_schema_extra=Categories.Device)
|
||||
|
||||
# GENERATION
|
||||
sequential_guidance : bool = Field(default=False, description="Whether to calculate guidance in serial instead of in parallel, lowering memory requirements", category="Generation", )
|
||||
attention_type : Literal["auto", "normal", "xformers", "sliced", "torch-sdp"] = Field(default="auto", description="Attention type", category="Generation", )
|
||||
attention_slice_size: Literal["auto", "balanced", "max", 1, 2, 3, 4, 5, 6, 7, 8] = Field(default="auto", description='Slice size, valid when attention_type=="sliced"', category="Generation", )
|
||||
force_tiled_decode : bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty)", category="Generation",)
|
||||
force_tiled_decode: bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty)", category="Generation",)
|
||||
png_compress_level : int = Field(default=6, description="The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = fastest, largest filesize, 9 = slowest, smallest filesize", category="Generation", )
|
||||
sequential_guidance : bool = Field(default=False, description="Whether to calculate guidance in serial instead of in parallel, lowering memory requirements", json_schema_extra=Categories.Generation)
|
||||
attention_type : Literal["auto", "normal", "xformers", "sliced", "torch-sdp"] = Field(default="auto", description="Attention type", json_schema_extra=Categories.Generation)
|
||||
attention_slice_size: Literal["auto", "balanced", "max", 1, 2, 3, 4, 5, 6, 7, 8] = Field(default="auto", description='Slice size, valid when attention_type=="sliced"', json_schema_extra=Categories.Generation)
|
||||
force_tiled_decode : bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty)", json_schema_extra=Categories.Generation)
|
||||
png_compress_level : int = Field(default=6, description="The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = fastest, largest filesize, 9 = slowest, smallest filesize", json_schema_extra=Categories.Generation)
|
||||
|
||||
# QUEUE
|
||||
max_queue_size : int = Field(default=10000, gt=0, description="Maximum number of items in the session queue", category="Queue", )
|
||||
max_queue_size : int = Field(default=10000, gt=0, description="Maximum number of items in the session queue", json_schema_extra=Categories.Queue)
|
||||
|
||||
# NODES
|
||||
allow_nodes : Optional[List[str]] = Field(default=None, description="List of nodes to allow. Omit to allow all.", category="Nodes")
|
||||
deny_nodes : Optional[List[str]] = Field(default=None, description="List of nodes to deny. Omit to deny none.", category="Nodes")
|
||||
node_cache_size : int = Field(default=512, description="How many cached nodes to keep in memory", category="Nodes", )
|
||||
allow_nodes : Optional[List[str]] = Field(default=None, description="List of nodes to allow. Omit to allow all.", json_schema_extra=Categories.Nodes)
|
||||
deny_nodes : Optional[List[str]] = Field(default=None, description="List of nodes to deny. Omit to deny none.", json_schema_extra=Categories.Nodes)
|
||||
node_cache_size : int = Field(default=512, description="How many cached nodes to keep in memory", json_schema_extra=Categories.Nodes)
|
||||
|
||||
# DEPRECATED FIELDS - STILL HERE IN ORDER TO OBTAN VALUES FROM PRE-3.1 CONFIG FILES
|
||||
always_use_cpu : bool = Field(default=False, description="If true, use the CPU for rendering even if a GPU is available.", category='Memory/Performance')
|
||||
free_gpu_mem : Optional[bool] = Field(default=None, description="If true, purge model from GPU after each generation.", category='Memory/Performance')
|
||||
max_cache_size : Optional[float] = Field(default=None, gt=0, description="Maximum memory amount used by model cache for rapid switching", category='Memory/Performance')
|
||||
max_vram_cache_size : Optional[float] = Field(default=None, ge=0, description="Amount of VRAM reserved for model storage", category='Memory/Performance')
|
||||
xformers_enabled : bool = Field(default=True, description="Enable/disable memory-efficient attention", category='Memory/Performance')
|
||||
tiled_decode : bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty)", category='Memory/Performance')
|
||||
always_use_cpu : bool = Field(default=False, description="If true, use the CPU for rendering even if a GPU is available.", json_schema_extra=Categories.MemoryPerformance)
|
||||
free_gpu_mem : Optional[bool] = Field(default=None, description="If true, purge model from GPU after each generation.", json_schema_extra=Categories.MemoryPerformance)
|
||||
max_cache_size : Optional[float] = Field(default=None, gt=0, description="Maximum memory amount used by model cache for rapid switching", json_schema_extra=Categories.MemoryPerformance)
|
||||
max_vram_cache_size : Optional[float] = Field(default=None, ge=0, description="Amount of VRAM reserved for model storage", json_schema_extra=Categories.MemoryPerformance)
|
||||
xformers_enabled : bool = Field(default=True, description="Enable/disable memory-efficient attention", json_schema_extra=Categories.MemoryPerformance)
|
||||
tiled_decode : bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty)", json_schema_extra=Categories.MemoryPerformance)
|
||||
|
||||
# See InvokeAIAppConfig subclass below for CACHE and DEVICE categories
|
||||
# fmt: on
|
||||
|
||||
class Config:
|
||||
validate_assignment = True
|
||||
env_prefix = "INVOKEAI"
|
||||
model_config = SettingsConfigDict(validate_assignment=True, env_prefix="INVOKEAI")
|
||||
|
||||
def parse_args(self, argv: Optional[list[str]] = None, conf: Optional[DictConfig] = None, clobber=False):
|
||||
def parse_args(
|
||||
self,
|
||||
argv: Optional[list[str]] = None,
|
||||
conf: Optional[DictConfig] = None,
|
||||
clobber=False,
|
||||
):
|
||||
"""
|
||||
Update settings with contents of init file, environment, and
|
||||
command-line settings.
|
||||
@ -308,7 +326,11 @@ class InvokeAIAppConfig(InvokeAISettings):
|
||||
if self.singleton_init and not clobber:
|
||||
hints = get_type_hints(self.__class__)
|
||||
for k in self.singleton_init:
|
||||
setattr(self, k, parse_obj_as(hints[k], self.singleton_init[k]))
|
||||
setattr(
|
||||
self,
|
||||
k,
|
||||
TypeAdapter(hints[k]).validate_python(self.singleton_init[k]),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls, **kwargs) -> InvokeAIAppConfig:
|
||||
|
@ -2,7 +2,6 @@
|
||||
|
||||
from typing import Any, Optional
|
||||
|
||||
from invokeai.app.invocations.model import ModelInfo
|
||||
from invokeai.app.services.invocation_processor.invocation_processor_common import ProgressImage
|
||||
from invokeai.app.services.session_queue.session_queue_common import (
|
||||
BatchStatus,
|
||||
@ -11,6 +10,7 @@ from invokeai.app.services.session_queue.session_queue_common import (
|
||||
SessionQueueStatus,
|
||||
)
|
||||
from invokeai.app.util.misc import get_timestamp
|
||||
from invokeai.backend.model_management.model_manager import ModelInfo
|
||||
from invokeai.backend.model_management.models.base import BaseModelType, ModelType, SubModelType
|
||||
|
||||
|
||||
@ -55,7 +55,7 @@ class EventServiceBase:
|
||||
graph_execution_state_id=graph_execution_state_id,
|
||||
node_id=node.get("id"),
|
||||
source_node_id=source_node_id,
|
||||
progress_image=progress_image.dict() if progress_image is not None else None,
|
||||
progress_image=progress_image.model_dump() if progress_image is not None else None,
|
||||
step=step,
|
||||
order=order,
|
||||
total_steps=total_steps,
|
||||
@ -291,8 +291,8 @@ class EventServiceBase:
|
||||
started_at=str(session_queue_item.started_at) if session_queue_item.started_at else None,
|
||||
completed_at=str(session_queue_item.completed_at) if session_queue_item.completed_at else None,
|
||||
),
|
||||
batch_status=batch_status.dict(),
|
||||
queue_status=queue_status.dict(),
|
||||
batch_status=batch_status.model_dump(),
|
||||
queue_status=queue_status.model_dump(),
|
||||
),
|
||||
)
|
||||
|
||||
|
@ -1,4 +1,5 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from PIL.Image import Image as PILImageType
|
||||
@ -13,7 +14,7 @@ class ImageFileStorageBase(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_path(self, image_name: str, thumbnail: bool = False) -> str:
|
||||
def get_path(self, image_name: str, thumbnail: bool = False) -> Path:
|
||||
"""Gets the internal path to an image or thumbnail."""
|
||||
pass
|
||||
|
||||
|
@ -34,8 +34,8 @@ class ImageRecordStorageBase(ABC):
|
||||
@abstractmethod
|
||||
def get_many(
|
||||
self,
|
||||
offset: Optional[int] = None,
|
||||
limit: Optional[int] = None,
|
||||
offset: int = 0,
|
||||
limit: int = 10,
|
||||
image_origin: Optional[ResourceOrigin] = None,
|
||||
categories: Optional[list[ImageCategory]] = None,
|
||||
is_intermediate: Optional[bool] = None,
|
||||
@ -69,11 +69,11 @@ class ImageRecordStorageBase(ABC):
|
||||
image_category: ImageCategory,
|
||||
width: int,
|
||||
height: int,
|
||||
session_id: Optional[str],
|
||||
node_id: Optional[str],
|
||||
metadata: Optional[dict],
|
||||
is_intermediate: bool = False,
|
||||
starred: bool = False,
|
||||
is_intermediate: Optional[bool] = False,
|
||||
starred: Optional[bool] = False,
|
||||
session_id: Optional[str] = None,
|
||||
node_id: Optional[str] = None,
|
||||
metadata: Optional[dict] = None,
|
||||
) -> datetime:
|
||||
"""Saves an image record."""
|
||||
pass
|
||||
|
@ -3,7 +3,7 @@ import datetime
|
||||
from enum import Enum
|
||||
from typing import Optional, Union
|
||||
|
||||
from pydantic import Extra, Field, StrictBool, StrictStr
|
||||
from pydantic import Field, StrictBool, StrictStr
|
||||
|
||||
from invokeai.app.util.metaenum import MetaEnum
|
||||
from invokeai.app.util.misc import get_iso_timestamp
|
||||
@ -129,7 +129,9 @@ class ImageRecord(BaseModelExcludeNull):
|
||||
"""The created timestamp of the image."""
|
||||
updated_at: Union[datetime.datetime, str] = Field(description="The updated timestamp of the image.")
|
||||
"""The updated timestamp of the image."""
|
||||
deleted_at: Union[datetime.datetime, str, None] = Field(description="The deleted timestamp of the image.")
|
||||
deleted_at: Optional[Union[datetime.datetime, str]] = Field(
|
||||
default=None, description="The deleted timestamp of the image."
|
||||
)
|
||||
"""The deleted timestamp of the image."""
|
||||
is_intermediate: bool = Field(description="Whether this is an intermediate image.")
|
||||
"""Whether this is an intermediate image."""
|
||||
@ -147,7 +149,7 @@ class ImageRecord(BaseModelExcludeNull):
|
||||
"""Whether this image is starred."""
|
||||
|
||||
|
||||
class ImageRecordChanges(BaseModelExcludeNull, extra=Extra.forbid):
|
||||
class ImageRecordChanges(BaseModelExcludeNull, extra="allow"):
|
||||
"""A set of changes to apply to an image record.
|
||||
|
||||
Only limited changes are valid:
|
||||
|
@ -2,7 +2,7 @@ import json
|
||||
import sqlite3
|
||||
import threading
|
||||
from datetime import datetime
|
||||
from typing import Optional, cast
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
|
||||
from invokeai.app.services.shared.sqlite import SqliteDatabase
|
||||
@ -117,7 +117,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
"""
|
||||
)
|
||||
|
||||
def get(self, image_name: str) -> Optional[ImageRecord]:
|
||||
def get(self, image_name: str) -> ImageRecord:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
|
||||
@ -223,8 +223,8 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
|
||||
def get_many(
|
||||
self,
|
||||
offset: Optional[int] = None,
|
||||
limit: Optional[int] = None,
|
||||
offset: int = 0,
|
||||
limit: int = 10,
|
||||
image_origin: Optional[ResourceOrigin] = None,
|
||||
categories: Optional[list[ImageCategory]] = None,
|
||||
is_intermediate: Optional[bool] = None,
|
||||
@ -249,7 +249,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
"""
|
||||
|
||||
query_conditions = ""
|
||||
query_params = []
|
||||
query_params: list[Union[int, str, bool]] = []
|
||||
|
||||
if image_origin is not None:
|
||||
query_conditions += """--sql
|
||||
@ -387,13 +387,13 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
image_name: str,
|
||||
image_origin: ResourceOrigin,
|
||||
image_category: ImageCategory,
|
||||
session_id: Optional[str],
|
||||
width: int,
|
||||
height: int,
|
||||
node_id: Optional[str],
|
||||
metadata: Optional[dict],
|
||||
is_intermediate: bool = False,
|
||||
starred: bool = False,
|
||||
is_intermediate: Optional[bool] = False,
|
||||
starred: Optional[bool] = False,
|
||||
session_id: Optional[str] = None,
|
||||
node_id: Optional[str] = None,
|
||||
metadata: Optional[dict] = None,
|
||||
) -> datetime:
|
||||
try:
|
||||
metadata_json = None if metadata is None else json.dumps(metadata)
|
||||
|
@ -49,7 +49,7 @@ class ImageServiceABC(ABC):
|
||||
node_id: Optional[str] = None,
|
||||
session_id: Optional[str] = None,
|
||||
board_id: Optional[str] = None,
|
||||
is_intermediate: bool = False,
|
||||
is_intermediate: Optional[bool] = False,
|
||||
metadata: Optional[dict] = None,
|
||||
workflow: Optional[str] = None,
|
||||
) -> ImageDTO:
|
||||
|
@ -20,7 +20,9 @@ class ImageUrlsDTO(BaseModelExcludeNull):
|
||||
class ImageDTO(ImageRecord, ImageUrlsDTO):
|
||||
"""Deserialized image record, enriched for the frontend."""
|
||||
|
||||
board_id: Optional[str] = Field(description="The id of the board the image belongs to, if one exists.")
|
||||
board_id: Optional[str] = Field(
|
||||
default=None, description="The id of the board the image belongs to, if one exists."
|
||||
)
|
||||
"""The id of the board the image belongs to, if one exists."""
|
||||
|
||||
pass
|
||||
@ -34,7 +36,7 @@ def image_record_to_dto(
|
||||
) -> ImageDTO:
|
||||
"""Converts an image record to an image DTO."""
|
||||
return ImageDTO(
|
||||
**image_record.dict(),
|
||||
**image_record.model_dump(),
|
||||
image_url=image_url,
|
||||
thumbnail_url=thumbnail_url,
|
||||
board_id=board_id,
|
||||
|
@ -41,7 +41,7 @@ class ImageService(ImageServiceABC):
|
||||
node_id: Optional[str] = None,
|
||||
session_id: Optional[str] = None,
|
||||
board_id: Optional[str] = None,
|
||||
is_intermediate: bool = False,
|
||||
is_intermediate: Optional[bool] = False,
|
||||
metadata: Optional[dict] = None,
|
||||
workflow: Optional[str] = None,
|
||||
) -> ImageDTO:
|
||||
@ -146,7 +146,7 @@ class ImageService(ImageServiceABC):
|
||||
self.__invoker.services.logger.error("Problem getting image DTO")
|
||||
raise e
|
||||
|
||||
def get_metadata(self, image_name: str) -> Optional[ImageMetadata]:
|
||||
def get_metadata(self, image_name: str) -> ImageMetadata:
|
||||
try:
|
||||
image_record = self.__invoker.services.image_records.get(image_name)
|
||||
metadata = self.__invoker.services.image_records.get_metadata(image_name)
|
||||
@ -174,7 +174,7 @@ class ImageService(ImageServiceABC):
|
||||
|
||||
def get_path(self, image_name: str, thumbnail: bool = False) -> str:
|
||||
try:
|
||||
return self.__invoker.services.image_files.get_path(image_name, thumbnail)
|
||||
return str(self.__invoker.services.image_files.get_path(image_name, thumbnail))
|
||||
except Exception as e:
|
||||
self.__invoker.services.logger.error("Problem getting image path")
|
||||
raise e
|
||||
|
@ -58,7 +58,12 @@ class MemoryInvocationCache(InvocationCacheBase):
|
||||
# If the cache is full, we need to remove the least used
|
||||
number_to_delete = len(self._cache) + 1 - self._max_cache_size
|
||||
self._delete_oldest_access(number_to_delete)
|
||||
self._cache[key] = CachedItem(invocation_output, invocation_output.json())
|
||||
self._cache[key] = CachedItem(
|
||||
invocation_output,
|
||||
invocation_output.model_dump_json(
|
||||
warnings=False, exclude_defaults=True, exclude_unset=True, include={"type"}
|
||||
),
|
||||
)
|
||||
|
||||
def _delete_oldest_access(self, number_to_delete: int) -> None:
|
||||
number_to_delete = min(number_to_delete, len(self._cache))
|
||||
@ -85,7 +90,7 @@ class MemoryInvocationCache(InvocationCacheBase):
|
||||
|
||||
@staticmethod
|
||||
def create_key(invocation: BaseInvocation) -> int:
|
||||
return hash(invocation.json(exclude={"id"}))
|
||||
return hash(invocation.model_dump_json(exclude={"id"}, warnings=False))
|
||||
|
||||
def disable(self) -> None:
|
||||
with self._lock:
|
||||
|
@ -89,7 +89,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
queue_item_id=queue_item.session_queue_item_id,
|
||||
queue_id=queue_item.session_queue_id,
|
||||
graph_execution_state_id=graph_execution_state.id,
|
||||
node=invocation.dict(),
|
||||
node=invocation.model_dump(),
|
||||
source_node_id=source_node_id,
|
||||
)
|
||||
|
||||
@ -127,9 +127,9 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
queue_item_id=queue_item.session_queue_item_id,
|
||||
queue_id=queue_item.session_queue_id,
|
||||
graph_execution_state_id=graph_execution_state.id,
|
||||
node=invocation.dict(),
|
||||
node=invocation.model_dump(),
|
||||
source_node_id=source_node_id,
|
||||
result=outputs.dict(),
|
||||
result=outputs.model_dump(),
|
||||
)
|
||||
self.__invoker.services.performance_statistics.log_stats()
|
||||
|
||||
@ -157,7 +157,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
queue_item_id=queue_item.session_queue_item_id,
|
||||
queue_id=queue_item.session_queue_id,
|
||||
graph_execution_state_id=graph_execution_state.id,
|
||||
node=invocation.dict(),
|
||||
node=invocation.model_dump(),
|
||||
source_node_id=source_node_id,
|
||||
error_type=e.__class__.__name__,
|
||||
error=error,
|
||||
@ -187,7 +187,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
queue_item_id=queue_item.session_queue_item_id,
|
||||
queue_id=queue_item.session_queue_id,
|
||||
graph_execution_state_id=graph_execution_state.id,
|
||||
node=invocation.dict(),
|
||||
node=invocation.model_dump(),
|
||||
source_node_id=source_node_id,
|
||||
error_type=e.__class__.__name__,
|
||||
error=traceback.format_exc(),
|
||||
|
@ -72,7 +72,7 @@ class InvocationStatsService(InvocationStatsServiceBase):
|
||||
)
|
||||
self.collector.update_invocation_stats(
|
||||
graph_id=self.graph_id,
|
||||
invocation_type=self.invocation.type, # type: ignore - `type` is not on the `BaseInvocation` model, but *is* on all invocations
|
||||
invocation_type=self.invocation.type, # type: ignore # `type` is not on the `BaseInvocation` model, but *is* on all invocations
|
||||
time_used=time.time() - self.start_time,
|
||||
vram_used=torch.cuda.max_memory_allocated() / GIG if torch.cuda.is_available() else 0.0,
|
||||
)
|
||||
|
@ -2,7 +2,7 @@ import sqlite3
|
||||
import threading
|
||||
from typing import Generic, Optional, TypeVar, get_args
|
||||
|
||||
from pydantic import BaseModel, parse_raw_as
|
||||
from pydantic import BaseModel, TypeAdapter
|
||||
|
||||
from invokeai.app.services.shared.pagination import PaginatedResults
|
||||
from invokeai.app.services.shared.sqlite import SqliteDatabase
|
||||
@ -18,6 +18,7 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
||||
_cursor: sqlite3.Cursor
|
||||
_id_field: str
|
||||
_lock: threading.RLock
|
||||
_adapter: Optional[TypeAdapter[T]]
|
||||
|
||||
def __init__(self, db: SqliteDatabase, table_name: str, id_field: str = "id"):
|
||||
super().__init__()
|
||||
@ -27,6 +28,7 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
||||
self._table_name = table_name
|
||||
self._id_field = id_field # TODO: validate that T has this field
|
||||
self._cursor = self._conn.cursor()
|
||||
self._adapter: Optional[TypeAdapter[T]] = None
|
||||
|
||||
self._create_table()
|
||||
|
||||
@ -45,16 +47,21 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
||||
self._lock.release()
|
||||
|
||||
def _parse_item(self, item: str) -> T:
|
||||
# __orig_class__ is technically an implementation detail of the typing module, not a supported API
|
||||
item_type = get_args(self.__orig_class__)[0] # type: ignore
|
||||
return parse_raw_as(item_type, item)
|
||||
if self._adapter is None:
|
||||
"""
|
||||
We don't get access to `__orig_class__` in `__init__()`, and we need this before start(), so
|
||||
we can create it when it is first needed instead.
|
||||
__orig_class__ is technically an implementation detail of the typing module, not a supported API
|
||||
"""
|
||||
self._adapter = TypeAdapter(get_args(self.__orig_class__)[0]) # type: ignore [attr-defined]
|
||||
return self._adapter.validate_json(item)
|
||||
|
||||
def set(self, item: T):
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
f"""INSERT OR REPLACE INTO {self._table_name} (item) VALUES (?);""",
|
||||
(item.json(),),
|
||||
(item.model_dump_json(warnings=False, exclude_none=True),),
|
||||
)
|
||||
self._conn.commit()
|
||||
finally:
|
||||
|
@ -231,7 +231,7 @@ class ModelManagerServiceBase(ABC):
|
||||
def merge_models(
|
||||
self,
|
||||
model_names: List[str] = Field(
|
||||
default=None, min_items=2, max_items=3, description="List of model names to merge"
|
||||
default=None, min_length=2, max_length=3, description="List of model names to merge"
|
||||
),
|
||||
base_model: Union[BaseModelType, str] = Field(
|
||||
default=None, description="Base model shared by all models to be merged"
|
||||
|
@ -327,7 +327,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
def merge_models(
|
||||
self,
|
||||
model_names: List[str] = Field(
|
||||
default=None, min_items=2, max_items=3, description="List of model names to merge"
|
||||
default=None, min_length=2, max_length=3, description="List of model names to merge"
|
||||
),
|
||||
base_model: Union[BaseModelType, str] = Field(
|
||||
default=None, description="Base model shared by all models to be merged"
|
||||
|
@ -3,8 +3,8 @@ import json
|
||||
from itertools import chain, product
|
||||
from typing import Generator, Iterable, Literal, NamedTuple, Optional, TypeAlias, Union, cast
|
||||
|
||||
from pydantic import BaseModel, Field, StrictStr, parse_raw_as, root_validator, validator
|
||||
from pydantic.json import pydantic_encoder
|
||||
from pydantic import BaseModel, ConfigDict, Field, StrictStr, TypeAdapter, field_validator, model_validator
|
||||
from pydantic_core import to_jsonable_python
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation
|
||||
from invokeai.app.services.shared.graph import Graph, GraphExecutionState, NodeNotFoundError
|
||||
@ -17,7 +17,7 @@ class BatchZippedLengthError(ValueError):
|
||||
"""Raise when a batch has items of different lengths."""
|
||||
|
||||
|
||||
class BatchItemsTypeError(TypeError):
|
||||
class BatchItemsTypeError(ValueError): # this cannot be a TypeError in pydantic v2
|
||||
"""Raise when a batch has items of different types."""
|
||||
|
||||
|
||||
@ -70,7 +70,7 @@ class Batch(BaseModel):
|
||||
default=1, ge=1, description="Int stating how many times to iterate through all possible batch indices"
|
||||
)
|
||||
|
||||
@validator("data")
|
||||
@field_validator("data")
|
||||
def validate_lengths(cls, v: Optional[BatchDataCollection]):
|
||||
if v is None:
|
||||
return v
|
||||
@ -81,7 +81,7 @@ class Batch(BaseModel):
|
||||
raise BatchZippedLengthError("Zipped batch items must all have the same length")
|
||||
return v
|
||||
|
||||
@validator("data")
|
||||
@field_validator("data")
|
||||
def validate_types(cls, v: Optional[BatchDataCollection]):
|
||||
if v is None:
|
||||
return v
|
||||
@ -94,7 +94,7 @@ class Batch(BaseModel):
|
||||
raise BatchItemsTypeError("All items in a batch must have the same type")
|
||||
return v
|
||||
|
||||
@validator("data")
|
||||
@field_validator("data")
|
||||
def validate_unique_field_mappings(cls, v: Optional[BatchDataCollection]):
|
||||
if v is None:
|
||||
return v
|
||||
@ -107,34 +107,35 @@ class Batch(BaseModel):
|
||||
paths.add(pair)
|
||||
return v
|
||||
|
||||
@root_validator(skip_on_failure=True)
|
||||
@model_validator(mode="after")
|
||||
def validate_batch_nodes_and_edges(cls, values):
|
||||
batch_data_collection = cast(Optional[BatchDataCollection], values["data"])
|
||||
batch_data_collection = cast(Optional[BatchDataCollection], values.data)
|
||||
if batch_data_collection is None:
|
||||
return values
|
||||
graph = cast(Graph, values["graph"])
|
||||
graph = cast(Graph, values.graph)
|
||||
for batch_data_list in batch_data_collection:
|
||||
for batch_data in batch_data_list:
|
||||
try:
|
||||
node = cast(BaseInvocation, graph.get_node(batch_data.node_path))
|
||||
except NodeNotFoundError:
|
||||
raise NodeNotFoundError(f"Node {batch_data.node_path} not found in graph")
|
||||
if batch_data.field_name not in node.__fields__:
|
||||
if batch_data.field_name not in node.model_fields:
|
||||
raise NodeNotFoundError(f"Field {batch_data.field_name} not found in node {batch_data.node_path}")
|
||||
return values
|
||||
|
||||
@validator("graph")
|
||||
@field_validator("graph")
|
||||
def validate_graph(cls, v: Graph):
|
||||
v.validate_self()
|
||||
return v
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
"required": [
|
||||
model_config = ConfigDict(
|
||||
json_schema_extra=dict(
|
||||
required=[
|
||||
"graph",
|
||||
"runs",
|
||||
]
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
# endregion Batch
|
||||
@ -146,15 +147,21 @@ DEFAULT_QUEUE_ID = "default"
|
||||
|
||||
QUEUE_ITEM_STATUS = Literal["pending", "in_progress", "completed", "failed", "canceled"]
|
||||
|
||||
adapter_NodeFieldValue = TypeAdapter(list[NodeFieldValue])
|
||||
|
||||
|
||||
def get_field_values(queue_item_dict: dict) -> Optional[list[NodeFieldValue]]:
|
||||
field_values_raw = queue_item_dict.get("field_values", None)
|
||||
return parse_raw_as(list[NodeFieldValue], field_values_raw) if field_values_raw is not None else None
|
||||
return adapter_NodeFieldValue.validate_json(field_values_raw) if field_values_raw is not None else None
|
||||
|
||||
|
||||
adapter_GraphExecutionState = TypeAdapter(GraphExecutionState)
|
||||
|
||||
|
||||
def get_session(queue_item_dict: dict) -> GraphExecutionState:
|
||||
session_raw = queue_item_dict.get("session", "{}")
|
||||
return parse_raw_as(GraphExecutionState, session_raw)
|
||||
session = adapter_GraphExecutionState.validate_json(session_raw, strict=False)
|
||||
return session
|
||||
|
||||
|
||||
class SessionQueueItemWithoutGraph(BaseModel):
|
||||
@ -178,14 +185,14 @@ class SessionQueueItemWithoutGraph(BaseModel):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, queue_item_dict: dict) -> "SessionQueueItemDTO":
|
||||
def queue_item_dto_from_dict(cls, queue_item_dict: dict) -> "SessionQueueItemDTO":
|
||||
# must parse these manually
|
||||
queue_item_dict["field_values"] = get_field_values(queue_item_dict)
|
||||
return SessionQueueItemDTO(**queue_item_dict)
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
"required": [
|
||||
model_config = ConfigDict(
|
||||
json_schema_extra=dict(
|
||||
required=[
|
||||
"item_id",
|
||||
"status",
|
||||
"batch_id",
|
||||
@ -196,7 +203,8 @@ class SessionQueueItemWithoutGraph(BaseModel):
|
||||
"created_at",
|
||||
"updated_at",
|
||||
]
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class SessionQueueItemDTO(SessionQueueItemWithoutGraph):
|
||||
@ -207,15 +215,15 @@ class SessionQueueItem(SessionQueueItemWithoutGraph):
|
||||
session: GraphExecutionState = Field(description="The fully-populated session to be executed")
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, queue_item_dict: dict) -> "SessionQueueItem":
|
||||
def queue_item_from_dict(cls, queue_item_dict: dict) -> "SessionQueueItem":
|
||||
# must parse these manually
|
||||
queue_item_dict["field_values"] = get_field_values(queue_item_dict)
|
||||
queue_item_dict["session"] = get_session(queue_item_dict)
|
||||
return SessionQueueItem(**queue_item_dict)
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
"required": [
|
||||
model_config = ConfigDict(
|
||||
json_schema_extra=dict(
|
||||
required=[
|
||||
"item_id",
|
||||
"status",
|
||||
"batch_id",
|
||||
@ -227,7 +235,8 @@ class SessionQueueItem(SessionQueueItemWithoutGraph):
|
||||
"created_at",
|
||||
"updated_at",
|
||||
]
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
# endregion Queue Items
|
||||
@ -321,7 +330,7 @@ def populate_graph(graph: Graph, node_field_values: Iterable[NodeFieldValue]) ->
|
||||
"""
|
||||
Populates the given graph with the given batch data items.
|
||||
"""
|
||||
graph_clone = graph.copy(deep=True)
|
||||
graph_clone = graph.model_copy(deep=True)
|
||||
for item in node_field_values:
|
||||
node = graph_clone.get_node(item.node_path)
|
||||
if node is None:
|
||||
@ -354,7 +363,7 @@ def create_session_nfv_tuples(
|
||||
for item in batch_datum.items
|
||||
]
|
||||
node_field_values_to_zip.append(node_field_values)
|
||||
data.append(list(zip(*node_field_values_to_zip)))
|
||||
data.append(list(zip(*node_field_values_to_zip))) # type: ignore [arg-type]
|
||||
|
||||
# create generator to yield session,nfv tuples
|
||||
count = 0
|
||||
@ -409,11 +418,11 @@ def prepare_values_to_insert(queue_id: str, batch: Batch, priority: int, max_new
|
||||
values_to_insert.append(
|
||||
SessionQueueValueToInsert(
|
||||
queue_id, # queue_id
|
||||
session.json(), # session (json)
|
||||
session.model_dump_json(warnings=False, exclude_none=True), # session (json)
|
||||
session.id, # session_id
|
||||
batch.batch_id, # batch_id
|
||||
# must use pydantic_encoder bc field_values is a list of models
|
||||
json.dumps(field_values, default=pydantic_encoder) if field_values else None, # field_values (json)
|
||||
json.dumps(field_values, default=to_jsonable_python) if field_values else None, # field_values (json)
|
||||
priority, # priority
|
||||
)
|
||||
)
|
||||
@ -421,3 +430,6 @@ def prepare_values_to_insert(queue_id: str, batch: Batch, priority: int, max_new
|
||||
|
||||
|
||||
# endregion Util
|
||||
|
||||
Batch.model_rebuild(force=True)
|
||||
SessionQueueItem.model_rebuild(force=True)
|
||||
|
@ -277,8 +277,8 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
if result is None:
|
||||
raise SessionQueueItemNotFoundError(f"No queue item with batch id {enqueue_result.batch.batch_id}")
|
||||
return EnqueueGraphResult(
|
||||
**enqueue_result.dict(),
|
||||
queue_item=SessionQueueItemDTO.from_dict(dict(result)),
|
||||
**enqueue_result.model_dump(),
|
||||
queue_item=SessionQueueItemDTO.queue_item_dto_from_dict(dict(result)),
|
||||
)
|
||||
|
||||
def enqueue_batch(self, queue_id: str, batch: Batch, prepend: bool) -> EnqueueBatchResult:
|
||||
@ -351,7 +351,7 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
self.__lock.release()
|
||||
if result is None:
|
||||
return None
|
||||
queue_item = SessionQueueItem.from_dict(dict(result))
|
||||
queue_item = SessionQueueItem.queue_item_from_dict(dict(result))
|
||||
queue_item = self._set_queue_item_status(item_id=queue_item.item_id, status="in_progress")
|
||||
return queue_item
|
||||
|
||||
@ -380,7 +380,7 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
self.__lock.release()
|
||||
if result is None:
|
||||
return None
|
||||
return SessionQueueItem.from_dict(dict(result))
|
||||
return SessionQueueItem.queue_item_from_dict(dict(result))
|
||||
|
||||
def get_current(self, queue_id: str) -> Optional[SessionQueueItem]:
|
||||
try:
|
||||
@ -404,7 +404,7 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
self.__lock.release()
|
||||
if result is None:
|
||||
return None
|
||||
return SessionQueueItem.from_dict(dict(result))
|
||||
return SessionQueueItem.queue_item_from_dict(dict(result))
|
||||
|
||||
def _set_queue_item_status(
|
||||
self, item_id: int, status: QUEUE_ITEM_STATUS, error: Optional[str] = None
|
||||
@ -564,7 +564,7 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
queue_item = self.get_queue_item(item_id)
|
||||
if queue_item.status not in ["canceled", "failed", "completed"]:
|
||||
status = "failed" if error is not None else "canceled"
|
||||
queue_item = self._set_queue_item_status(item_id=item_id, status=status, error=error)
|
||||
queue_item = self._set_queue_item_status(item_id=item_id, status=status, error=error) # type: ignore [arg-type] # mypy seems to not narrow the Literals here
|
||||
self.__invoker.services.queue.cancel(queue_item.session_id)
|
||||
self.__invoker.services.events.emit_session_canceled(
|
||||
queue_item_id=queue_item.item_id,
|
||||
@ -699,7 +699,7 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
self.__lock.release()
|
||||
if result is None:
|
||||
raise SessionQueueItemNotFoundError(f"No queue item with id {item_id}")
|
||||
return SessionQueueItem.from_dict(dict(result))
|
||||
return SessionQueueItem.queue_item_from_dict(dict(result))
|
||||
|
||||
def list_queue_items(
|
||||
self,
|
||||
@ -751,7 +751,7 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
params.append(limit + 1)
|
||||
self.__cursor.execute(query, params)
|
||||
results = cast(list[sqlite3.Row], self.__cursor.fetchall())
|
||||
items = [SessionQueueItemDTO.from_dict(dict(result)) for result in results]
|
||||
items = [SessionQueueItemDTO.queue_item_dto_from_dict(dict(result)) for result in results]
|
||||
has_more = False
|
||||
if len(items) > limit:
|
||||
# remove the extra item
|
||||
|
@ -80,10 +80,10 @@ def create_system_graphs(graph_library: ItemStorageABC[LibraryGraph]) -> list[Li
|
||||
# TODO: Uncomment this when we are ready to fix this up to prevent breaking changes
|
||||
graphs: list[LibraryGraph] = list()
|
||||
|
||||
# text_to_image = graph_library.get(default_text_to_image_graph_id)
|
||||
text_to_image = graph_library.get(default_text_to_image_graph_id)
|
||||
|
||||
# # TODO: Check if the graph is the same as the default one, and if not, update it
|
||||
# #if text_to_image is None:
|
||||
# TODO: Check if the graph is the same as the default one, and if not, update it
|
||||
# if text_to_image is None:
|
||||
text_to_image = create_text_to_image()
|
||||
graph_library.set(text_to_image)
|
||||
|
||||
|
@ -5,7 +5,7 @@ import itertools
|
||||
from typing import Annotated, Any, Optional, Union, get_args, get_origin, get_type_hints
|
||||
|
||||
import networkx as nx
|
||||
from pydantic import BaseModel, root_validator, validator
|
||||
from pydantic import BaseModel, ConfigDict, field_validator, model_validator
|
||||
from pydantic.fields import Field
|
||||
|
||||
# Importing * is bad karma but needed here for node detection
|
||||
@ -235,7 +235,8 @@ class CollectInvocationOutput(BaseInvocationOutput):
|
||||
class CollectInvocation(BaseInvocation):
|
||||
"""Collects values into a collection"""
|
||||
|
||||
item: Any = InputField(
|
||||
item: Optional[Any] = InputField(
|
||||
default=None,
|
||||
description="The item to collect (all inputs must be of the same type)",
|
||||
ui_type=UIType.CollectionItem,
|
||||
title="Collection Item",
|
||||
@ -250,8 +251,8 @@ class CollectInvocation(BaseInvocation):
|
||||
return CollectInvocationOutput(collection=copy.copy(self.collection))
|
||||
|
||||
|
||||
InvocationsUnion = Union[BaseInvocation.get_invocations()] # type: ignore
|
||||
InvocationOutputsUnion = Union[BaseInvocationOutput.get_all_subclasses_tuple()] # type: ignore
|
||||
InvocationsUnion: Any = BaseInvocation.get_invocations_union()
|
||||
InvocationOutputsUnion: Any = BaseInvocationOutput.get_outputs_union()
|
||||
|
||||
|
||||
class Graph(BaseModel):
|
||||
@ -378,13 +379,13 @@ class Graph(BaseModel):
|
||||
raise NodeNotFoundError(f"Edge destination node {edge.destination.node_id} does not exist in the graph")
|
||||
|
||||
# output fields are not on the node object directly, they are on the output type
|
||||
if edge.source.field not in source_node.get_output_type().__fields__:
|
||||
if edge.source.field not in source_node.get_output_type().model_fields:
|
||||
raise NodeFieldNotFoundError(
|
||||
f"Edge source field {edge.source.field} does not exist in node {edge.source.node_id}"
|
||||
)
|
||||
|
||||
# input fields are on the node
|
||||
if edge.destination.field not in destination_node.__fields__:
|
||||
if edge.destination.field not in destination_node.model_fields:
|
||||
raise NodeFieldNotFoundError(
|
||||
f"Edge destination field {edge.destination.field} does not exist in node {edge.destination.node_id}"
|
||||
)
|
||||
@ -395,24 +396,24 @@ class Graph(BaseModel):
|
||||
raise CyclicalGraphError("Graph contains cycles")
|
||||
|
||||
# Validate all edge connections are valid
|
||||
for e in self.edges:
|
||||
for edge in self.edges:
|
||||
if not are_connections_compatible(
|
||||
self.get_node(e.source.node_id),
|
||||
e.source.field,
|
||||
self.get_node(e.destination.node_id),
|
||||
e.destination.field,
|
||||
self.get_node(edge.source.node_id),
|
||||
edge.source.field,
|
||||
self.get_node(edge.destination.node_id),
|
||||
edge.destination.field,
|
||||
):
|
||||
raise InvalidEdgeError(
|
||||
f"Invalid edge from {e.source.node_id}.{e.source.field} to {e.destination.node_id}.{e.destination.field}"
|
||||
f"Invalid edge from {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}"
|
||||
)
|
||||
|
||||
# Validate all iterators & collectors
|
||||
# TODO: may need to validate all iterators & collectors in subgraphs so edge connections in parent graphs will be available
|
||||
for n in self.nodes.values():
|
||||
if isinstance(n, IterateInvocation) and not self._is_iterator_connection_valid(n.id):
|
||||
raise InvalidEdgeError(f"Invalid iterator node {n.id}")
|
||||
if isinstance(n, CollectInvocation) and not self._is_collector_connection_valid(n.id):
|
||||
raise InvalidEdgeError(f"Invalid collector node {n.id}")
|
||||
for node in self.nodes.values():
|
||||
if isinstance(node, IterateInvocation) and not self._is_iterator_connection_valid(node.id):
|
||||
raise InvalidEdgeError(f"Invalid iterator node {node.id}")
|
||||
if isinstance(node, CollectInvocation) and not self._is_collector_connection_valid(node.id):
|
||||
raise InvalidEdgeError(f"Invalid collector node {node.id}")
|
||||
|
||||
return None
|
||||
|
||||
@ -594,7 +595,7 @@ class Graph(BaseModel):
|
||||
|
||||
def _get_input_edges_and_graphs(
|
||||
self, node_path: str, prefix: Optional[str] = None
|
||||
) -> list[tuple["Graph", str, Edge]]:
|
||||
) -> list[tuple["Graph", Union[str, None], Edge]]:
|
||||
"""Gets all input edges for a node along with the graph they are in and the graph's path"""
|
||||
edges = list()
|
||||
|
||||
@ -636,7 +637,7 @@ class Graph(BaseModel):
|
||||
|
||||
def _get_output_edges_and_graphs(
|
||||
self, node_path: str, prefix: Optional[str] = None
|
||||
) -> list[tuple["Graph", str, Edge]]:
|
||||
) -> list[tuple["Graph", Union[str, None], Edge]]:
|
||||
"""Gets all output edges for a node along with the graph they are in and the graph's path"""
|
||||
edges = list()
|
||||
|
||||
@ -817,15 +818,15 @@ class GraphExecutionState(BaseModel):
|
||||
default_factory=dict,
|
||||
)
|
||||
|
||||
@validator("graph")
|
||||
@field_validator("graph")
|
||||
def graph_is_valid(cls, v: Graph):
|
||||
"""Validates that the graph is valid"""
|
||||
v.validate_self()
|
||||
return v
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
"required": [
|
||||
model_config = ConfigDict(
|
||||
json_schema_extra=dict(
|
||||
required=[
|
||||
"id",
|
||||
"graph",
|
||||
"execution_graph",
|
||||
@ -836,7 +837,8 @@ class GraphExecutionState(BaseModel):
|
||||
"prepared_source_mapping",
|
||||
"source_prepared_mapping",
|
||||
]
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
def next(self) -> Optional[BaseInvocation]:
|
||||
"""Gets the next node ready to execute."""
|
||||
@ -910,7 +912,7 @@ class GraphExecutionState(BaseModel):
|
||||
input_collection = getattr(input_collection_prepared_node_output, input_collection_edge.source.field)
|
||||
self_iteration_count = len(input_collection)
|
||||
|
||||
new_nodes = list()
|
||||
new_nodes: list[str] = list()
|
||||
if self_iteration_count == 0:
|
||||
# TODO: should this raise a warning? It might just happen if an empty collection is input, and should be valid.
|
||||
return new_nodes
|
||||
@ -920,7 +922,7 @@ class GraphExecutionState(BaseModel):
|
||||
|
||||
# Create new edges for this iteration
|
||||
# For collect nodes, this may contain multiple inputs to the same field
|
||||
new_edges = list()
|
||||
new_edges: list[Edge] = list()
|
||||
for edge in input_edges:
|
||||
for input_node_id in (n[1] for n in iteration_node_map if n[0] == edge.source.node_id):
|
||||
new_edge = Edge(
|
||||
@ -1179,18 +1181,18 @@ class LibraryGraph(BaseModel):
|
||||
description="The outputs exposed by this graph", default_factory=list
|
||||
)
|
||||
|
||||
@validator("exposed_inputs", "exposed_outputs")
|
||||
def validate_exposed_aliases(cls, v):
|
||||
@field_validator("exposed_inputs", "exposed_outputs")
|
||||
def validate_exposed_aliases(cls, v: list[Union[ExposedNodeInput, ExposedNodeOutput]]):
|
||||
if len(v) != len(set(i.alias for i in v)):
|
||||
raise ValueError("Duplicate exposed alias")
|
||||
return v
|
||||
|
||||
@root_validator
|
||||
@model_validator(mode="after")
|
||||
def validate_exposed_nodes(cls, values):
|
||||
graph = values["graph"]
|
||||
graph = values.graph
|
||||
|
||||
# Validate exposed inputs
|
||||
for exposed_input in values["exposed_inputs"]:
|
||||
for exposed_input in values.exposed_inputs:
|
||||
if not graph.has_node(exposed_input.node_path):
|
||||
raise ValueError(f"Exposed input node {exposed_input.node_path} does not exist")
|
||||
node = graph.get_node(exposed_input.node_path)
|
||||
@ -1200,7 +1202,7 @@ class LibraryGraph(BaseModel):
|
||||
)
|
||||
|
||||
# Validate exposed outputs
|
||||
for exposed_output in values["exposed_outputs"]:
|
||||
for exposed_output in values.exposed_outputs:
|
||||
if not graph.has_node(exposed_output.node_path):
|
||||
raise ValueError(f"Exposed output node {exposed_output.node_path} does not exist")
|
||||
node = graph.get_node(exposed_output.node_path)
|
||||
@ -1212,4 +1214,6 @@ class LibraryGraph(BaseModel):
|
||||
return values
|
||||
|
||||
|
||||
GraphInvocation.update_forward_refs()
|
||||
GraphInvocation.model_rebuild(force=True)
|
||||
Graph.model_rebuild(force=True)
|
||||
GraphExecutionState.model_rebuild(force=True)
|
||||
|
@ -1,12 +1,11 @@
|
||||
from typing import Generic, TypeVar
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic.generics import GenericModel
|
||||
|
||||
GenericBaseModel = TypeVar("GenericBaseModel", bound=BaseModel)
|
||||
|
||||
|
||||
class CursorPaginatedResults(GenericModel, Generic[GenericBaseModel]):
|
||||
class CursorPaginatedResults(BaseModel, Generic[GenericBaseModel]):
|
||||
"""
|
||||
Cursor-paginated results
|
||||
Generic must be a Pydantic model
|
||||
@ -17,7 +16,7 @@ class CursorPaginatedResults(GenericModel, Generic[GenericBaseModel]):
|
||||
items: list[GenericBaseModel] = Field(..., description="Items")
|
||||
|
||||
|
||||
class OffsetPaginatedResults(GenericModel, Generic[GenericBaseModel]):
|
||||
class OffsetPaginatedResults(BaseModel, Generic[GenericBaseModel]):
|
||||
"""
|
||||
Offset-paginated results
|
||||
Generic must be a Pydantic model
|
||||
@ -29,7 +28,7 @@ class OffsetPaginatedResults(GenericModel, Generic[GenericBaseModel]):
|
||||
items: list[GenericBaseModel] = Field(description="Items")
|
||||
|
||||
|
||||
class PaginatedResults(GenericModel, Generic[GenericBaseModel]):
|
||||
class PaginatedResults(BaseModel, Generic[GenericBaseModel]):
|
||||
"""
|
||||
Paginated results
|
||||
Generic must be a Pydantic model
|
||||
|
@ -265,7 +265,7 @@ def np_img_resize(np_img: np.ndarray, resize_mode: str, h: int, w: int, device:
|
||||
|
||||
|
||||
def prepare_control_image(
|
||||
image: Image,
|
||||
image: Image.Image,
|
||||
width: int,
|
||||
height: int,
|
||||
num_channels: int = 3,
|
||||
|
@ -1,4 +1,5 @@
|
||||
import datetime
|
||||
import typing
|
||||
import uuid
|
||||
|
||||
import numpy as np
|
||||
@ -27,3 +28,8 @@ def get_random_seed():
|
||||
def uuid_string():
|
||||
res = uuid.uuid4()
|
||||
return str(res)
|
||||
|
||||
|
||||
def is_optional(value: typing.Any):
|
||||
"""Checks if a value is typed as Optional. Note that Optional is sugar for Union[x, None]."""
|
||||
return typing.get_origin(value) is typing.Union and type(None) in typing.get_args(value)
|
||||
|
@ -13,11 +13,11 @@ From https://github.com/tiangolo/fastapi/discussions/8882#discussioncomment-5154
|
||||
|
||||
|
||||
class BaseModelExcludeNull(BaseModel):
|
||||
def dict(self, *args, **kwargs) -> dict[str, Any]:
|
||||
def model_dump(self, *args, **kwargs) -> dict[str, Any]:
|
||||
"""
|
||||
Override the default dict method to exclude None values in the response
|
||||
"""
|
||||
kwargs.pop("exclude_none", None)
|
||||
return super().dict(*args, exclude_none=True, **kwargs)
|
||||
return super().model_dump(*args, exclude_none=True, **kwargs)
|
||||
|
||||
pass
|
||||
|
0
invokeai/assets/__init__.py
Normal file
0
invokeai/assets/__init__.py
Normal file
@ -41,18 +41,18 @@ config = InvokeAIAppConfig.get_config()
|
||||
|
||||
|
||||
class SegmentedGrayscale(object):
|
||||
def __init__(self, image: Image, heatmap: torch.Tensor):
|
||||
def __init__(self, image: Image.Image, heatmap: torch.Tensor):
|
||||
self.heatmap = heatmap
|
||||
self.image = image
|
||||
|
||||
def to_grayscale(self, invert: bool = False) -> Image:
|
||||
def to_grayscale(self, invert: bool = False) -> Image.Image:
|
||||
return self._rescale(Image.fromarray(np.uint8(255 - self.heatmap * 255 if invert else self.heatmap * 255)))
|
||||
|
||||
def to_mask(self, threshold: float = 0.5) -> Image:
|
||||
def to_mask(self, threshold: float = 0.5) -> Image.Image:
|
||||
discrete_heatmap = self.heatmap.lt(threshold).int()
|
||||
return self._rescale(Image.fromarray(np.uint8(discrete_heatmap * 255), mode="L"))
|
||||
|
||||
def to_transparent(self, invert: bool = False) -> Image:
|
||||
def to_transparent(self, invert: bool = False) -> Image.Image:
|
||||
transparent_image = self.image.copy()
|
||||
# For img2img, we want the selected regions to be transparent,
|
||||
# but to_grayscale() returns the opposite. Thus invert.
|
||||
@ -61,7 +61,7 @@ class SegmentedGrayscale(object):
|
||||
return transparent_image
|
||||
|
||||
# unscales and uncrops the 352x352 heatmap so that it matches the image again
|
||||
def _rescale(self, heatmap: Image) -> Image:
|
||||
def _rescale(self, heatmap: Image.Image) -> Image.Image:
|
||||
size = self.image.width if (self.image.width > self.image.height) else self.image.height
|
||||
resized_image = heatmap.resize((size, size), resample=Image.Resampling.LANCZOS)
|
||||
return resized_image.crop((0, 0, self.image.width, self.image.height))
|
||||
@ -82,7 +82,7 @@ class Txt2Mask(object):
|
||||
self.model = CLIPSegForImageSegmentation.from_pretrained(CLIPSEG_MODEL, cache_dir=config.cache_dir)
|
||||
|
||||
@torch.no_grad()
|
||||
def segment(self, image, prompt: str) -> SegmentedGrayscale:
|
||||
def segment(self, image: Image.Image, prompt: str) -> SegmentedGrayscale:
|
||||
"""
|
||||
Given a prompt string such as "a bagel", tries to identify the object in the
|
||||
provided image and returns a SegmentedGrayscale object in which the brighter
|
||||
@ -99,7 +99,7 @@ class Txt2Mask(object):
|
||||
heatmap = torch.sigmoid(outputs.logits)
|
||||
return SegmentedGrayscale(image, heatmap)
|
||||
|
||||
def _scale_and_crop(self, image: Image) -> Image:
|
||||
def _scale_and_crop(self, image: Image.Image) -> Image.Image:
|
||||
scaled_image = Image.new("RGB", (CLIPSEG_SIZE, CLIPSEG_SIZE))
|
||||
if image.width > image.height: # width is constraint
|
||||
scale = CLIPSEG_SIZE / image.width
|
||||
|
@ -9,7 +9,7 @@ class InitImageResizer:
|
||||
def __init__(self, Image):
|
||||
self.image = Image
|
||||
|
||||
def resize(self, width=None, height=None) -> Image:
|
||||
def resize(self, width=None, height=None) -> Image.Image:
|
||||
"""
|
||||
Return a copy of the image resized to fit within
|
||||
a box width x height. The aspect ratio is
|
||||
|
@ -793,7 +793,11 @@ def migrate_init_file(legacy_format: Path):
|
||||
old = legacy_parser.parse_args([f"@{str(legacy_format)}"])
|
||||
new = InvokeAIAppConfig.get_config()
|
||||
|
||||
fields = [x for x, y in InvokeAIAppConfig.__fields__.items() if y.field_info.extra.get("category") != "DEPRECATED"]
|
||||
fields = [
|
||||
x
|
||||
for x, y in InvokeAIAppConfig.model_fields.items()
|
||||
if (y.json_schema_extra.get("category", None) if y.json_schema_extra else None) != "DEPRECATED"
|
||||
]
|
||||
for attr in fields:
|
||||
if hasattr(old, attr):
|
||||
try:
|
||||
|
@ -236,13 +236,13 @@ import types
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from shutil import move, rmtree
|
||||
from typing import Callable, Dict, List, Literal, Optional, Set, Tuple, Union
|
||||
from typing import Callable, Dict, List, Literal, Optional, Set, Tuple, Union, cast
|
||||
|
||||
import torch
|
||||
import yaml
|
||||
from omegaconf import OmegaConf
|
||||
from omegaconf.dictconfig import DictConfig
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
@ -294,6 +294,8 @@ class AddModelResult(BaseModel):
|
||||
base_model: BaseModelType = Field(description="The base model")
|
||||
config: ModelConfigBase = Field(description="The configuration of the model")
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
|
||||
MAX_CACHE_SIZE = 6.0 # GB
|
||||
|
||||
@ -576,7 +578,7 @@ class ModelManager(object):
|
||||
"""
|
||||
model_key = self.create_key(model_name, base_model, model_type)
|
||||
if model_key in self.models:
|
||||
return self.models[model_key].dict(exclude_defaults=True)
|
||||
return self.models[model_key].model_dump(exclude_defaults=True)
|
||||
else:
|
||||
return None # TODO: None or empty dict on not found
|
||||
|
||||
@ -632,7 +634,7 @@ class ModelManager(object):
|
||||
continue
|
||||
|
||||
model_dict = dict(
|
||||
**model_config.dict(exclude_defaults=True),
|
||||
**model_config.model_dump(exclude_defaults=True),
|
||||
# OpenAPIModelInfoBase
|
||||
model_name=cur_model_name,
|
||||
base_model=cur_base_model,
|
||||
@ -900,14 +902,16 @@ class ModelManager(object):
|
||||
Write current configuration out to the indicated file.
|
||||
"""
|
||||
data_to_save = dict()
|
||||
data_to_save["__metadata__"] = self.config_meta.dict()
|
||||
data_to_save["__metadata__"] = self.config_meta.model_dump()
|
||||
|
||||
for model_key, model_config in self.models.items():
|
||||
model_name, base_model, model_type = self.parse_key(model_key)
|
||||
model_class = self._get_implementation(base_model, model_type)
|
||||
if model_class.save_to_config:
|
||||
# TODO: or exclude_unset better fits here?
|
||||
data_to_save[model_key] = model_config.dict(exclude_defaults=True, exclude={"error"})
|
||||
data_to_save[model_key] = cast(BaseModel, model_config).model_dump(
|
||||
exclude_defaults=True, exclude={"error"}, mode="json"
|
||||
)
|
||||
# alias for config file
|
||||
data_to_save[model_key]["format"] = data_to_save[model_key].pop("model_format")
|
||||
|
||||
|
@ -2,7 +2,7 @@ import inspect
|
||||
from enum import Enum
|
||||
from typing import Literal, get_origin
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, ConfigDict, create_model
|
||||
|
||||
from .base import ( # noqa: F401
|
||||
BaseModelType,
|
||||
@ -106,6 +106,8 @@ class OpenAPIModelInfoBase(BaseModel):
|
||||
base_model: BaseModelType
|
||||
model_type: ModelType
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
|
||||
for base_model, models in MODEL_CLASSES.items():
|
||||
for model_type, model_class in models.items():
|
||||
@ -121,17 +123,11 @@ for base_model, models in MODEL_CLASSES.items():
|
||||
if openapi_cfg_name in vars():
|
||||
continue
|
||||
|
||||
api_wrapper = type(
|
||||
api_wrapper = create_model(
|
||||
openapi_cfg_name,
|
||||
(cfg, OpenAPIModelInfoBase),
|
||||
dict(
|
||||
__annotations__=dict(
|
||||
model_type=Literal[model_type.value],
|
||||
),
|
||||
),
|
||||
__base__=(cfg, OpenAPIModelInfoBase),
|
||||
model_type=(Literal[model_type], model_type), # type: ignore
|
||||
)
|
||||
|
||||
# globals()[openapi_cfg_name] = api_wrapper
|
||||
vars()[openapi_cfg_name] = api_wrapper
|
||||
OPENAPI_MODEL_CONFIGS.append(api_wrapper)
|
||||
|
||||
|
@ -19,7 +19,7 @@ from diffusers import logging as diffusers_logging
|
||||
from onnx import numpy_helper
|
||||
from onnxruntime import InferenceSession, SessionOptions, get_available_providers
|
||||
from picklescan.scanner import scan_file_path
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from transformers import logging as transformers_logging
|
||||
|
||||
|
||||
@ -86,14 +86,21 @@ class ModelError(str, Enum):
|
||||
NotFound = "not_found"
|
||||
|
||||
|
||||
def model_config_json_schema_extra(schema: dict[str, Any]) -> None:
|
||||
if "required" not in schema:
|
||||
schema["required"] = []
|
||||
schema["required"].append("model_type")
|
||||
|
||||
|
||||
class ModelConfigBase(BaseModel):
|
||||
path: str # or Path
|
||||
description: Optional[str] = Field(None)
|
||||
model_format: Optional[str] = Field(None)
|
||||
error: Optional[ModelError] = Field(None)
|
||||
|
||||
class Config:
|
||||
use_enum_values = True
|
||||
model_config = ConfigDict(
|
||||
use_enum_values=True, protected_namespaces=(), json_schema_extra=model_config_json_schema_extra
|
||||
)
|
||||
|
||||
|
||||
class EmptyConfigLoader(ConfigMixin):
|
||||
|
@ -58,14 +58,16 @@ class IPAdapterModel(ModelBase):
|
||||
|
||||
def get_model(
|
||||
self,
|
||||
torch_dtype: Optional[torch.dtype],
|
||||
torch_dtype: torch.dtype,
|
||||
child_type: Optional[SubModelType] = None,
|
||||
) -> typing.Union[IPAdapter, IPAdapterPlus]:
|
||||
if child_type is not None:
|
||||
raise ValueError("There are no child models in an IP-Adapter model.")
|
||||
|
||||
model = build_ip_adapter(
|
||||
ip_adapter_ckpt_path=os.path.join(self.model_path, "ip_adapter.bin"), device="cpu", dtype=torch_dtype
|
||||
ip_adapter_ckpt_path=os.path.join(self.model_path, "ip_adapter.bin"),
|
||||
device=torch.device("cpu"),
|
||||
dtype=torch_dtype,
|
||||
)
|
||||
|
||||
self.model_size = model.calc_size()
|
||||
|
@ -96,7 +96,7 @@ def set_seamless(model: Union[UNet2DConditionModel, AutoencoderKL], seamless_axe
|
||||
finally:
|
||||
for module, orig_conv_forward in to_restore:
|
||||
module._conv_forward = orig_conv_forward
|
||||
if hasattr(m, "asymmetric_padding_mode"):
|
||||
del m.asymmetric_padding_mode
|
||||
if hasattr(m, "asymmetric_padding"):
|
||||
del m.asymmetric_padding
|
||||
if hasattr(module, "asymmetric_padding_mode"):
|
||||
del module.asymmetric_padding_mode
|
||||
if hasattr(module, "asymmetric_padding"):
|
||||
del module.asymmetric_padding
|
||||
|
@ -1,7 +1,8 @@
|
||||
import math
|
||||
from typing import Optional
|
||||
|
||||
import PIL
|
||||
import torch
|
||||
from PIL import Image
|
||||
from torchvision.transforms.functional import InterpolationMode
|
||||
from torchvision.transforms.functional import resize as tv_resize
|
||||
|
||||
@ -11,7 +12,7 @@ class AttentionMapSaver:
|
||||
self.token_ids = token_ids
|
||||
self.latents_shape = latents_shape
|
||||
# self.collated_maps = #torch.zeros([len(token_ids), latents_shape[0], latents_shape[1]])
|
||||
self.collated_maps = {}
|
||||
self.collated_maps: dict[str, torch.Tensor] = {}
|
||||
|
||||
def clear_maps(self):
|
||||
self.collated_maps = {}
|
||||
@ -38,9 +39,10 @@ class AttentionMapSaver:
|
||||
|
||||
def write_maps_to_disk(self, path: str):
|
||||
pil_image = self.get_stacked_maps_image()
|
||||
pil_image.save(path, "PNG")
|
||||
if pil_image is not None:
|
||||
pil_image.save(path, "PNG")
|
||||
|
||||
def get_stacked_maps_image(self) -> PIL.Image:
|
||||
def get_stacked_maps_image(self) -> Optional[Image.Image]:
|
||||
"""
|
||||
Scale all collected attention maps to the same size, blend them together and return as an image.
|
||||
:return: An image containing a vertical stack of blended attention maps, one for each requested token.
|
||||
@ -95,4 +97,4 @@ class AttentionMapSaver:
|
||||
return None
|
||||
|
||||
merged_bytes = merged.mul(0xFF).byte()
|
||||
return PIL.Image.fromarray(merged_bytes.numpy(), mode="L")
|
||||
return Image.fromarray(merged_bytes.numpy(), mode="L")
|
||||
|
@ -151,7 +151,9 @@ export const addRequestedSingleImageDeletionListener = () => {
|
||||
|
||||
if (wasImageDeleted) {
|
||||
dispatch(
|
||||
api.util.invalidateTags([{ type: 'Board', id: imageDTO.board_id }])
|
||||
api.util.invalidateTags([
|
||||
{ type: 'Board', id: imageDTO.board_id ?? 'none' },
|
||||
])
|
||||
);
|
||||
}
|
||||
},
|
||||
|
@ -6,7 +6,7 @@ import { useMantineMultiSelectStyles } from 'mantine-theme/hooks/useMantineMulti
|
||||
import { KeyboardEvent, RefObject, memo, useCallback } from 'react';
|
||||
|
||||
type IAIMultiSelectProps = Omit<MultiSelectProps, 'label'> & {
|
||||
tooltip?: string;
|
||||
tooltip?: string | null;
|
||||
inputRef?: RefObject<HTMLInputElement>;
|
||||
label?: string;
|
||||
};
|
||||
|
@ -12,7 +12,7 @@ export type IAISelectDataType = {
|
||||
};
|
||||
|
||||
type IAISelectProps = Omit<SelectProps, 'label'> & {
|
||||
tooltip?: string;
|
||||
tooltip?: string | null;
|
||||
label?: string;
|
||||
inputRef?: RefObject<HTMLInputElement>;
|
||||
};
|
||||
|
@ -10,7 +10,7 @@ export type IAISelectDataType = {
|
||||
};
|
||||
|
||||
export type IAISelectProps = Omit<SelectProps, 'label'> & {
|
||||
tooltip?: string;
|
||||
tooltip?: string | null;
|
||||
inputRef?: RefObject<HTMLInputElement>;
|
||||
label?: string;
|
||||
};
|
||||
|
@ -39,7 +39,10 @@ export const dynamicPromptsSlice = createSlice({
|
||||
promptsChanged: (state, action: PayloadAction<string[]>) => {
|
||||
state.prompts = action.payload;
|
||||
},
|
||||
parsingErrorChanged: (state, action: PayloadAction<string | undefined>) => {
|
||||
parsingErrorChanged: (
|
||||
state,
|
||||
action: PayloadAction<string | null | undefined>
|
||||
) => {
|
||||
state.parsingError = action.payload;
|
||||
},
|
||||
isErrorChanged: (state, action: PayloadAction<boolean>) => {
|
||||
|
@ -10,7 +10,7 @@ import {
|
||||
} from 'features/parameters/types/parameterSchemas';
|
||||
import i18n from 'i18next';
|
||||
import { has, keyBy } from 'lodash-es';
|
||||
import { OpenAPIV3 } from 'openapi-types';
|
||||
import { OpenAPIV3_1 } from 'openapi-types';
|
||||
import { RgbaColor } from 'react-colorful';
|
||||
import { Node } from 'reactflow';
|
||||
import { Graph, _InputField, _OutputField } from 'services/api/types';
|
||||
@ -791,9 +791,9 @@ export type IntegerInputFieldTemplate = InputFieldTemplateBase & {
|
||||
default: number;
|
||||
multipleOf?: number;
|
||||
maximum?: number;
|
||||
exclusiveMaximum?: boolean;
|
||||
exclusiveMaximum?: number;
|
||||
minimum?: number;
|
||||
exclusiveMinimum?: boolean;
|
||||
exclusiveMinimum?: number;
|
||||
};
|
||||
|
||||
export type IntegerCollectionInputFieldTemplate = InputFieldTemplateBase & {
|
||||
@ -814,9 +814,9 @@ export type FloatInputFieldTemplate = InputFieldTemplateBase & {
|
||||
default: number;
|
||||
multipleOf?: number;
|
||||
maximum?: number;
|
||||
exclusiveMaximum?: boolean;
|
||||
exclusiveMaximum?: number;
|
||||
minimum?: number;
|
||||
exclusiveMinimum?: boolean;
|
||||
exclusiveMinimum?: number;
|
||||
};
|
||||
|
||||
export type FloatCollectionInputFieldTemplate = InputFieldTemplateBase & {
|
||||
@ -1163,20 +1163,20 @@ export type TypeHints = {
|
||||
};
|
||||
|
||||
export type InvocationSchemaExtra = {
|
||||
output: OpenAPIV3.ReferenceObject; // the output of the invocation
|
||||
output: OpenAPIV3_1.ReferenceObject; // the output of the invocation
|
||||
title: string;
|
||||
category?: string;
|
||||
tags?: string[];
|
||||
version?: string;
|
||||
properties: Omit<
|
||||
NonNullable<OpenAPIV3.SchemaObject['properties']> &
|
||||
NonNullable<OpenAPIV3_1.SchemaObject['properties']> &
|
||||
(_InputField | _OutputField),
|
||||
'type'
|
||||
> & {
|
||||
type: Omit<OpenAPIV3.SchemaObject, 'default'> & {
|
||||
type: Omit<OpenAPIV3_1.SchemaObject, 'default'> & {
|
||||
default: AnyInvocationType;
|
||||
};
|
||||
use_cache: Omit<OpenAPIV3.SchemaObject, 'default'> & {
|
||||
use_cache: Omit<OpenAPIV3_1.SchemaObject, 'default'> & {
|
||||
default: boolean;
|
||||
};
|
||||
};
|
||||
@ -1187,17 +1187,17 @@ export type InvocationSchemaType = {
|
||||
};
|
||||
|
||||
export type InvocationBaseSchemaObject = Omit<
|
||||
OpenAPIV3.BaseSchemaObject,
|
||||
OpenAPIV3_1.BaseSchemaObject,
|
||||
'title' | 'type' | 'properties'
|
||||
> &
|
||||
InvocationSchemaExtra;
|
||||
|
||||
export type InvocationOutputSchemaObject = Omit<
|
||||
OpenAPIV3.SchemaObject,
|
||||
OpenAPIV3_1.SchemaObject,
|
||||
'properties'
|
||||
> & {
|
||||
properties: OpenAPIV3.SchemaObject['properties'] & {
|
||||
type: Omit<OpenAPIV3.SchemaObject, 'default'> & {
|
||||
properties: OpenAPIV3_1.SchemaObject['properties'] & {
|
||||
type: Omit<OpenAPIV3_1.SchemaObject, 'default'> & {
|
||||
default: string;
|
||||
};
|
||||
} & {
|
||||
@ -1205,14 +1205,18 @@ export type InvocationOutputSchemaObject = Omit<
|
||||
};
|
||||
};
|
||||
|
||||
export type InvocationFieldSchema = OpenAPIV3.SchemaObject & _InputField;
|
||||
export type InvocationFieldSchema = OpenAPIV3_1.SchemaObject & _InputField;
|
||||
|
||||
export type OpenAPIV3_1SchemaOrRef =
|
||||
| OpenAPIV3_1.ReferenceObject
|
||||
| OpenAPIV3_1.SchemaObject;
|
||||
|
||||
export interface ArraySchemaObject extends InvocationBaseSchemaObject {
|
||||
type: OpenAPIV3.ArraySchemaObjectType;
|
||||
items: OpenAPIV3.ReferenceObject | OpenAPIV3.SchemaObject;
|
||||
type: OpenAPIV3_1.ArraySchemaObjectType;
|
||||
items: OpenAPIV3_1.ReferenceObject | OpenAPIV3_1.SchemaObject;
|
||||
}
|
||||
export interface NonArraySchemaObject extends InvocationBaseSchemaObject {
|
||||
type?: OpenAPIV3.NonArraySchemaObjectType;
|
||||
type?: OpenAPIV3_1.NonArraySchemaObjectType;
|
||||
}
|
||||
|
||||
export type InvocationSchemaObject = (
|
||||
@ -1221,41 +1225,41 @@ export type InvocationSchemaObject = (
|
||||
) & { class: 'invocation' };
|
||||
|
||||
export const isSchemaObject = (
|
||||
obj: OpenAPIV3.ReferenceObject | OpenAPIV3.SchemaObject | undefined
|
||||
): obj is OpenAPIV3.SchemaObject => Boolean(obj && !('$ref' in obj));
|
||||
obj: OpenAPIV3_1.ReferenceObject | OpenAPIV3_1.SchemaObject | undefined
|
||||
): obj is OpenAPIV3_1.SchemaObject => Boolean(obj && !('$ref' in obj));
|
||||
|
||||
export const isArraySchemaObject = (
|
||||
obj: OpenAPIV3.ReferenceObject | OpenAPIV3.SchemaObject | undefined
|
||||
): obj is OpenAPIV3.ArraySchemaObject =>
|
||||
obj: OpenAPIV3_1.ReferenceObject | OpenAPIV3_1.SchemaObject | undefined
|
||||
): obj is OpenAPIV3_1.ArraySchemaObject =>
|
||||
Boolean(obj && !('$ref' in obj) && obj.type === 'array');
|
||||
|
||||
export const isNonArraySchemaObject = (
|
||||
obj: OpenAPIV3.ReferenceObject | OpenAPIV3.SchemaObject | undefined
|
||||
): obj is OpenAPIV3.NonArraySchemaObject =>
|
||||
obj: OpenAPIV3_1.ReferenceObject | OpenAPIV3_1.SchemaObject | undefined
|
||||
): obj is OpenAPIV3_1.NonArraySchemaObject =>
|
||||
Boolean(obj && !('$ref' in obj) && obj.type !== 'array');
|
||||
|
||||
export const isRefObject = (
|
||||
obj: OpenAPIV3.ReferenceObject | OpenAPIV3.SchemaObject | undefined
|
||||
): obj is OpenAPIV3.ReferenceObject => Boolean(obj && '$ref' in obj);
|
||||
obj: OpenAPIV3_1.ReferenceObject | OpenAPIV3_1.SchemaObject | undefined
|
||||
): obj is OpenAPIV3_1.ReferenceObject => Boolean(obj && '$ref' in obj);
|
||||
|
||||
export const isInvocationSchemaObject = (
|
||||
obj:
|
||||
| OpenAPIV3.ReferenceObject
|
||||
| OpenAPIV3.SchemaObject
|
||||
| OpenAPIV3_1.ReferenceObject
|
||||
| OpenAPIV3_1.SchemaObject
|
||||
| InvocationSchemaObject
|
||||
): obj is InvocationSchemaObject =>
|
||||
'class' in obj && obj.class === 'invocation';
|
||||
|
||||
export const isInvocationOutputSchemaObject = (
|
||||
obj:
|
||||
| OpenAPIV3.ReferenceObject
|
||||
| OpenAPIV3.SchemaObject
|
||||
| OpenAPIV3_1.ReferenceObject
|
||||
| OpenAPIV3_1.SchemaObject
|
||||
| InvocationOutputSchemaObject
|
||||
): obj is InvocationOutputSchemaObject =>
|
||||
'class' in obj && obj.class === 'output';
|
||||
|
||||
export const isInvocationFieldSchema = (
|
||||
obj: OpenAPIV3.ReferenceObject | OpenAPIV3.SchemaObject
|
||||
obj: OpenAPIV3_1.ReferenceObject | OpenAPIV3_1.SchemaObject
|
||||
): obj is InvocationFieldSchema => !('$ref' in obj);
|
||||
|
||||
export type InvocationEdgeExtra = { type: 'default' | 'collapsed' };
|
||||
|
@ -1,5 +1,12 @@
|
||||
import { isBoolean, isInteger, isNumber, isString } from 'lodash-es';
|
||||
import { OpenAPIV3 } from 'openapi-types';
|
||||
import {
|
||||
isArray,
|
||||
isBoolean,
|
||||
isInteger,
|
||||
isNumber,
|
||||
isString,
|
||||
startCase,
|
||||
} from 'lodash-es';
|
||||
import { OpenAPIV3_1 } from 'openapi-types';
|
||||
import {
|
||||
COLLECTION_MAP,
|
||||
POLYMORPHIC_TYPES,
|
||||
@ -72,6 +79,7 @@ import {
|
||||
T2IAdapterCollectionInputFieldTemplate,
|
||||
BoardInputFieldTemplate,
|
||||
InputFieldTemplate,
|
||||
OpenAPIV3_1SchemaOrRef,
|
||||
} from '../types/types';
|
||||
import { ControlField } from 'services/api/types';
|
||||
|
||||
@ -90,7 +98,7 @@ export type BuildInputFieldArg = {
|
||||
* @example
|
||||
* refObjectToFieldType({ "$ref": "#/components/schemas/ImageField" }) --> 'ImageField'
|
||||
*/
|
||||
export const refObjectToSchemaName = (refObject: OpenAPIV3.ReferenceObject) =>
|
||||
export const refObjectToSchemaName = (refObject: OpenAPIV3_1.ReferenceObject) =>
|
||||
refObject.$ref.split('/').slice(-1)[0];
|
||||
|
||||
const buildIntegerInputFieldTemplate = ({
|
||||
@ -111,7 +119,10 @@ const buildIntegerInputFieldTemplate = ({
|
||||
template.maximum = schemaObject.maximum;
|
||||
}
|
||||
|
||||
if (schemaObject.exclusiveMaximum !== undefined) {
|
||||
if (
|
||||
schemaObject.exclusiveMaximum !== undefined &&
|
||||
isNumber(schemaObject.exclusiveMaximum)
|
||||
) {
|
||||
template.exclusiveMaximum = schemaObject.exclusiveMaximum;
|
||||
}
|
||||
|
||||
@ -119,7 +130,10 @@ const buildIntegerInputFieldTemplate = ({
|
||||
template.minimum = schemaObject.minimum;
|
||||
}
|
||||
|
||||
if (schemaObject.exclusiveMinimum !== undefined) {
|
||||
if (
|
||||
schemaObject.exclusiveMinimum !== undefined &&
|
||||
isNumber(schemaObject.exclusiveMinimum)
|
||||
) {
|
||||
template.exclusiveMinimum = schemaObject.exclusiveMinimum;
|
||||
}
|
||||
|
||||
@ -144,7 +158,10 @@ const buildIntegerPolymorphicInputFieldTemplate = ({
|
||||
template.maximum = schemaObject.maximum;
|
||||
}
|
||||
|
||||
if (schemaObject.exclusiveMaximum !== undefined) {
|
||||
if (
|
||||
schemaObject.exclusiveMaximum !== undefined &&
|
||||
isNumber(schemaObject.exclusiveMaximum)
|
||||
) {
|
||||
template.exclusiveMaximum = schemaObject.exclusiveMaximum;
|
||||
}
|
||||
|
||||
@ -152,7 +169,10 @@ const buildIntegerPolymorphicInputFieldTemplate = ({
|
||||
template.minimum = schemaObject.minimum;
|
||||
}
|
||||
|
||||
if (schemaObject.exclusiveMinimum !== undefined) {
|
||||
if (
|
||||
schemaObject.exclusiveMinimum !== undefined &&
|
||||
isNumber(schemaObject.exclusiveMinimum)
|
||||
) {
|
||||
template.exclusiveMinimum = schemaObject.exclusiveMinimum;
|
||||
}
|
||||
|
||||
@ -195,7 +215,10 @@ const buildFloatInputFieldTemplate = ({
|
||||
template.maximum = schemaObject.maximum;
|
||||
}
|
||||
|
||||
if (schemaObject.exclusiveMaximum !== undefined) {
|
||||
if (
|
||||
schemaObject.exclusiveMaximum !== undefined &&
|
||||
isNumber(schemaObject.exclusiveMaximum)
|
||||
) {
|
||||
template.exclusiveMaximum = schemaObject.exclusiveMaximum;
|
||||
}
|
||||
|
||||
@ -203,7 +226,10 @@ const buildFloatInputFieldTemplate = ({
|
||||
template.minimum = schemaObject.minimum;
|
||||
}
|
||||
|
||||
if (schemaObject.exclusiveMinimum !== undefined) {
|
||||
if (
|
||||
schemaObject.exclusiveMinimum !== undefined &&
|
||||
isNumber(schemaObject.exclusiveMinimum)
|
||||
) {
|
||||
template.exclusiveMinimum = schemaObject.exclusiveMinimum;
|
||||
}
|
||||
|
||||
@ -227,7 +253,10 @@ const buildFloatPolymorphicInputFieldTemplate = ({
|
||||
template.maximum = schemaObject.maximum;
|
||||
}
|
||||
|
||||
if (schemaObject.exclusiveMaximum !== undefined) {
|
||||
if (
|
||||
schemaObject.exclusiveMaximum !== undefined &&
|
||||
isNumber(schemaObject.exclusiveMaximum)
|
||||
) {
|
||||
template.exclusiveMaximum = schemaObject.exclusiveMaximum;
|
||||
}
|
||||
|
||||
@ -235,7 +264,10 @@ const buildFloatPolymorphicInputFieldTemplate = ({
|
||||
template.minimum = schemaObject.minimum;
|
||||
}
|
||||
|
||||
if (schemaObject.exclusiveMinimum !== undefined) {
|
||||
if (
|
||||
schemaObject.exclusiveMinimum !== undefined &&
|
||||
isNumber(schemaObject.exclusiveMinimum)
|
||||
) {
|
||||
template.exclusiveMinimum = schemaObject.exclusiveMinimum;
|
||||
}
|
||||
return template;
|
||||
@ -872,84 +904,106 @@ const buildSchedulerInputFieldTemplate = ({
|
||||
};
|
||||
|
||||
export const getFieldType = (
|
||||
schemaObject: InvocationFieldSchema
|
||||
schemaObject: OpenAPIV3_1SchemaOrRef
|
||||
): string | undefined => {
|
||||
if (schemaObject?.ui_type) {
|
||||
return schemaObject.ui_type;
|
||||
} else if (!schemaObject.type) {
|
||||
// if schemaObject has no type, then it should have one of allOf, anyOf, oneOf
|
||||
if (isSchemaObject(schemaObject)) {
|
||||
if (!schemaObject.type) {
|
||||
// if schemaObject has no type, then it should have one of allOf, anyOf, oneOf
|
||||
|
||||
if (schemaObject.allOf) {
|
||||
const allOf = schemaObject.allOf;
|
||||
if (allOf && allOf[0] && isRefObject(allOf[0])) {
|
||||
return refObjectToSchemaName(allOf[0]);
|
||||
}
|
||||
} else if (schemaObject.anyOf) {
|
||||
const anyOf = schemaObject.anyOf;
|
||||
/**
|
||||
* Handle Polymorphic inputs, eg string | string[]. In OpenAPI, this is:
|
||||
* - an `anyOf` with two items
|
||||
* - one is an `ArraySchemaObject` with a single `SchemaObject or ReferenceObject` of type T in its `items`
|
||||
* - the other is a `SchemaObject` or `ReferenceObject` of type T
|
||||
*
|
||||
* Any other cases we ignore.
|
||||
*/
|
||||
|
||||
let firstType: string | undefined;
|
||||
let secondType: string | undefined;
|
||||
|
||||
if (isArraySchemaObject(anyOf[0])) {
|
||||
// first is array, second is not
|
||||
const first = anyOf[0].items;
|
||||
const second = anyOf[1];
|
||||
if (isRefObject(first) && isRefObject(second)) {
|
||||
firstType = refObjectToSchemaName(first);
|
||||
secondType = refObjectToSchemaName(second);
|
||||
} else if (
|
||||
isNonArraySchemaObject(first) &&
|
||||
isNonArraySchemaObject(second)
|
||||
) {
|
||||
firstType = first.type;
|
||||
secondType = second.type;
|
||||
if (schemaObject.allOf) {
|
||||
const allOf = schemaObject.allOf;
|
||||
if (allOf && allOf[0] && isRefObject(allOf[0])) {
|
||||
return refObjectToSchemaName(allOf[0]);
|
||||
}
|
||||
} else if (isArraySchemaObject(anyOf[1])) {
|
||||
// first is not array, second is
|
||||
const first = anyOf[0];
|
||||
const second = anyOf[1].items;
|
||||
if (isRefObject(first) && isRefObject(second)) {
|
||||
firstType = refObjectToSchemaName(first);
|
||||
secondType = refObjectToSchemaName(second);
|
||||
} else if (
|
||||
isNonArraySchemaObject(first) &&
|
||||
isNonArraySchemaObject(second)
|
||||
) {
|
||||
firstType = first.type;
|
||||
secondType = second.type;
|
||||
} else if (schemaObject.anyOf) {
|
||||
// ignore null types
|
||||
const anyOf = schemaObject.anyOf.filter((i) => {
|
||||
if (isSchemaObject(i)) {
|
||||
if (i.type === 'null') {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
});
|
||||
if (anyOf.length === 1) {
|
||||
if (isRefObject(anyOf[0])) {
|
||||
return refObjectToSchemaName(anyOf[0]);
|
||||
} else if (isSchemaObject(anyOf[0])) {
|
||||
return getFieldType(anyOf[0]);
|
||||
}
|
||||
}
|
||||
/**
|
||||
* Handle Polymorphic inputs, eg string | string[]. In OpenAPI, this is:
|
||||
* - an `anyOf` with two items
|
||||
* - one is an `ArraySchemaObject` with a single `SchemaObject or ReferenceObject` of type T in its `items`
|
||||
* - the other is a `SchemaObject` or `ReferenceObject` of type T
|
||||
*
|
||||
* Any other cases we ignore.
|
||||
*/
|
||||
|
||||
let firstType: string | undefined;
|
||||
let secondType: string | undefined;
|
||||
|
||||
if (isArraySchemaObject(anyOf[0])) {
|
||||
// first is array, second is not
|
||||
const first = anyOf[0].items;
|
||||
const second = anyOf[1];
|
||||
if (isRefObject(first) && isRefObject(second)) {
|
||||
firstType = refObjectToSchemaName(first);
|
||||
secondType = refObjectToSchemaName(second);
|
||||
} else if (
|
||||
isNonArraySchemaObject(first) &&
|
||||
isNonArraySchemaObject(second)
|
||||
) {
|
||||
firstType = first.type;
|
||||
secondType = second.type;
|
||||
}
|
||||
} else if (isArraySchemaObject(anyOf[1])) {
|
||||
// first is not array, second is
|
||||
const first = anyOf[0];
|
||||
const second = anyOf[1].items;
|
||||
if (isRefObject(first) && isRefObject(second)) {
|
||||
firstType = refObjectToSchemaName(first);
|
||||
secondType = refObjectToSchemaName(second);
|
||||
} else if (
|
||||
isNonArraySchemaObject(first) &&
|
||||
isNonArraySchemaObject(second)
|
||||
) {
|
||||
firstType = first.type;
|
||||
secondType = second.type;
|
||||
}
|
||||
}
|
||||
if (firstType === secondType && isPolymorphicItemType(firstType)) {
|
||||
return SINGLE_TO_POLYMORPHIC_MAP[firstType];
|
||||
}
|
||||
}
|
||||
if (firstType === secondType && isPolymorphicItemType(firstType)) {
|
||||
return SINGLE_TO_POLYMORPHIC_MAP[firstType];
|
||||
} else if (schemaObject.enum) {
|
||||
return 'enum';
|
||||
} else if (schemaObject.type) {
|
||||
if (schemaObject.type === 'number') {
|
||||
// floats are "number" in OpenAPI, while ints are "integer" - we need to distinguish them
|
||||
return 'float';
|
||||
} else if (schemaObject.type === 'array') {
|
||||
const itemType = isSchemaObject(schemaObject.items)
|
||||
? schemaObject.items.type
|
||||
: refObjectToSchemaName(schemaObject.items);
|
||||
|
||||
if (isArray(itemType)) {
|
||||
// This is a nested array, which we don't support
|
||||
return;
|
||||
}
|
||||
|
||||
if (isCollectionItemType(itemType)) {
|
||||
return COLLECTION_MAP[itemType];
|
||||
}
|
||||
|
||||
return;
|
||||
} else if (!isArray(schemaObject.type)) {
|
||||
return schemaObject.type;
|
||||
}
|
||||
}
|
||||
} else if (schemaObject.enum) {
|
||||
return 'enum';
|
||||
} else if (schemaObject.type) {
|
||||
if (schemaObject.type === 'number') {
|
||||
// floats are "number" in OpenAPI, while ints are "integer" - we need to distinguish them
|
||||
return 'float';
|
||||
} else if (schemaObject.type === 'array') {
|
||||
const itemType = isSchemaObject(schemaObject.items)
|
||||
? schemaObject.items.type
|
||||
: refObjectToSchemaName(schemaObject.items);
|
||||
|
||||
if (isCollectionItemType(itemType)) {
|
||||
return COLLECTION_MAP[itemType];
|
||||
}
|
||||
|
||||
return;
|
||||
} else {
|
||||
return schemaObject.type;
|
||||
}
|
||||
} else if (isRefObject(schemaObject)) {
|
||||
return refObjectToSchemaName(schemaObject);
|
||||
}
|
||||
return;
|
||||
};
|
||||
@ -1025,7 +1079,15 @@ export const buildInputFieldTemplate = (
|
||||
name: string,
|
||||
fieldType: FieldType
|
||||
) => {
|
||||
const { input, ui_hidden, ui_component, ui_type, ui_order } = fieldSchema;
|
||||
const {
|
||||
input,
|
||||
ui_hidden,
|
||||
ui_component,
|
||||
ui_type,
|
||||
ui_order,
|
||||
ui_choice_labels,
|
||||
item_default,
|
||||
} = fieldSchema;
|
||||
|
||||
const extra = {
|
||||
// TODO: Can we support polymorphic inputs in the UI?
|
||||
@ -1035,11 +1097,13 @@ export const buildInputFieldTemplate = (
|
||||
ui_type,
|
||||
required: nodeSchema.required?.includes(name) ?? false,
|
||||
ui_order,
|
||||
ui_choice_labels,
|
||||
item_default,
|
||||
};
|
||||
|
||||
const baseField = {
|
||||
name,
|
||||
title: fieldSchema.title ?? '',
|
||||
title: fieldSchema.title ?? (name ? startCase(name) : ''),
|
||||
description: fieldSchema.description ?? '',
|
||||
fieldKind: 'input' as const,
|
||||
...extra,
|
||||
|
@ -1,7 +1,7 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import { parseify } from 'common/util/serialize';
|
||||
import { reduce } from 'lodash-es';
|
||||
import { OpenAPIV3 } from 'openapi-types';
|
||||
import { reduce, startCase } from 'lodash-es';
|
||||
import { OpenAPIV3_1 } from 'openapi-types';
|
||||
import { AnyInvocationType } from 'services/events/types';
|
||||
import {
|
||||
FieldType,
|
||||
@ -60,7 +60,7 @@ const isNotInDenylist = (schema: InvocationSchemaObject) =>
|
||||
!invocationDenylist.includes(schema.properties.type.default);
|
||||
|
||||
export const parseSchema = (
|
||||
openAPI: OpenAPIV3.Document,
|
||||
openAPI: OpenAPIV3_1.Document,
|
||||
nodesAllowlistExtra: string[] | undefined = undefined,
|
||||
nodesDenylistExtra: string[] | undefined = undefined
|
||||
): Record<string, InvocationTemplate> => {
|
||||
@ -110,7 +110,7 @@ export const parseSchema = (
|
||||
return inputsAccumulator;
|
||||
}
|
||||
|
||||
const fieldType = getFieldType(property);
|
||||
const fieldType = property.ui_type ?? getFieldType(property);
|
||||
|
||||
if (!isFieldType(fieldType)) {
|
||||
logger('nodes').warn(
|
||||
@ -209,7 +209,7 @@ export const parseSchema = (
|
||||
return outputsAccumulator;
|
||||
}
|
||||
|
||||
const fieldType = getFieldType(property);
|
||||
const fieldType = property.ui_type ?? getFieldType(property);
|
||||
|
||||
if (!isFieldType(fieldType)) {
|
||||
logger('nodes').warn(
|
||||
@ -222,7 +222,8 @@ export const parseSchema = (
|
||||
outputsAccumulator[propertyName] = {
|
||||
fieldKind: 'output',
|
||||
name: propertyName,
|
||||
title: property.title ?? '',
|
||||
title:
|
||||
property.title ?? (propertyName ? startCase(propertyName) : ''),
|
||||
description: property.description ?? '',
|
||||
type: fieldType,
|
||||
ui_hidden: property.ui_hidden ?? false,
|
||||
|
@ -7,7 +7,7 @@ const QueueItemCard = ({
|
||||
session_queue_item,
|
||||
label,
|
||||
}: {
|
||||
session_queue_item?: components['schemas']['SessionQueueItem'];
|
||||
session_queue_item?: components['schemas']['SessionQueueItem'] | null;
|
||||
label: string;
|
||||
}) => {
|
||||
return (
|
||||
|
@ -112,7 +112,7 @@ export default function MergeModelsPanel() {
|
||||
}
|
||||
});
|
||||
|
||||
const mergeModelsInfo: MergeModelConfig = {
|
||||
const mergeModelsInfo: MergeModelConfig['body'] = {
|
||||
model_names: models_names,
|
||||
merged_model_name:
|
||||
mergedModelName !== '' ? mergedModelName : models_names.join('-'),
|
||||
@ -125,7 +125,7 @@ export default function MergeModelsPanel() {
|
||||
|
||||
mergeModels({
|
||||
base_model: baseModel,
|
||||
body: mergeModelsInfo,
|
||||
body: { body: mergeModelsInfo },
|
||||
})
|
||||
.unwrap()
|
||||
.then((_) => {
|
||||
|
@ -520,7 +520,7 @@ export const imagesApi = api.injectEndpoints({
|
||||
// assume all images are on the same board/category
|
||||
if (images[0]) {
|
||||
const categories = getCategories(images[0]);
|
||||
const boardId = images[0].board_id;
|
||||
const boardId = images[0].board_id ?? undefined;
|
||||
|
||||
return [
|
||||
{
|
||||
@ -637,7 +637,7 @@ export const imagesApi = api.injectEndpoints({
|
||||
// assume all images are on the same board/category
|
||||
if (images[0]) {
|
||||
const categories = getCategories(images[0]);
|
||||
const boardId = images[0].board_id;
|
||||
const boardId = images[0].board_id ?? undefined;
|
||||
return [
|
||||
{
|
||||
type: 'ImageList',
|
||||
|
3486
invokeai/frontend/web/src/services/api/schema.d.ts
vendored
3486
invokeai/frontend/web/src/services/api/schema.d.ts
vendored
File diff suppressed because one or more lines are too long
110
pyproject.toml
110
pyproject.toml
@ -35,10 +35,10 @@ dependencies = [
|
||||
"accelerate~=0.23.0",
|
||||
"albumentations",
|
||||
"click",
|
||||
"clip_anytorch", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip",
|
||||
"clip_anytorch", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip",
|
||||
"compel~=2.0.2",
|
||||
"controlnet-aux>=0.0.6",
|
||||
"timm==0.6.13", # needed to override timm latest in controlnet_aux, see https://github.com/isl-org/ZoeDepth/issues/26
|
||||
"timm==0.6.13", # needed to override timm latest in controlnet_aux, see https://github.com/isl-org/ZoeDepth/issues/26
|
||||
"datasets",
|
||||
# When bumping diffusers beyond 0.21, make sure to address this:
|
||||
# https://github.com/invoke-ai/InvokeAI/blob/fc09ab7e13cb7ca5389100d149b6422ace7b8ed3/invokeai/app/invocations/latent.py#L513
|
||||
@ -48,19 +48,20 @@ dependencies = [
|
||||
"easing-functions",
|
||||
"einops",
|
||||
"facexlib",
|
||||
"fastapi==0.88.0",
|
||||
"fastapi-events==0.8.0",
|
||||
"fastapi~=0.103.2",
|
||||
"fastapi-events~=0.9.1",
|
||||
"huggingface-hub~=0.16.4",
|
||||
"invisible-watermark~=0.2.0", # needed to install SDXL base and refiner using their repo_ids
|
||||
"matplotlib", # needed for plotting of Penner easing functions
|
||||
"mediapipe", # needed for "mediapipeface" controlnet model
|
||||
"invisible-watermark~=0.2.0", # needed to install SDXL base and refiner using their repo_ids
|
||||
"matplotlib", # needed for plotting of Penner easing functions
|
||||
"mediapipe", # needed for "mediapipeface" controlnet model
|
||||
"numpy",
|
||||
"npyscreen",
|
||||
"omegaconf",
|
||||
"onnx",
|
||||
"onnxruntime",
|
||||
"opencv-python",
|
||||
"pydantic==1.*",
|
||||
"pydantic~=2.4.2",
|
||||
"pydantic-settings~=2.0.3",
|
||||
"picklescan",
|
||||
"pillow",
|
||||
"prompt-toolkit",
|
||||
@ -95,33 +96,25 @@ dependencies = [
|
||||
"mkdocs-git-revision-date-localized-plugin",
|
||||
"mkdocs-redirects==1.2.0",
|
||||
]
|
||||
"dev" = [
|
||||
"jurigged",
|
||||
"pudb",
|
||||
]
|
||||
"dev" = ["jurigged", "pudb"]
|
||||
"test" = [
|
||||
"black",
|
||||
"flake8",
|
||||
"Flake8-pyproject",
|
||||
"isort",
|
||||
"mypy",
|
||||
"pre-commit",
|
||||
"pytest>6.0.0",
|
||||
"pytest-cov",
|
||||
"pytest-datadir",
|
||||
]
|
||||
"xformers" = [
|
||||
"xformers~=0.0.19; sys_platform!='darwin'",
|
||||
"triton; sys_platform=='linux'",
|
||||
]
|
||||
"onnx" = [
|
||||
"onnxruntime",
|
||||
]
|
||||
"onnx-cuda" = [
|
||||
"onnxruntime-gpu",
|
||||
]
|
||||
"onnx-directml" = [
|
||||
"onnxruntime-directml",
|
||||
"xformers~=0.0.19; sys_platform!='darwin'",
|
||||
"triton; sys_platform=='linux'",
|
||||
]
|
||||
"onnx" = ["onnxruntime"]
|
||||
"onnx-cuda" = ["onnxruntime-gpu"]
|
||||
"onnx-directml" = ["onnxruntime-directml"]
|
||||
|
||||
[project.scripts]
|
||||
|
||||
@ -163,12 +156,15 @@ version = { attr = "invokeai.version.__version__" }
|
||||
[tool.setuptools.packages.find]
|
||||
"where" = ["."]
|
||||
"include" = [
|
||||
"invokeai.assets.fonts*","invokeai.version*",
|
||||
"invokeai.generator*","invokeai.backend*",
|
||||
"invokeai.frontend*", "invokeai.frontend.web.dist*",
|
||||
"invokeai.frontend.web.static*",
|
||||
"invokeai.configs*",
|
||||
"invokeai.app*",
|
||||
"invokeai.assets.fonts*",
|
||||
"invokeai.version*",
|
||||
"invokeai.generator*",
|
||||
"invokeai.backend*",
|
||||
"invokeai.frontend*",
|
||||
"invokeai.frontend.web.dist*",
|
||||
"invokeai.frontend.web.static*",
|
||||
"invokeai.configs*",
|
||||
"invokeai.app*",
|
||||
]
|
||||
|
||||
[tool.setuptools.package-data]
|
||||
@ -182,7 +178,7 @@ version = { attr = "invokeai.version.__version__" }
|
||||
[tool.pytest.ini_options]
|
||||
addopts = "--cov-report term --cov-report html --cov-report xml --strict-markers -m \"not slow\""
|
||||
markers = [
|
||||
"slow: Marks tests as slow. Disabled by default. To run all tests, use -m \"\". To run only slow tests, use -m \"slow\"."
|
||||
"slow: Marks tests as slow. Disabled by default. To run all tests, use -m \"\". To run only slow tests, use -m \"slow\".",
|
||||
]
|
||||
[tool.coverage.run]
|
||||
branch = true
|
||||
@ -190,7 +186,7 @@ source = ["invokeai"]
|
||||
omit = ["*tests*", "*migrations*", ".venv/*", "*.env"]
|
||||
[tool.coverage.report]
|
||||
show_missing = true
|
||||
fail_under = 85 # let's set something sensible on Day 1 ...
|
||||
fail_under = 85 # let's set something sensible on Day 1 ...
|
||||
[tool.coverage.json]
|
||||
output = "coverage/coverage.json"
|
||||
pretty_print = true
|
||||
@ -209,7 +205,7 @@ exclude = [
|
||||
"__pycache__",
|
||||
"build",
|
||||
"dist",
|
||||
"invokeai/frontend/web/node_modules/"
|
||||
"invokeai/frontend/web/node_modules/",
|
||||
]
|
||||
|
||||
[tool.black]
|
||||
@ -218,3 +214,53 @@ line-length = 120
|
||||
[tool.isort]
|
||||
profile = "black"
|
||||
line_length = 120
|
||||
|
||||
[tool.mypy]
|
||||
ignore_missing_imports = true # ignores missing types in third-party libraries
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
follow_imports = "skip"
|
||||
module = [
|
||||
"invokeai.app.api.routers.models",
|
||||
"invokeai.app.invocations.compel",
|
||||
"invokeai.app.invocations.latent",
|
||||
"invokeai.app.services.config.config_base",
|
||||
"invokeai.app.services.config.config_default",
|
||||
"invokeai.app.services.invocation_stats.invocation_stats_default",
|
||||
"invokeai.app.services.model_manager.model_manager_base",
|
||||
"invokeai.app.services.model_manager.model_manager_default",
|
||||
"invokeai.app.util.controlnet_utils",
|
||||
"invokeai.backend.image_util.txt2mask",
|
||||
"invokeai.backend.image_util.safety_checker",
|
||||
"invokeai.backend.image_util.patchmatch",
|
||||
"invokeai.backend.image_util.invisible_watermark",
|
||||
"invokeai.backend.install.model_install_backend",
|
||||
"invokeai.backend.ip_adapter.ip_adapter",
|
||||
"invokeai.backend.ip_adapter.resampler",
|
||||
"invokeai.backend.ip_adapter.unet_patcher",
|
||||
"invokeai.backend.model_management.convert_ckpt_to_diffusers",
|
||||
"invokeai.backend.model_management.lora",
|
||||
"invokeai.backend.model_management.model_cache",
|
||||
"invokeai.backend.model_management.model_manager",
|
||||
"invokeai.backend.model_management.model_merge",
|
||||
"invokeai.backend.model_management.model_probe",
|
||||
"invokeai.backend.model_management.model_search",
|
||||
"invokeai.backend.model_management.models.*", # this is needed to ignore the module's `__init__.py`
|
||||
"invokeai.backend.model_management.models.base",
|
||||
"invokeai.backend.model_management.models.controlnet",
|
||||
"invokeai.backend.model_management.models.ip_adapter",
|
||||
"invokeai.backend.model_management.models.lora",
|
||||
"invokeai.backend.model_management.models.sdxl",
|
||||
"invokeai.backend.model_management.models.stable_diffusion",
|
||||
"invokeai.backend.model_management.models.vae",
|
||||
"invokeai.backend.model_management.seamless",
|
||||
"invokeai.backend.model_management.util",
|
||||
"invokeai.backend.stable_diffusion.diffusers_pipeline",
|
||||
"invokeai.backend.stable_diffusion.diffusion.cross_attention_control",
|
||||
"invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion",
|
||||
"invokeai.backend.util.hotfixes",
|
||||
"invokeai.backend.util.logging",
|
||||
"invokeai.backend.util.mps_fixes",
|
||||
"invokeai.backend.util.util",
|
||||
"invokeai.frontend.install.model_install",
|
||||
]
|
||||
|
@ -1,4 +1,5 @@
|
||||
import pytest
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
@ -593,20 +594,21 @@ def test_graph_can_serialize():
|
||||
g.add_edge(e)
|
||||
|
||||
# Not throwing on this line is sufficient
|
||||
_ = g.json()
|
||||
_ = g.model_dump_json()
|
||||
|
||||
|
||||
def test_graph_can_deserialize():
|
||||
g = Graph()
|
||||
n1 = TextToImageTestInvocation(id="1", prompt="Banana sushi")
|
||||
n2 = ESRGANInvocation(id="2")
|
||||
n2 = ImageToImageTestInvocation(id="2")
|
||||
g.add_node(n1)
|
||||
g.add_node(n2)
|
||||
e = create_edge(n1.id, "image", n2.id, "image")
|
||||
g.add_edge(e)
|
||||
|
||||
json = g.json()
|
||||
g2 = Graph.parse_raw(json)
|
||||
json = g.model_dump_json()
|
||||
adapter_graph = TypeAdapter(Graph)
|
||||
g2 = adapter_graph.validate_json(json)
|
||||
|
||||
assert g2 is not None
|
||||
assert g2.nodes["1"] is not None
|
||||
@ -619,7 +621,7 @@ def test_graph_can_deserialize():
|
||||
|
||||
|
||||
def test_invocation_decorator():
|
||||
invocation_type = "test_invocation"
|
||||
invocation_type = "test_invocation_decorator"
|
||||
title = "Test Invocation"
|
||||
tags = ["first", "second", "third"]
|
||||
category = "category"
|
||||
@ -630,7 +632,7 @@ def test_invocation_decorator():
|
||||
def invoke(self):
|
||||
pass
|
||||
|
||||
schema = TestInvocation.schema()
|
||||
schema = TestInvocation.model_json_schema()
|
||||
|
||||
assert schema.get("title") == title
|
||||
assert schema.get("tags") == tags
|
||||
@ -640,18 +642,17 @@ def test_invocation_decorator():
|
||||
|
||||
|
||||
def test_invocation_version_must_be_semver():
|
||||
invocation_type = "test_invocation"
|
||||
valid_version = "1.0.0"
|
||||
invalid_version = "not_semver"
|
||||
|
||||
@invocation(invocation_type, version=valid_version)
|
||||
@invocation("test_invocation_version_valid", version=valid_version)
|
||||
class ValidVersionInvocation(BaseInvocation):
|
||||
def invoke(self):
|
||||
pass
|
||||
|
||||
with pytest.raises(InvalidVersionError):
|
||||
|
||||
@invocation(invocation_type, version=invalid_version)
|
||||
@invocation("test_invocation_version_invalid", version=invalid_version)
|
||||
class InvalidVersionInvocation(BaseInvocation):
|
||||
def invoke(self):
|
||||
pass
|
||||
@ -694,4 +695,4 @@ def test_ints_do_not_accept_floats():
|
||||
def test_graph_can_generate_schema():
|
||||
# Not throwing on this line is sufficient
|
||||
# NOTE: if this test fails, it's PROBABLY because a new invocation type is breaking schema generation
|
||||
_ = Graph.schema_json(indent=2)
|
||||
_ = Graph.model_json_schema()
|
||||
|
@ -1,5 +1,5 @@
|
||||
import pytest
|
||||
from pydantic import ValidationError, parse_raw_as
|
||||
from pydantic import TypeAdapter, ValidationError
|
||||
|
||||
from invokeai.app.services.session_queue.session_queue_common import (
|
||||
Batch,
|
||||
@ -150,8 +150,9 @@ def test_prepare_values_to_insert(batch_data_collection, batch_graph):
|
||||
values = prepare_values_to_insert(queue_id="default", batch=b, priority=0, max_new_queue_items=1000)
|
||||
assert len(values) == 8
|
||||
|
||||
session_adapter = TypeAdapter(GraphExecutionState)
|
||||
# graph should be serialized
|
||||
ges = parse_raw_as(GraphExecutionState, values[0].session)
|
||||
ges = session_adapter.validate_json(values[0].session)
|
||||
|
||||
# graph values should be populated
|
||||
assert ges.graph.get_node("1").prompt == "Banana sushi"
|
||||
@ -160,15 +161,16 @@ def test_prepare_values_to_insert(batch_data_collection, batch_graph):
|
||||
assert ges.graph.get_node("4").prompt == "Nissan"
|
||||
|
||||
# session ids should match deserialized graph
|
||||
assert [v.session_id for v in values] == [parse_raw_as(GraphExecutionState, v.session).id for v in values]
|
||||
assert [v.session_id for v in values] == [session_adapter.validate_json(v.session).id for v in values]
|
||||
|
||||
# should unique session ids
|
||||
sids = [v.session_id for v in values]
|
||||
assert len(sids) == len(set(sids))
|
||||
|
||||
nfv_list_adapter = TypeAdapter(list[NodeFieldValue])
|
||||
# should have 3 node field values
|
||||
assert type(values[0].field_values) is str
|
||||
assert len(parse_raw_as(list[NodeFieldValue], values[0].field_values)) == 3
|
||||
assert len(nfv_list_adapter.validate_json(values[0].field_values)) == 3
|
||||
|
||||
# should have batch id and priority
|
||||
assert all(v.batch_id == b.batch_id for v in values)
|
||||
|
@ -15,7 +15,8 @@ class TestModel(BaseModel):
|
||||
@pytest.fixture
|
||||
def db() -> SqliteItemStorage[TestModel]:
|
||||
sqlite_db = SqliteDatabase(InvokeAIAppConfig(use_memory_db=True), InvokeAILogger.get_logger())
|
||||
return SqliteItemStorage[TestModel](db=sqlite_db, table_name="test", id_field="id")
|
||||
sqlite_item_storage = SqliteItemStorage[TestModel](db=sqlite_db, table_name="test", id_field="id")
|
||||
return sqlite_item_storage
|
||||
|
||||
|
||||
def test_sqlite_service_can_create_and_get(db: SqliteItemStorage[TestModel]):
|
||||
|
Loading…
Reference in New Issue
Block a user