mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(api): chore: pydantic & fastapi upgrade
Upgrade pydantic and fastapi to latest. - pydantic~=2.4.2 - fastapi~=103.2 - fastapi-events~=0.9.1 **Big Changes** There are a number of logic changes needed to support pydantic v2. Most changes are very simple, like using the new methods to serialized and deserialize models, but there are a few more complex changes. **Invocations** The biggest change relates to invocation creation, instantiation and validation. Because pydantic v2 moves all validation logic into the rust pydantic-core, we may no longer directly stick our fingers into the validation pie. Previously, we (ab)used models and fields to allow invocation fields to be optional at instantiation, but required when `invoke()` is called. We directly manipulated the fields and invocation models when calling `invoke()`. With pydantic v2, this is much more involved. Changes to the python wrapper do not propagate down to the rust validation logic - you have to rebuild the model. This causes problem with concurrent access to the invocation classes and is not a free operation. This logic has been totally refactored and we do not need to change the model any more. The details are in `baseinvocation.py`, in the `InputField` function and `BaseInvocation.invoke_internal()` method. In the end, this implementation is cleaner. **Invocation Fields** In pydantic v2, you can no longer directly add or remove fields from a model. Previously, we did this to add the `type` field to invocations. **Invocation Decorators** With pydantic v2, we instead use the imperative `create_model()` API to create a new model with the additional field. This is done in `baseinvocation.py` in the `invocation()` wrapper. A similar technique is used for `invocation_output()`. **Minor Changes** There are a number of minor changes around the pydantic v2 models API. **Protected `model_` Namespace** All models' pydantic-provided methods and attributes are prefixed with `model_` and this is considered a protected namespace. This causes some conflict, because "model" means something to us, and we have a ton of pydantic models with attributes starting with "model_". Forunately, there are no direct conflicts. However, in any pydantic model where we define an attribute or method that starts with "model_", we must tell set the protected namespaces to an empty tuple. ```py class IPAdapterModelField(BaseModel): model_name: str = Field(description="Name of the IP-Adapter model") base_model: BaseModelType = Field(description="Base model") model_config = ConfigDict(protected_namespaces=()) ``` **Model Serialization** Pydantic models no longer have `Model.dict()` or `Model.json()`. Instead, we use `Model.model_dump()` or `Model.model_dump_json()`. **Model Deserialization** Pydantic models no longer have `Model.parse_obj()` or `Model.parse_raw()`, and there are no `parse_raw_as()` or `parse_obj_as()` functions. Instead, you need to create a `TypeAdapter` object to parse python objects or JSON into a model. ```py adapter_graph = TypeAdapter(Graph) deserialized_graph_from_json = adapter_graph.validate_json(graph_json) deserialized_graph_from_dict = adapter_graph.validate_python(graph_dict) ``` **Field Customisation** Pydantic `Field`s no longer accept arbitrary args. Now, you must put all additional arbitrary args in a `json_schema_extra` arg on the field. **Schema Customisation** FastAPI and pydantic schema generation now follows the OpenAPI version 3.1 spec. This necessitates two changes: - Our schema customization logic has been revised - Schema parsing to build node templates has been revised The specific aren't important, but this does present additional surface area for bugs. **Performance Improvements** Pydantic v2 is a full rewrite with a rust backend. This offers a substantial performance improvement (pydantic claims 5x to 50x depending on the task). We'll notice this the most during serialization and deserialization of sessions/graphs, which happens very very often - a couple times per node. I haven't done any benchmarks, but anecdotally, graph execution is much faster. Also, very larges graphs - like with massive iterators - are much, much faster.
This commit is contained in:
parent
19c5435332
commit
c238a7f18b
@ -42,7 +42,7 @@ async def upload_image(
|
|||||||
crop_visible: Optional[bool] = Query(default=False, description="Whether to crop the image"),
|
crop_visible: Optional[bool] = Query(default=False, description="Whether to crop the image"),
|
||||||
) -> ImageDTO:
|
) -> ImageDTO:
|
||||||
"""Uploads an image"""
|
"""Uploads an image"""
|
||||||
if not file.content_type.startswith("image"):
|
if not file.content_type or not file.content_type.startswith("image"):
|
||||||
raise HTTPException(status_code=415, detail="Not an image")
|
raise HTTPException(status_code=415, detail="Not an image")
|
||||||
|
|
||||||
contents = await file.read()
|
contents = await file.read()
|
||||||
|
@ -2,11 +2,11 @@
|
|||||||
|
|
||||||
|
|
||||||
import pathlib
|
import pathlib
|
||||||
from typing import List, Literal, Optional, Union
|
from typing import Annotated, List, Literal, Optional, Union
|
||||||
|
|
||||||
from fastapi import Body, Path, Query, Response
|
from fastapi import Body, Path, Query, Response
|
||||||
from fastapi.routing import APIRouter
|
from fastapi.routing import APIRouter
|
||||||
from pydantic import BaseModel, parse_obj_as
|
from pydantic import BaseModel, ConfigDict, Field, TypeAdapter
|
||||||
from starlette.exceptions import HTTPException
|
from starlette.exceptions import HTTPException
|
||||||
|
|
||||||
from invokeai.backend import BaseModelType, ModelType
|
from invokeai.backend import BaseModelType, ModelType
|
||||||
@ -23,8 +23,14 @@ from ..dependencies import ApiDependencies
|
|||||||
models_router = APIRouter(prefix="/v1/models", tags=["models"])
|
models_router = APIRouter(prefix="/v1/models", tags=["models"])
|
||||||
|
|
||||||
UpdateModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
UpdateModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||||
|
update_models_response_adapter = TypeAdapter(UpdateModelResponse)
|
||||||
|
|
||||||
ImportModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
ImportModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||||
|
import_models_response_adapter = TypeAdapter(ImportModelResponse)
|
||||||
|
|
||||||
ConvertModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
ConvertModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||||
|
convert_models_response_adapter = TypeAdapter(ConvertModelResponse)
|
||||||
|
|
||||||
MergeModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
MergeModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||||
ImportModelAttributes = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
ImportModelAttributes = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||||
|
|
||||||
@ -32,6 +38,11 @@ ImportModelAttributes = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
|||||||
class ModelsList(BaseModel):
|
class ModelsList(BaseModel):
|
||||||
models: list[Union[tuple(OPENAPI_MODEL_CONFIGS)]]
|
models: list[Union[tuple(OPENAPI_MODEL_CONFIGS)]]
|
||||||
|
|
||||||
|
model_config = ConfigDict(use_enum_values=True)
|
||||||
|
|
||||||
|
|
||||||
|
models_list_adapter = TypeAdapter(ModelsList)
|
||||||
|
|
||||||
|
|
||||||
@models_router.get(
|
@models_router.get(
|
||||||
"/",
|
"/",
|
||||||
@ -49,7 +60,7 @@ async def list_models(
|
|||||||
models_raw.extend(ApiDependencies.invoker.services.model_manager.list_models(base_model, model_type))
|
models_raw.extend(ApiDependencies.invoker.services.model_manager.list_models(base_model, model_type))
|
||||||
else:
|
else:
|
||||||
models_raw = ApiDependencies.invoker.services.model_manager.list_models(None, model_type)
|
models_raw = ApiDependencies.invoker.services.model_manager.list_models(None, model_type)
|
||||||
models = parse_obj_as(ModelsList, {"models": models_raw})
|
models = models_list_adapter.validate_python({"models": models_raw})
|
||||||
return models
|
return models
|
||||||
|
|
||||||
|
|
||||||
@ -105,11 +116,14 @@ async def update_model(
|
|||||||
info.path = new_info.get("path")
|
info.path = new_info.get("path")
|
||||||
|
|
||||||
# replace empty string values with None/null to avoid phenomenon of vae: ''
|
# replace empty string values with None/null to avoid phenomenon of vae: ''
|
||||||
info_dict = info.dict()
|
info_dict = info.model_dump()
|
||||||
info_dict = {x: info_dict[x] if info_dict[x] else None for x in info_dict.keys()}
|
info_dict = {x: info_dict[x] if info_dict[x] else None for x in info_dict.keys()}
|
||||||
|
|
||||||
ApiDependencies.invoker.services.model_manager.update_model(
|
ApiDependencies.invoker.services.model_manager.update_model(
|
||||||
model_name=model_name, base_model=base_model, model_type=model_type, model_attributes=info_dict
|
model_name=model_name,
|
||||||
|
base_model=base_model,
|
||||||
|
model_type=model_type,
|
||||||
|
model_attributes=info_dict,
|
||||||
)
|
)
|
||||||
|
|
||||||
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
||||||
@ -117,7 +131,7 @@ async def update_model(
|
|||||||
base_model=base_model,
|
base_model=base_model,
|
||||||
model_type=model_type,
|
model_type=model_type,
|
||||||
)
|
)
|
||||||
model_response = parse_obj_as(UpdateModelResponse, model_raw)
|
model_response = update_models_response_adapter.validate_python(model_raw)
|
||||||
except ModelNotFoundException as e:
|
except ModelNotFoundException as e:
|
||||||
raise HTTPException(status_code=404, detail=str(e))
|
raise HTTPException(status_code=404, detail=str(e))
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
@ -159,7 +173,8 @@ async def import_model(
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
installed_models = ApiDependencies.invoker.services.model_manager.heuristic_import(
|
installed_models = ApiDependencies.invoker.services.model_manager.heuristic_import(
|
||||||
items_to_import=items_to_import, prediction_type_helper=lambda x: prediction_types.get(prediction_type)
|
items_to_import=items_to_import,
|
||||||
|
prediction_type_helper=lambda x: prediction_types.get(prediction_type),
|
||||||
)
|
)
|
||||||
info = installed_models.get(location)
|
info = installed_models.get(location)
|
||||||
|
|
||||||
@ -171,7 +186,7 @@ async def import_model(
|
|||||||
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
||||||
model_name=info.name, base_model=info.base_model, model_type=info.model_type
|
model_name=info.name, base_model=info.base_model, model_type=info.model_type
|
||||||
)
|
)
|
||||||
return parse_obj_as(ImportModelResponse, model_raw)
|
return import_models_response_adapter.validate_python(model_raw)
|
||||||
|
|
||||||
except ModelNotFoundException as e:
|
except ModelNotFoundException as e:
|
||||||
logger.error(str(e))
|
logger.error(str(e))
|
||||||
@ -205,13 +220,18 @@ async def add_model(
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
ApiDependencies.invoker.services.model_manager.add_model(
|
ApiDependencies.invoker.services.model_manager.add_model(
|
||||||
info.model_name, info.base_model, info.model_type, model_attributes=info.dict()
|
info.model_name,
|
||||||
|
info.base_model,
|
||||||
|
info.model_type,
|
||||||
|
model_attributes=info.model_dump(),
|
||||||
)
|
)
|
||||||
logger.info(f"Successfully added {info.model_name}")
|
logger.info(f"Successfully added {info.model_name}")
|
||||||
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
||||||
model_name=info.model_name, base_model=info.base_model, model_type=info.model_type
|
model_name=info.model_name,
|
||||||
|
base_model=info.base_model,
|
||||||
|
model_type=info.model_type,
|
||||||
)
|
)
|
||||||
return parse_obj_as(ImportModelResponse, model_raw)
|
return import_models_response_adapter.validate_python(model_raw)
|
||||||
except ModelNotFoundException as e:
|
except ModelNotFoundException as e:
|
||||||
logger.error(str(e))
|
logger.error(str(e))
|
||||||
raise HTTPException(status_code=404, detail=str(e))
|
raise HTTPException(status_code=404, detail=str(e))
|
||||||
@ -223,7 +243,10 @@ async def add_model(
|
|||||||
@models_router.delete(
|
@models_router.delete(
|
||||||
"/{base_model}/{model_type}/{model_name}",
|
"/{base_model}/{model_type}/{model_name}",
|
||||||
operation_id="del_model",
|
operation_id="del_model",
|
||||||
responses={204: {"description": "Model deleted successfully"}, 404: {"description": "Model not found"}},
|
responses={
|
||||||
|
204: {"description": "Model deleted successfully"},
|
||||||
|
404: {"description": "Model not found"},
|
||||||
|
},
|
||||||
status_code=204,
|
status_code=204,
|
||||||
response_model=None,
|
response_model=None,
|
||||||
)
|
)
|
||||||
@ -279,7 +302,7 @@ async def convert_model(
|
|||||||
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
||||||
model_name, base_model=base_model, model_type=model_type
|
model_name, base_model=base_model, model_type=model_type
|
||||||
)
|
)
|
||||||
response = parse_obj_as(ConvertModelResponse, model_raw)
|
response = convert_models_response_adapter.validate_python(model_raw)
|
||||||
except ModelNotFoundException as e:
|
except ModelNotFoundException as e:
|
||||||
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found: {str(e)}")
|
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found: {str(e)}")
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
@ -302,7 +325,8 @@ async def search_for_models(
|
|||||||
) -> List[pathlib.Path]:
|
) -> List[pathlib.Path]:
|
||||||
if not search_path.is_dir():
|
if not search_path.is_dir():
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=404, detail=f"The search path '{search_path}' does not exist or is not directory"
|
status_code=404,
|
||||||
|
detail=f"The search path '{search_path}' does not exist or is not directory",
|
||||||
)
|
)
|
||||||
return ApiDependencies.invoker.services.model_manager.search_for_models(search_path)
|
return ApiDependencies.invoker.services.model_manager.search_for_models(search_path)
|
||||||
|
|
||||||
@ -337,6 +361,26 @@ async def sync_to_config() -> bool:
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
# There's some weird pydantic-fastapi behaviour that requires this to be a separate class
|
||||||
|
# TODO: After a few updates, see if it works inside the route operation handler?
|
||||||
|
class MergeModelsBody(BaseModel):
|
||||||
|
model_names: List[str] = Field(description="model name", min_length=2, max_length=3)
|
||||||
|
merged_model_name: Optional[str] = Field(description="Name of destination model")
|
||||||
|
alpha: Optional[float] = Field(description="Alpha weighting strength to apply to 2d and 3d models", default=0.5)
|
||||||
|
interp: Optional[MergeInterpolationMethod] = Field(description="Interpolation method")
|
||||||
|
force: Optional[bool] = Field(
|
||||||
|
description="Force merging of models created with different versions of diffusers",
|
||||||
|
default=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
merge_dest_directory: Optional[str] = Field(
|
||||||
|
description="Save the merged model to the designated directory (with 'merged_model_name' appended)",
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
model_config = ConfigDict(protected_namespaces=())
|
||||||
|
|
||||||
|
|
||||||
@models_router.put(
|
@models_router.put(
|
||||||
"/merge/{base_model}",
|
"/merge/{base_model}",
|
||||||
operation_id="merge_models",
|
operation_id="merge_models",
|
||||||
@ -349,31 +393,23 @@ async def sync_to_config() -> bool:
|
|||||||
response_model=MergeModelResponse,
|
response_model=MergeModelResponse,
|
||||||
)
|
)
|
||||||
async def merge_models(
|
async def merge_models(
|
||||||
|
body: Annotated[MergeModelsBody, Body(description="Model configuration", embed=True)],
|
||||||
base_model: BaseModelType = Path(description="Base model"),
|
base_model: BaseModelType = Path(description="Base model"),
|
||||||
model_names: List[str] = Body(description="model name", min_items=2, max_items=3),
|
|
||||||
merged_model_name: Optional[str] = Body(description="Name of destination model"),
|
|
||||||
alpha: Optional[float] = Body(description="Alpha weighting strength to apply to 2d and 3d models", default=0.5),
|
|
||||||
interp: Optional[MergeInterpolationMethod] = Body(description="Interpolation method"),
|
|
||||||
force: Optional[bool] = Body(
|
|
||||||
description="Force merging of models created with different versions of diffusers", default=False
|
|
||||||
),
|
|
||||||
merge_dest_directory: Optional[str] = Body(
|
|
||||||
description="Save the merged model to the designated directory (with 'merged_model_name' appended)",
|
|
||||||
default=None,
|
|
||||||
),
|
|
||||||
) -> MergeModelResponse:
|
) -> MergeModelResponse:
|
||||||
"""Convert a checkpoint model into a diffusers model"""
|
"""Convert a checkpoint model into a diffusers model"""
|
||||||
logger = ApiDependencies.invoker.services.logger
|
logger = ApiDependencies.invoker.services.logger
|
||||||
try:
|
try:
|
||||||
logger.info(f"Merging models: {model_names} into {merge_dest_directory or '<MODELS>'}/{merged_model_name}")
|
logger.info(
|
||||||
dest = pathlib.Path(merge_dest_directory) if merge_dest_directory else None
|
f"Merging models: {body.model_names} into {body.merge_dest_directory or '<MODELS>'}/{body.merged_model_name}"
|
||||||
|
)
|
||||||
|
dest = pathlib.Path(body.merge_dest_directory) if body.merge_dest_directory else None
|
||||||
result = ApiDependencies.invoker.services.model_manager.merge_models(
|
result = ApiDependencies.invoker.services.model_manager.merge_models(
|
||||||
model_names,
|
model_names=body.model_names,
|
||||||
base_model,
|
base_model=base_model,
|
||||||
merged_model_name=merged_model_name or "+".join(model_names),
|
merged_model_name=body.merged_model_name or "+".join(body.model_names),
|
||||||
alpha=alpha,
|
alpha=body.alpha,
|
||||||
interp=interp,
|
interp=body.interp,
|
||||||
force=force,
|
force=body.force,
|
||||||
merge_dest_directory=dest,
|
merge_dest_directory=dest,
|
||||||
)
|
)
|
||||||
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
||||||
@ -381,9 +417,12 @@ async def merge_models(
|
|||||||
base_model=base_model,
|
base_model=base_model,
|
||||||
model_type=ModelType.Main,
|
model_type=ModelType.Main,
|
||||||
)
|
)
|
||||||
response = parse_obj_as(ConvertModelResponse, model_raw)
|
response = convert_models_response_adapter.validate_python(model_raw)
|
||||||
except ModelNotFoundException:
|
except ModelNotFoundException:
|
||||||
raise HTTPException(status_code=404, detail=f"One or more of the models '{model_names}' not found")
|
raise HTTPException(
|
||||||
|
status_code=404,
|
||||||
|
detail=f"One or more of the models '{body.model_names}' not found",
|
||||||
|
)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
raise HTTPException(status_code=400, detail=str(e))
|
raise HTTPException(status_code=400, detail=str(e))
|
||||||
return response
|
return response
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from typing import Optional
|
from typing import Optional, Union
|
||||||
|
|
||||||
from dynamicprompts.generators import CombinatorialPromptGenerator, RandomPromptGenerator
|
from dynamicprompts.generators import CombinatorialPromptGenerator, RandomPromptGenerator
|
||||||
from fastapi import Body
|
from fastapi import Body
|
||||||
@ -27,6 +27,7 @@ async def parse_dynamicprompts(
|
|||||||
combinatorial: bool = Body(default=True, description="Whether to use the combinatorial generator"),
|
combinatorial: bool = Body(default=True, description="Whether to use the combinatorial generator"),
|
||||||
) -> DynamicPromptsResponse:
|
) -> DynamicPromptsResponse:
|
||||||
"""Creates a batch process"""
|
"""Creates a batch process"""
|
||||||
|
generator: Union[RandomPromptGenerator, CombinatorialPromptGenerator]
|
||||||
try:
|
try:
|
||||||
error: Optional[str] = None
|
error: Optional[str] = None
|
||||||
if combinatorial:
|
if combinatorial:
|
||||||
|
@ -22,7 +22,7 @@ if True: # hack to make flake8 happy with imports coming after setting up the c
|
|||||||
from fastapi.staticfiles import StaticFiles
|
from fastapi.staticfiles import StaticFiles
|
||||||
from fastapi_events.handlers.local import local_handler
|
from fastapi_events.handlers.local import local_handler
|
||||||
from fastapi_events.middleware import EventHandlerASGIMiddleware
|
from fastapi_events.middleware import EventHandlerASGIMiddleware
|
||||||
from pydantic.schema import schema
|
from pydantic.json_schema import models_json_schema
|
||||||
|
|
||||||
# noinspection PyUnresolvedReferences
|
# noinspection PyUnresolvedReferences
|
||||||
import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import)
|
import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import)
|
||||||
@ -31,7 +31,7 @@ if True: # hack to make flake8 happy with imports coming after setting up the c
|
|||||||
|
|
||||||
from ..backend.util.logging import InvokeAILogger
|
from ..backend.util.logging import InvokeAILogger
|
||||||
from .api.dependencies import ApiDependencies
|
from .api.dependencies import ApiDependencies
|
||||||
from .api.routers import app_info, board_images, boards, images, models, session_queue, sessions, utilities
|
from .api.routers import app_info, board_images, boards, images, models, session_queue, utilities
|
||||||
from .api.sockets import SocketIO
|
from .api.sockets import SocketIO
|
||||||
from .invocations.baseinvocation import BaseInvocation, UIConfigBase, _InputField, _OutputField
|
from .invocations.baseinvocation import BaseInvocation, UIConfigBase, _InputField, _OutputField
|
||||||
|
|
||||||
@ -51,7 +51,7 @@ mimetypes.add_type("text/css", ".css")
|
|||||||
|
|
||||||
# Create the app
|
# Create the app
|
||||||
# TODO: create this all in a method so configuration/etc. can be passed in?
|
# TODO: create this all in a method so configuration/etc. can be passed in?
|
||||||
app = FastAPI(title="Invoke AI", docs_url=None, redoc_url=None)
|
app = FastAPI(title="Invoke AI", docs_url=None, redoc_url=None, separate_input_output_schemas=False)
|
||||||
|
|
||||||
# Add event handler
|
# Add event handler
|
||||||
event_handler_id: int = id(app)
|
event_handler_id: int = id(app)
|
||||||
@ -63,18 +63,18 @@ app.add_middleware(
|
|||||||
|
|
||||||
socket_io = SocketIO(app)
|
socket_io = SocketIO(app)
|
||||||
|
|
||||||
|
app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=app_config.allow_origins,
|
||||||
|
allow_credentials=app_config.allow_credentials,
|
||||||
|
allow_methods=app_config.allow_methods,
|
||||||
|
allow_headers=app_config.allow_headers,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# Add startup event to load dependencies
|
# Add startup event to load dependencies
|
||||||
@app.on_event("startup")
|
@app.on_event("startup")
|
||||||
async def startup_event():
|
async def startup_event():
|
||||||
app.add_middleware(
|
|
||||||
CORSMiddleware,
|
|
||||||
allow_origins=app_config.allow_origins,
|
|
||||||
allow_credentials=app_config.allow_credentials,
|
|
||||||
allow_methods=app_config.allow_methods,
|
|
||||||
allow_headers=app_config.allow_headers,
|
|
||||||
)
|
|
||||||
|
|
||||||
ApiDependencies.initialize(config=app_config, event_handler_id=event_handler_id, logger=logger)
|
ApiDependencies.initialize(config=app_config, event_handler_id=event_handler_id, logger=logger)
|
||||||
|
|
||||||
|
|
||||||
@ -85,12 +85,7 @@ async def shutdown_event():
|
|||||||
|
|
||||||
|
|
||||||
# Include all routers
|
# Include all routers
|
||||||
# TODO: REMOVE
|
# app.include_router(sessions.session_router, prefix="/api")
|
||||||
# app.include_router(
|
|
||||||
# invocation.invocation_router,
|
|
||||||
# prefix = '/api')
|
|
||||||
|
|
||||||
app.include_router(sessions.session_router, prefix="/api")
|
|
||||||
|
|
||||||
app.include_router(utilities.utilities_router, prefix="/api")
|
app.include_router(utilities.utilities_router, prefix="/api")
|
||||||
|
|
||||||
@ -117,6 +112,7 @@ def custom_openapi():
|
|||||||
description="An API for invoking AI image operations",
|
description="An API for invoking AI image operations",
|
||||||
version="1.0.0",
|
version="1.0.0",
|
||||||
routes=app.routes,
|
routes=app.routes,
|
||||||
|
separate_input_output_schemas=False, # https://fastapi.tiangolo.com/how-to/separate-openapi-schemas/
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add all outputs
|
# Add all outputs
|
||||||
@ -127,29 +123,32 @@ def custom_openapi():
|
|||||||
output_type = signature(invoker.invoke).return_annotation
|
output_type = signature(invoker.invoke).return_annotation
|
||||||
output_types.add(output_type)
|
output_types.add(output_type)
|
||||||
|
|
||||||
output_schemas = schema(output_types, ref_prefix="#/components/schemas/")
|
output_schemas = models_json_schema(
|
||||||
for schema_key, output_schema in output_schemas["definitions"].items():
|
models=[(o, "serialization") for o in output_types], ref_template="#/components/schemas/{model}"
|
||||||
output_schema["class"] = "output"
|
)
|
||||||
openapi_schema["components"]["schemas"][schema_key] = output_schema
|
for schema_key, output_schema in output_schemas[1]["$defs"].items():
|
||||||
|
|
||||||
# TODO: note that we assume the schema_key here is the TYPE.__name__
|
# TODO: note that we assume the schema_key here is the TYPE.__name__
|
||||||
# This could break in some cases, figure out a better way to do it
|
# This could break in some cases, figure out a better way to do it
|
||||||
output_type_titles[schema_key] = output_schema["title"]
|
output_type_titles[schema_key] = output_schema["title"]
|
||||||
|
|
||||||
# Add Node Editor UI helper schemas
|
# Add Node Editor UI helper schemas
|
||||||
ui_config_schemas = schema([UIConfigBase, _InputField, _OutputField], ref_prefix="#/components/schemas/")
|
ui_config_schemas = models_json_schema(
|
||||||
for schema_key, ui_config_schema in ui_config_schemas["definitions"].items():
|
[(UIConfigBase, "serialization"), (_InputField, "serialization"), (_OutputField, "serialization")],
|
||||||
|
ref_template="#/components/schemas/{model}",
|
||||||
|
)
|
||||||
|
for schema_key, ui_config_schema in ui_config_schemas[1]["$defs"].items():
|
||||||
openapi_schema["components"]["schemas"][schema_key] = ui_config_schema
|
openapi_schema["components"]["schemas"][schema_key] = ui_config_schema
|
||||||
|
|
||||||
# Add a reference to the output type to additionalProperties of the invoker schema
|
# Add a reference to the output type to additionalProperties of the invoker schema
|
||||||
for invoker in all_invocations:
|
for invoker in all_invocations:
|
||||||
invoker_name = invoker.__name__
|
invoker_name = invoker.__name__
|
||||||
output_type = signature(invoker.invoke).return_annotation
|
output_type = signature(obj=invoker.invoke).return_annotation
|
||||||
output_type_title = output_type_titles[output_type.__name__]
|
output_type_title = output_type_titles[output_type.__name__]
|
||||||
invoker_schema = openapi_schema["components"]["schemas"][invoker_name]
|
invoker_schema = openapi_schema["components"]["schemas"][f"{invoker_name}"]
|
||||||
outputs_ref = {"$ref": f"#/components/schemas/{output_type_title}"}
|
outputs_ref = {"$ref": f"#/components/schemas/{output_type_title}"}
|
||||||
invoker_schema["output"] = outputs_ref
|
invoker_schema["output"] = outputs_ref
|
||||||
invoker_schema["class"] = "invocation"
|
invoker_schema["class"] = "invocation"
|
||||||
|
openapi_schema["components"]["schemas"][f"{output_type_title}"]["class"] = "output"
|
||||||
|
|
||||||
from invokeai.backend.model_management.models import get_model_config_enums
|
from invokeai.backend.model_management.models import get_model_config_enums
|
||||||
|
|
||||||
@ -172,7 +171,7 @@ def custom_openapi():
|
|||||||
return app.openapi_schema
|
return app.openapi_schema
|
||||||
|
|
||||||
|
|
||||||
app.openapi = custom_openapi
|
app.openapi = custom_openapi # type: ignore [method-assign] # this is a valid assignment
|
||||||
|
|
||||||
# Override API doc favicons
|
# Override API doc favicons
|
||||||
app.mount("/static", StaticFiles(directory=Path(web_dir.__path__[0], "static/dream_web")), name="static")
|
app.mount("/static", StaticFiles(directory=Path(web_dir.__path__[0], "static/dream_web")), name="static")
|
||||||
|
@ -24,8 +24,8 @@ def add_field_argument(command_parser, name: str, field, default_override=None):
|
|||||||
if field.default_factory is None
|
if field.default_factory is None
|
||||||
else field.default_factory()
|
else field.default_factory()
|
||||||
)
|
)
|
||||||
if get_origin(field.type_) == Literal:
|
if get_origin(field.annotation) == Literal:
|
||||||
allowed_values = get_args(field.type_)
|
allowed_values = get_args(field.annotation)
|
||||||
allowed_types = set()
|
allowed_types = set()
|
||||||
for val in allowed_values:
|
for val in allowed_values:
|
||||||
allowed_types.add(type(val))
|
allowed_types.add(type(val))
|
||||||
@ -38,15 +38,15 @@ def add_field_argument(command_parser, name: str, field, default_override=None):
|
|||||||
type=field_type,
|
type=field_type,
|
||||||
default=default,
|
default=default,
|
||||||
choices=allowed_values,
|
choices=allowed_values,
|
||||||
help=field.field_info.description,
|
help=field.description,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
command_parser.add_argument(
|
command_parser.add_argument(
|
||||||
f"--{name}",
|
f"--{name}",
|
||||||
dest=name,
|
dest=name,
|
||||||
type=field.type_,
|
type=field.annotation,
|
||||||
default=default,
|
default=default,
|
||||||
help=field.field_info.description,
|
help=field.description,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -142,7 +142,6 @@ class BaseCommand(ABC, BaseModel):
|
|||||||
"""A CLI command"""
|
"""A CLI command"""
|
||||||
|
|
||||||
# All commands must include a type name like this:
|
# All commands must include a type name like this:
|
||||||
# type: Literal['your_command_name'] = 'your_command_name'
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_all_subclasses(cls):
|
def get_all_subclasses(cls):
|
||||||
|
@ -7,28 +7,16 @@ import re
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from inspect import signature
|
from inspect import signature
|
||||||
from typing import (
|
from types import UnionType
|
||||||
TYPE_CHECKING,
|
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Iterable, Literal, Optional, Type, TypeVar, Union
|
||||||
AbstractSet,
|
|
||||||
Any,
|
|
||||||
Callable,
|
|
||||||
ClassVar,
|
|
||||||
Literal,
|
|
||||||
Mapping,
|
|
||||||
Optional,
|
|
||||||
Type,
|
|
||||||
TypeVar,
|
|
||||||
Union,
|
|
||||||
get_args,
|
|
||||||
get_type_hints,
|
|
||||||
)
|
|
||||||
|
|
||||||
import semver
|
import semver
|
||||||
from pydantic import BaseModel, Field, validator
|
from pydantic import BaseModel, ConfigDict, Field, create_model, field_validator
|
||||||
from pydantic.fields import ModelField, Undefined
|
from pydantic.fields import _Unset
|
||||||
from pydantic.typing import NoArgAnyCallable
|
from pydantic_core import PydanticUndefined
|
||||||
|
|
||||||
from invokeai.app.services.config.config_default import InvokeAIAppConfig
|
from invokeai.app.services.config.config_default import InvokeAIAppConfig
|
||||||
|
from invokeai.app.util.misc import uuid_string
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ..services.invocation_services import InvocationServices
|
from ..services.invocation_services import InvocationServices
|
||||||
@ -211,6 +199,11 @@ class _InputField(BaseModel):
|
|||||||
ui_choice_labels: Optional[dict[str, str]]
|
ui_choice_labels: Optional[dict[str, str]]
|
||||||
item_default: Optional[Any]
|
item_default: Optional[Any]
|
||||||
|
|
||||||
|
model_config = ConfigDict(
|
||||||
|
validate_assignment=True,
|
||||||
|
json_schema_serialization_defaults_required=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class _OutputField(BaseModel):
|
class _OutputField(BaseModel):
|
||||||
"""
|
"""
|
||||||
@ -224,34 +217,36 @@ class _OutputField(BaseModel):
|
|||||||
ui_type: Optional[UIType]
|
ui_type: Optional[UIType]
|
||||||
ui_order: Optional[int]
|
ui_order: Optional[int]
|
||||||
|
|
||||||
|
model_config = ConfigDict(
|
||||||
|
validate_assignment=True,
|
||||||
|
json_schema_serialization_defaults_required=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_type(klass: BaseModel) -> str:
|
||||||
|
"""Helper function to get an invocation or invocation output's type. This is the default value of the `type` field."""
|
||||||
|
return klass.model_fields["type"].default
|
||||||
|
|
||||||
|
|
||||||
def InputField(
|
def InputField(
|
||||||
*args: Any,
|
# copied from pydantic's Field
|
||||||
default: Any = Undefined,
|
default: Any = _Unset,
|
||||||
default_factory: Optional[NoArgAnyCallable] = None,
|
default_factory: Callable[[], Any] | None = _Unset,
|
||||||
alias: Optional[str] = None,
|
title: str | None = _Unset,
|
||||||
title: Optional[str] = None,
|
description: str | None = _Unset,
|
||||||
description: Optional[str] = None,
|
pattern: str | None = _Unset,
|
||||||
exclude: Optional[Union[AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any]] = None,
|
strict: bool | None = _Unset,
|
||||||
include: Optional[Union[AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any]] = None,
|
gt: float | None = _Unset,
|
||||||
const: Optional[bool] = None,
|
ge: float | None = _Unset,
|
||||||
gt: Optional[float] = None,
|
lt: float | None = _Unset,
|
||||||
ge: Optional[float] = None,
|
le: float | None = _Unset,
|
||||||
lt: Optional[float] = None,
|
multiple_of: float | None = _Unset,
|
||||||
le: Optional[float] = None,
|
allow_inf_nan: bool | None = _Unset,
|
||||||
multiple_of: Optional[float] = None,
|
max_digits: int | None = _Unset,
|
||||||
allow_inf_nan: Optional[bool] = None,
|
decimal_places: int | None = _Unset,
|
||||||
max_digits: Optional[int] = None,
|
min_length: int | None = _Unset,
|
||||||
decimal_places: Optional[int] = None,
|
max_length: int | None = _Unset,
|
||||||
min_items: Optional[int] = None,
|
# custom
|
||||||
max_items: Optional[int] = None,
|
|
||||||
unique_items: Optional[bool] = None,
|
|
||||||
min_length: Optional[int] = None,
|
|
||||||
max_length: Optional[int] = None,
|
|
||||||
allow_mutation: bool = True,
|
|
||||||
regex: Optional[str] = None,
|
|
||||||
discriminator: Optional[str] = None,
|
|
||||||
repr: bool = True,
|
|
||||||
input: Input = Input.Any,
|
input: Input = Input.Any,
|
||||||
ui_type: Optional[UIType] = None,
|
ui_type: Optional[UIType] = None,
|
||||||
ui_component: Optional[UIComponent] = None,
|
ui_component: Optional[UIComponent] = None,
|
||||||
@ -259,7 +254,6 @@ def InputField(
|
|||||||
ui_order: Optional[int] = None,
|
ui_order: Optional[int] = None,
|
||||||
ui_choice_labels: Optional[dict[str, str]] = None,
|
ui_choice_labels: Optional[dict[str, str]] = None,
|
||||||
item_default: Optional[Any] = None,
|
item_default: Optional[Any] = None,
|
||||||
**kwargs: Any,
|
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""
|
"""
|
||||||
Creates an input field for an invocation.
|
Creates an input field for an invocation.
|
||||||
@ -289,18 +283,26 @@ def InputField(
|
|||||||
: param int ui_order: [None] Specifies the order in which this field should be rendered in the UI. \
|
: param int ui_order: [None] Specifies the order in which this field should be rendered in the UI. \
|
||||||
|
|
||||||
: param bool item_default: [None] Specifies the default item value, if this is a collection input. \
|
: param bool item_default: [None] Specifies the default item value, if this is a collection input. \
|
||||||
Ignored for non-collection fields..
|
Ignored for non-collection fields.
|
||||||
"""
|
"""
|
||||||
return Field(
|
|
||||||
*args,
|
json_schema_extra_: dict[str, Any] = dict(
|
||||||
|
input=input,
|
||||||
|
ui_type=ui_type,
|
||||||
|
ui_component=ui_component,
|
||||||
|
ui_hidden=ui_hidden,
|
||||||
|
ui_order=ui_order,
|
||||||
|
item_default=item_default,
|
||||||
|
ui_choice_labels=ui_choice_labels,
|
||||||
|
)
|
||||||
|
|
||||||
|
field_args = dict(
|
||||||
default=default,
|
default=default,
|
||||||
default_factory=default_factory,
|
default_factory=default_factory,
|
||||||
alias=alias,
|
|
||||||
title=title,
|
title=title,
|
||||||
description=description,
|
description=description,
|
||||||
exclude=exclude,
|
pattern=pattern,
|
||||||
include=include,
|
strict=strict,
|
||||||
const=const,
|
|
||||||
gt=gt,
|
gt=gt,
|
||||||
ge=ge,
|
ge=ge,
|
||||||
lt=lt,
|
lt=lt,
|
||||||
@ -309,57 +311,92 @@ def InputField(
|
|||||||
allow_inf_nan=allow_inf_nan,
|
allow_inf_nan=allow_inf_nan,
|
||||||
max_digits=max_digits,
|
max_digits=max_digits,
|
||||||
decimal_places=decimal_places,
|
decimal_places=decimal_places,
|
||||||
min_items=min_items,
|
|
||||||
max_items=max_items,
|
|
||||||
unique_items=unique_items,
|
|
||||||
min_length=min_length,
|
min_length=min_length,
|
||||||
max_length=max_length,
|
max_length=max_length,
|
||||||
allow_mutation=allow_mutation,
|
)
|
||||||
regex=regex,
|
|
||||||
discriminator=discriminator,
|
"""
|
||||||
repr=repr,
|
Invocation definitions have their fields typed correctly for their `invoke()` functions.
|
||||||
input=input,
|
This typing is often more specific than the actual invocation definition requires, because
|
||||||
ui_type=ui_type,
|
fields may have values provided only by connections.
|
||||||
ui_component=ui_component,
|
|
||||||
ui_hidden=ui_hidden,
|
For example, consider an ResizeImageInvocation with an `image: ImageField` field.
|
||||||
ui_order=ui_order,
|
|
||||||
item_default=item_default,
|
`image` is required during the call to `invoke()`, but when the python class is instantiated,
|
||||||
ui_choice_labels=ui_choice_labels,
|
the field may not be present. This is fine, because that image field will be provided by a
|
||||||
**kwargs,
|
an ancestor node that outputs the image.
|
||||||
|
|
||||||
|
So we'd like to type that `image` field as `Optional[ImageField]`. If we do that, however, then
|
||||||
|
we need to handle a lot of extra logic in the `invoke()` function to check if the field has a
|
||||||
|
value or not. This is very tedious.
|
||||||
|
|
||||||
|
Ideally, the invocation definition would be able to specify that the field is required during
|
||||||
|
invocation, but optional during instantiation. So the field would be typed as `image: ImageField`,
|
||||||
|
but when calling the `invoke()` function, we raise an error if the field is not present.
|
||||||
|
|
||||||
|
To do this, we need to do a bit of fanagling to make the pydantic field optional, and then do
|
||||||
|
extra validation when calling `invoke()`.
|
||||||
|
|
||||||
|
There is some additional logic here to cleaning create the pydantic field via the wrapper.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Filter out field args not provided
|
||||||
|
provided_args = {k: v for (k, v) in field_args.items() if v is not PydanticUndefined}
|
||||||
|
|
||||||
|
if (default is not PydanticUndefined) and (default_factory is not PydanticUndefined):
|
||||||
|
raise ValueError("Cannot specify both default and default_factory")
|
||||||
|
|
||||||
|
# because we are manually making fields optional, we need to store the original required bool for reference later
|
||||||
|
if default is PydanticUndefined and default_factory is PydanticUndefined:
|
||||||
|
json_schema_extra_.update(dict(orig_required=True))
|
||||||
|
else:
|
||||||
|
json_schema_extra_.update(dict(orig_required=False))
|
||||||
|
|
||||||
|
# make Input.Any and Input.Connection fields optional, providing None as a default if the field doesn't already have one
|
||||||
|
if (input is Input.Any or input is Input.Connection) and default_factory is PydanticUndefined:
|
||||||
|
default_ = None if default is PydanticUndefined else default
|
||||||
|
provided_args.update(dict(default=default_))
|
||||||
|
if default is not PydanticUndefined:
|
||||||
|
# before invoking, we'll grab the original default value and set it on the field if the field wasn't provided a value
|
||||||
|
json_schema_extra_.update(dict(default=default))
|
||||||
|
json_schema_extra_.update(dict(orig_default=default))
|
||||||
|
elif default is not PydanticUndefined and default_factory is PydanticUndefined:
|
||||||
|
default_ = default
|
||||||
|
provided_args.update(dict(default=default_))
|
||||||
|
json_schema_extra_.update(dict(orig_default=default_))
|
||||||
|
elif default_factory is not PydanticUndefined:
|
||||||
|
provided_args.update(dict(default_factory=default_factory))
|
||||||
|
# TODO: cannot serialize default_factory...
|
||||||
|
# json_schema_extra_.update(dict(orig_default_factory=default_factory))
|
||||||
|
|
||||||
|
return Field(
|
||||||
|
**provided_args,
|
||||||
|
json_schema_extra=json_schema_extra_,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def OutputField(
|
def OutputField(
|
||||||
*args: Any,
|
# copied from pydantic's Field
|
||||||
default: Any = Undefined,
|
default: Any = _Unset,
|
||||||
default_factory: Optional[NoArgAnyCallable] = None,
|
default_factory: Callable[[], Any] | None = _Unset,
|
||||||
alias: Optional[str] = None,
|
title: str | None = _Unset,
|
||||||
title: Optional[str] = None,
|
description: str | None = _Unset,
|
||||||
description: Optional[str] = None,
|
pattern: str | None = _Unset,
|
||||||
exclude: Optional[Union[AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any]] = None,
|
strict: bool | None = _Unset,
|
||||||
include: Optional[Union[AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any]] = None,
|
gt: float | None = _Unset,
|
||||||
const: Optional[bool] = None,
|
ge: float | None = _Unset,
|
||||||
gt: Optional[float] = None,
|
lt: float | None = _Unset,
|
||||||
ge: Optional[float] = None,
|
le: float | None = _Unset,
|
||||||
lt: Optional[float] = None,
|
multiple_of: float | None = _Unset,
|
||||||
le: Optional[float] = None,
|
allow_inf_nan: bool | None = _Unset,
|
||||||
multiple_of: Optional[float] = None,
|
max_digits: int | None = _Unset,
|
||||||
allow_inf_nan: Optional[bool] = None,
|
decimal_places: int | None = _Unset,
|
||||||
max_digits: Optional[int] = None,
|
min_length: int | None = _Unset,
|
||||||
decimal_places: Optional[int] = None,
|
max_length: int | None = _Unset,
|
||||||
min_items: Optional[int] = None,
|
# custom
|
||||||
max_items: Optional[int] = None,
|
|
||||||
unique_items: Optional[bool] = None,
|
|
||||||
min_length: Optional[int] = None,
|
|
||||||
max_length: Optional[int] = None,
|
|
||||||
allow_mutation: bool = True,
|
|
||||||
regex: Optional[str] = None,
|
|
||||||
discriminator: Optional[str] = None,
|
|
||||||
repr: bool = True,
|
|
||||||
ui_type: Optional[UIType] = None,
|
ui_type: Optional[UIType] = None,
|
||||||
ui_hidden: bool = False,
|
ui_hidden: bool = False,
|
||||||
ui_order: Optional[int] = None,
|
ui_order: Optional[int] = None,
|
||||||
**kwargs: Any,
|
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""
|
"""
|
||||||
Creates an output field for an invocation output.
|
Creates an output field for an invocation output.
|
||||||
@ -379,15 +416,12 @@ def OutputField(
|
|||||||
: param int ui_order: [None] Specifies the order in which this field should be rendered in the UI. \
|
: param int ui_order: [None] Specifies the order in which this field should be rendered in the UI. \
|
||||||
"""
|
"""
|
||||||
return Field(
|
return Field(
|
||||||
*args,
|
|
||||||
default=default,
|
default=default,
|
||||||
default_factory=default_factory,
|
default_factory=default_factory,
|
||||||
alias=alias,
|
|
||||||
title=title,
|
title=title,
|
||||||
description=description,
|
description=description,
|
||||||
exclude=exclude,
|
pattern=pattern,
|
||||||
include=include,
|
strict=strict,
|
||||||
const=const,
|
|
||||||
gt=gt,
|
gt=gt,
|
||||||
ge=ge,
|
ge=ge,
|
||||||
lt=lt,
|
lt=lt,
|
||||||
@ -396,19 +430,13 @@ def OutputField(
|
|||||||
allow_inf_nan=allow_inf_nan,
|
allow_inf_nan=allow_inf_nan,
|
||||||
max_digits=max_digits,
|
max_digits=max_digits,
|
||||||
decimal_places=decimal_places,
|
decimal_places=decimal_places,
|
||||||
min_items=min_items,
|
|
||||||
max_items=max_items,
|
|
||||||
unique_items=unique_items,
|
|
||||||
min_length=min_length,
|
min_length=min_length,
|
||||||
max_length=max_length,
|
max_length=max_length,
|
||||||
allow_mutation=allow_mutation,
|
json_schema_extra=dict(
|
||||||
regex=regex,
|
ui_type=ui_type,
|
||||||
discriminator=discriminator,
|
ui_hidden=ui_hidden,
|
||||||
repr=repr,
|
ui_order=ui_order,
|
||||||
ui_type=ui_type,
|
),
|
||||||
ui_hidden=ui_hidden,
|
|
||||||
ui_order=ui_order,
|
|
||||||
**kwargs,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -422,7 +450,13 @@ class UIConfigBase(BaseModel):
|
|||||||
title: Optional[str] = Field(default=None, description="The node's display name")
|
title: Optional[str] = Field(default=None, description="The node's display name")
|
||||||
category: Optional[str] = Field(default=None, description="The node's category")
|
category: Optional[str] = Field(default=None, description="The node's category")
|
||||||
version: Optional[str] = Field(
|
version: Optional[str] = Field(
|
||||||
default=None, description='The node\'s version. Should be a valid semver string e.g. "1.0.0" or "3.8.13".'
|
default=None,
|
||||||
|
description='The node\'s version. Should be a valid semver string e.g. "1.0.0" or "3.8.13".',
|
||||||
|
)
|
||||||
|
|
||||||
|
model_config = ConfigDict(
|
||||||
|
validate_assignment=True,
|
||||||
|
json_schema_serialization_defaults_required=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -457,23 +491,38 @@ class BaseInvocationOutput(BaseModel):
|
|||||||
All invocation outputs must use the `@invocation_output` decorator to provide their unique type.
|
All invocation outputs must use the `@invocation_output` decorator to provide their unique type.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@classmethod
|
_output_classes: ClassVar[set[BaseInvocationOutput]] = set()
|
||||||
def get_all_subclasses_tuple(cls):
|
|
||||||
subclasses = []
|
|
||||||
toprocess = [cls]
|
|
||||||
while len(toprocess) > 0:
|
|
||||||
next = toprocess.pop(0)
|
|
||||||
next_subclasses = next.__subclasses__()
|
|
||||||
subclasses.extend(next_subclasses)
|
|
||||||
toprocess.extend(next_subclasses)
|
|
||||||
return tuple(subclasses)
|
|
||||||
|
|
||||||
class Config:
|
@classmethod
|
||||||
@staticmethod
|
def register_output(cls, output: BaseInvocationOutput) -> None:
|
||||||
def schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None:
|
cls._output_classes.add(output)
|
||||||
if "required" not in schema or not isinstance(schema["required"], list):
|
|
||||||
schema["required"] = list()
|
@classmethod
|
||||||
schema["required"].extend(["type"])
|
def get_outputs(cls) -> Iterable[BaseInvocationOutput]:
|
||||||
|
return cls._output_classes
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_outputs_union(cls) -> UnionType:
|
||||||
|
outputs_union = Union[tuple(cls._output_classes)] # type: ignore [valid-type]
|
||||||
|
return outputs_union # type: ignore [return-value]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_output_types(cls) -> Iterable[str]:
|
||||||
|
return map(lambda i: get_type(i), BaseInvocationOutput.get_outputs())
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None:
|
||||||
|
# Because we use a pydantic Literal field with default value for the invocation type,
|
||||||
|
# it will be typed as optional in the OpenAPI schema. Make it required manually.
|
||||||
|
if "required" not in schema or not isinstance(schema["required"], list):
|
||||||
|
schema["required"] = list()
|
||||||
|
schema["required"].extend(["type"])
|
||||||
|
|
||||||
|
model_config = ConfigDict(
|
||||||
|
validate_assignment=True,
|
||||||
|
json_schema_serialization_defaults_required=True,
|
||||||
|
json_schema_extra=json_schema_extra,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class RequiredConnectionException(Exception):
|
class RequiredConnectionException(Exception):
|
||||||
@ -498,104 +547,91 @@ class BaseInvocation(ABC, BaseModel):
|
|||||||
All invocations must use the `@invocation` decorator to provide their unique type.
|
All invocations must use the `@invocation` decorator to provide their unique type.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
_invocation_classes: ClassVar[set[BaseInvocation]] = set()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_all_subclasses(cls):
|
def register_invocation(cls, invocation: BaseInvocation) -> None:
|
||||||
|
cls._invocation_classes.add(invocation)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_invocations_union(cls) -> UnionType:
|
||||||
|
invocations_union = Union[tuple(cls._invocation_classes)] # type: ignore [valid-type]
|
||||||
|
return invocations_union # type: ignore [return-value]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_invocations(cls) -> Iterable[BaseInvocation]:
|
||||||
app_config = InvokeAIAppConfig.get_config()
|
app_config = InvokeAIAppConfig.get_config()
|
||||||
subclasses = []
|
allowed_invocations: set[BaseInvocation] = set()
|
||||||
toprocess = [cls]
|
for sc in cls._invocation_classes:
|
||||||
while len(toprocess) > 0:
|
invocation_type = get_type(sc)
|
||||||
next = toprocess.pop(0)
|
|
||||||
next_subclasses = next.__subclasses__()
|
|
||||||
subclasses.extend(next_subclasses)
|
|
||||||
toprocess.extend(next_subclasses)
|
|
||||||
allowed_invocations = []
|
|
||||||
for sc in subclasses:
|
|
||||||
is_in_allowlist = (
|
is_in_allowlist = (
|
||||||
sc.__fields__.get("type").default in app_config.allow_nodes
|
invocation_type in app_config.allow_nodes if isinstance(app_config.allow_nodes, list) else True
|
||||||
if isinstance(app_config.allow_nodes, list)
|
|
||||||
else True
|
|
||||||
)
|
)
|
||||||
|
|
||||||
is_in_denylist = (
|
is_in_denylist = (
|
||||||
sc.__fields__.get("type").default in app_config.deny_nodes
|
invocation_type in app_config.deny_nodes if isinstance(app_config.deny_nodes, list) else False
|
||||||
if isinstance(app_config.deny_nodes, list)
|
|
||||||
else False
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if is_in_allowlist and not is_in_denylist:
|
if is_in_allowlist and not is_in_denylist:
|
||||||
allowed_invocations.append(sc)
|
allowed_invocations.add(sc)
|
||||||
return allowed_invocations
|
return allowed_invocations
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_invocations(cls):
|
def get_invocations_map(cls) -> dict[str, BaseInvocation]:
|
||||||
return tuple(BaseInvocation.get_all_subclasses())
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_invocations_map(cls):
|
|
||||||
# Get the type strings out of the literals and into a dictionary
|
# Get the type strings out of the literals and into a dictionary
|
||||||
return dict(
|
return dict(
|
||||||
map(
|
map(
|
||||||
lambda t: (get_args(get_type_hints(t)["type"])[0], t),
|
lambda i: (get_type(i), i),
|
||||||
BaseInvocation.get_all_subclasses(),
|
BaseInvocation.get_invocations(),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_output_type(cls):
|
def get_invocation_types(cls) -> Iterable[str]:
|
||||||
|
return map(lambda i: get_type(i), BaseInvocation.get_invocations())
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_output_type(cls) -> BaseInvocationOutput:
|
||||||
return signature(cls.invoke).return_annotation
|
return signature(cls.invoke).return_annotation
|
||||||
|
|
||||||
class Config:
|
@staticmethod
|
||||||
validate_assignment = True
|
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None:
|
||||||
validate_all = True
|
# Add the various UI-facing attributes to the schema. These are used to build the invocation templates.
|
||||||
|
uiconfig = getattr(model_class, "UIConfig", None)
|
||||||
@staticmethod
|
if uiconfig and hasattr(uiconfig, "title"):
|
||||||
def schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None:
|
schema["title"] = uiconfig.title
|
||||||
uiconfig = getattr(model_class, "UIConfig", None)
|
if uiconfig and hasattr(uiconfig, "tags"):
|
||||||
if uiconfig and hasattr(uiconfig, "title"):
|
schema["tags"] = uiconfig.tags
|
||||||
schema["title"] = uiconfig.title
|
if uiconfig and hasattr(uiconfig, "category"):
|
||||||
if uiconfig and hasattr(uiconfig, "tags"):
|
schema["category"] = uiconfig.category
|
||||||
schema["tags"] = uiconfig.tags
|
if uiconfig and hasattr(uiconfig, "version"):
|
||||||
if uiconfig and hasattr(uiconfig, "category"):
|
schema["version"] = uiconfig.version
|
||||||
schema["category"] = uiconfig.category
|
if "required" not in schema or not isinstance(schema["required"], list):
|
||||||
if uiconfig and hasattr(uiconfig, "version"):
|
schema["required"] = list()
|
||||||
schema["version"] = uiconfig.version
|
schema["required"].extend(["type", "id"])
|
||||||
if "required" not in schema or not isinstance(schema["required"], list):
|
|
||||||
schema["required"] = list()
|
|
||||||
schema["required"].extend(["type", "id"])
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def invoke(self, context: InvocationContext) -> BaseInvocationOutput:
|
def invoke(self, context: InvocationContext) -> BaseInvocationOutput:
|
||||||
"""Invoke with provided context and return outputs."""
|
"""Invoke with provided context and return outputs."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def __init__(self, **data):
|
|
||||||
# nodes may have required fields, that can accept input from connections
|
|
||||||
# on instantiation of the model, we need to exclude these from validation
|
|
||||||
restore = dict()
|
|
||||||
try:
|
|
||||||
field_names = list(self.__fields__.keys())
|
|
||||||
for field_name in field_names:
|
|
||||||
# if the field is required and may get its value from a connection, exclude it from validation
|
|
||||||
field = self.__fields__[field_name]
|
|
||||||
_input = field.field_info.extra.get("input", None)
|
|
||||||
if _input in [Input.Connection, Input.Any] and field.required:
|
|
||||||
if field_name not in data:
|
|
||||||
restore[field_name] = self.__fields__.pop(field_name)
|
|
||||||
# instantiate the node, which will validate the data
|
|
||||||
super().__init__(**data)
|
|
||||||
finally:
|
|
||||||
# restore the removed fields
|
|
||||||
for field_name, field in restore.items():
|
|
||||||
self.__fields__[field_name] = field
|
|
||||||
|
|
||||||
def invoke_internal(self, context: InvocationContext) -> BaseInvocationOutput:
|
def invoke_internal(self, context: InvocationContext) -> BaseInvocationOutput:
|
||||||
for field_name, field in self.__fields__.items():
|
for field_name, field in self.model_fields.items():
|
||||||
_input = field.field_info.extra.get("input", None)
|
if not field.json_schema_extra or callable(field.json_schema_extra):
|
||||||
if field.required and not hasattr(self, field_name):
|
# something has gone terribly awry, we should always have this and it should be a dict
|
||||||
if _input == Input.Connection:
|
continue
|
||||||
raise RequiredConnectionException(self.__fields__["type"].default, field_name)
|
|
||||||
elif _input == Input.Any:
|
# Here we handle the case where the field is optional in the pydantic class, but required
|
||||||
raise MissingInputException(self.__fields__["type"].default, field_name)
|
# in the `invoke()` method.
|
||||||
|
|
||||||
|
orig_default = field.json_schema_extra.get("orig_default", PydanticUndefined)
|
||||||
|
orig_required = field.json_schema_extra.get("orig_required", True)
|
||||||
|
input_ = field.json_schema_extra.get("input", None)
|
||||||
|
if orig_default is not PydanticUndefined and not hasattr(self, field_name):
|
||||||
|
setattr(self, field_name, orig_default)
|
||||||
|
if orig_required and orig_default is PydanticUndefined and getattr(self, field_name) is None:
|
||||||
|
if input_ == Input.Connection:
|
||||||
|
raise RequiredConnectionException(self.model_fields["type"].default, field_name)
|
||||||
|
elif input_ == Input.Any:
|
||||||
|
raise MissingInputException(self.model_fields["type"].default, field_name)
|
||||||
|
|
||||||
# skip node cache codepath if it's disabled
|
# skip node cache codepath if it's disabled
|
||||||
if context.services.configuration.node_cache_size == 0:
|
if context.services.configuration.node_cache_size == 0:
|
||||||
@ -618,23 +654,31 @@ class BaseInvocation(ABC, BaseModel):
|
|||||||
return self.invoke(context)
|
return self.invoke(context)
|
||||||
|
|
||||||
def get_type(self) -> str:
|
def get_type(self) -> str:
|
||||||
return self.__fields__["type"].default
|
return self.model_fields["type"].default
|
||||||
|
|
||||||
id: str = Field(
|
id: str = Field(
|
||||||
description="The id of this instance of an invocation. Must be unique among all instances of invocations."
|
default_factory=uuid_string,
|
||||||
|
description="The id of this instance of an invocation. Must be unique among all instances of invocations.",
|
||||||
)
|
)
|
||||||
is_intermediate: bool = InputField(
|
is_intermediate: Optional[bool] = Field(
|
||||||
default=False, description="Whether or not this is an intermediate invocation.", ui_type=UIType.IsIntermediate
|
default=False,
|
||||||
|
description="Whether or not this is an intermediate invocation.",
|
||||||
|
json_schema_extra=dict(ui_type=UIType.IsIntermediate),
|
||||||
)
|
)
|
||||||
workflow: Optional[str] = InputField(
|
workflow: Optional[str] = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="The workflow to save with the image",
|
description="The workflow to save with the image",
|
||||||
ui_type=UIType.WorkflowField,
|
json_schema_extra=dict(ui_type=UIType.WorkflowField),
|
||||||
|
)
|
||||||
|
use_cache: Optional[bool] = Field(
|
||||||
|
default=True,
|
||||||
|
description="Whether or not to use the cache",
|
||||||
)
|
)
|
||||||
use_cache: bool = InputField(default=True, description="Whether or not to use the cache")
|
|
||||||
|
|
||||||
@validator("workflow", pre=True)
|
@field_validator("workflow", mode="before")
|
||||||
|
@classmethod
|
||||||
def validate_workflow_is_json(cls, v):
|
def validate_workflow_is_json(cls, v):
|
||||||
|
"""We don't have a workflow schema in the backend, so we just check that it's valid JSON"""
|
||||||
if v is None:
|
if v is None:
|
||||||
return None
|
return None
|
||||||
try:
|
try:
|
||||||
@ -645,8 +689,14 @@ class BaseInvocation(ABC, BaseModel):
|
|||||||
|
|
||||||
UIConfig: ClassVar[Type[UIConfigBase]]
|
UIConfig: ClassVar[Type[UIConfigBase]]
|
||||||
|
|
||||||
|
model_config = ConfigDict(
|
||||||
|
validate_assignment=True,
|
||||||
|
json_schema_extra=json_schema_extra,
|
||||||
|
json_schema_serialization_defaults_required=True,
|
||||||
|
)
|
||||||
|
|
||||||
GenericBaseInvocation = TypeVar("GenericBaseInvocation", bound=BaseInvocation)
|
|
||||||
|
TBaseInvocation = TypeVar("TBaseInvocation", bound=BaseInvocation)
|
||||||
|
|
||||||
|
|
||||||
def invocation(
|
def invocation(
|
||||||
@ -656,7 +706,7 @@ def invocation(
|
|||||||
category: Optional[str] = None,
|
category: Optional[str] = None,
|
||||||
version: Optional[str] = None,
|
version: Optional[str] = None,
|
||||||
use_cache: Optional[bool] = True,
|
use_cache: Optional[bool] = True,
|
||||||
) -> Callable[[Type[GenericBaseInvocation]], Type[GenericBaseInvocation]]:
|
) -> Callable[[Type[TBaseInvocation]], Type[TBaseInvocation]]:
|
||||||
"""
|
"""
|
||||||
Adds metadata to an invocation.
|
Adds metadata to an invocation.
|
||||||
|
|
||||||
@ -668,12 +718,15 @@ def invocation(
|
|||||||
:param Optional[bool] use_cache: Whether or not to use the invocation cache. Defaults to True. The user may override this in the workflow editor.
|
:param Optional[bool] use_cache: Whether or not to use the invocation cache. Defaults to True. The user may override this in the workflow editor.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def wrapper(cls: Type[GenericBaseInvocation]) -> Type[GenericBaseInvocation]:
|
def wrapper(cls: Type[TBaseInvocation]) -> Type[TBaseInvocation]:
|
||||||
# Validate invocation types on creation of invocation classes
|
# Validate invocation types on creation of invocation classes
|
||||||
# TODO: ensure unique?
|
# TODO: ensure unique?
|
||||||
if re.compile(r"^\S+$").match(invocation_type) is None:
|
if re.compile(r"^\S+$").match(invocation_type) is None:
|
||||||
raise ValueError(f'"invocation_type" must consist of non-whitespace characters, got "{invocation_type}"')
|
raise ValueError(f'"invocation_type" must consist of non-whitespace characters, got "{invocation_type}"')
|
||||||
|
|
||||||
|
if invocation_type in BaseInvocation.get_invocation_types():
|
||||||
|
raise ValueError(f'Invocation type "{invocation_type}" already exists')
|
||||||
|
|
||||||
# Add OpenAPI schema extras
|
# Add OpenAPI schema extras
|
||||||
uiconf_name = cls.__qualname__ + ".UIConfig"
|
uiconf_name = cls.__qualname__ + ".UIConfig"
|
||||||
if not hasattr(cls, "UIConfig") or cls.UIConfig.__qualname__ != uiconf_name:
|
if not hasattr(cls, "UIConfig") or cls.UIConfig.__qualname__ != uiconf_name:
|
||||||
@ -691,59 +744,83 @@ def invocation(
|
|||||||
raise InvalidVersionError(f'Invalid version string for node "{invocation_type}": "{version}"') from e
|
raise InvalidVersionError(f'Invalid version string for node "{invocation_type}": "{version}"') from e
|
||||||
cls.UIConfig.version = version
|
cls.UIConfig.version = version
|
||||||
if use_cache is not None:
|
if use_cache is not None:
|
||||||
cls.__fields__["use_cache"].default = use_cache
|
cls.model_fields["use_cache"].default = use_cache
|
||||||
|
|
||||||
|
# Add the invocation type to the model.
|
||||||
|
|
||||||
|
# You'd be tempted to just add the type field and rebuild the model, like this:
|
||||||
|
# cls.model_fields.update(type=FieldInfo.from_annotated_attribute(Literal[invocation_type], invocation_type))
|
||||||
|
# cls.model_rebuild() or cls.model_rebuild(force=True)
|
||||||
|
|
||||||
|
# Unfortunately, because the `GraphInvocation` uses a forward ref in its `graph` field's annotation, this does
|
||||||
|
# not work. Instead, we have to create a new class with the type field and patch the original class with it.
|
||||||
|
|
||||||
# Add the invocation type to the pydantic model of the invocation
|
|
||||||
invocation_type_annotation = Literal[invocation_type] # type: ignore
|
invocation_type_annotation = Literal[invocation_type] # type: ignore
|
||||||
invocation_type_field = ModelField.infer(
|
invocation_type_field = Field(
|
||||||
name="type",
|
title="type",
|
||||||
value=invocation_type,
|
default=invocation_type,
|
||||||
annotation=invocation_type_annotation,
|
|
||||||
class_validators=None,
|
|
||||||
config=cls.__config__,
|
|
||||||
)
|
)
|
||||||
cls.__fields__.update({"type": invocation_type_field})
|
|
||||||
# to support 3.9, 3.10 and 3.11, as described in https://docs.python.org/3/howto/annotations.html
|
docstring = cls.__doc__
|
||||||
if annotations := cls.__dict__.get("__annotations__", None):
|
cls = create_model(
|
||||||
annotations.update({"type": invocation_type_annotation})
|
cls.__qualname__,
|
||||||
|
__base__=cls,
|
||||||
|
__module__=cls.__module__,
|
||||||
|
type=(invocation_type_annotation, invocation_type_field),
|
||||||
|
)
|
||||||
|
cls.__doc__ = docstring
|
||||||
|
|
||||||
|
# TODO: how to type this correctly? it's typed as ModelMetaclass, a private class in pydantic
|
||||||
|
BaseInvocation.register_invocation(cls) # type: ignore
|
||||||
|
|
||||||
return cls
|
return cls
|
||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
GenericBaseInvocationOutput = TypeVar("GenericBaseInvocationOutput", bound=BaseInvocationOutput)
|
TBaseInvocationOutput = TypeVar("TBaseInvocationOutput", bound=BaseInvocationOutput)
|
||||||
|
|
||||||
|
|
||||||
def invocation_output(
|
def invocation_output(
|
||||||
output_type: str,
|
output_type: str,
|
||||||
) -> Callable[[Type[GenericBaseInvocationOutput]], Type[GenericBaseInvocationOutput]]:
|
) -> Callable[[Type[TBaseInvocationOutput]], Type[TBaseInvocationOutput]]:
|
||||||
"""
|
"""
|
||||||
Adds metadata to an invocation output.
|
Adds metadata to an invocation output.
|
||||||
|
|
||||||
:param str output_type: The type of the invocation output. Must be unique among all invocation outputs.
|
:param str output_type: The type of the invocation output. Must be unique among all invocation outputs.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def wrapper(cls: Type[GenericBaseInvocationOutput]) -> Type[GenericBaseInvocationOutput]:
|
def wrapper(cls: Type[TBaseInvocationOutput]) -> Type[TBaseInvocationOutput]:
|
||||||
# Validate output types on creation of invocation output classes
|
# Validate output types on creation of invocation output classes
|
||||||
# TODO: ensure unique?
|
# TODO: ensure unique?
|
||||||
if re.compile(r"^\S+$").match(output_type) is None:
|
if re.compile(r"^\S+$").match(output_type) is None:
|
||||||
raise ValueError(f'"output_type" must consist of non-whitespace characters, got "{output_type}"')
|
raise ValueError(f'"output_type" must consist of non-whitespace characters, got "{output_type}"')
|
||||||
|
|
||||||
# Add the output type to the pydantic model of the invocation output
|
if output_type in BaseInvocationOutput.get_output_types():
|
||||||
output_type_annotation = Literal[output_type] # type: ignore
|
raise ValueError(f'Invocation type "{output_type}" already exists')
|
||||||
output_type_field = ModelField.infer(
|
|
||||||
name="type",
|
|
||||||
value=output_type,
|
|
||||||
annotation=output_type_annotation,
|
|
||||||
class_validators=None,
|
|
||||||
config=cls.__config__,
|
|
||||||
)
|
|
||||||
cls.__fields__.update({"type": output_type_field})
|
|
||||||
|
|
||||||
# to support 3.9, 3.10 and 3.11, as described in https://docs.python.org/3/howto/annotations.html
|
# Add the output type to the model.
|
||||||
if annotations := cls.__dict__.get("__annotations__", None):
|
|
||||||
annotations.update({"type": output_type_annotation})
|
output_type_annotation = Literal[output_type] # type: ignore
|
||||||
|
output_type_field = Field(
|
||||||
|
title="type",
|
||||||
|
default=output_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
docstring = cls.__doc__
|
||||||
|
cls = create_model(
|
||||||
|
cls.__qualname__,
|
||||||
|
__base__=cls,
|
||||||
|
__module__=cls.__module__,
|
||||||
|
type=(output_type_annotation, output_type_field),
|
||||||
|
)
|
||||||
|
cls.__doc__ = docstring
|
||||||
|
|
||||||
|
BaseInvocationOutput.register_output(cls) # type: ignore # TODO: how to type this correctly?
|
||||||
|
|
||||||
return cls
|
return cls
|
||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
GenericBaseModel = TypeVar("GenericBaseModel", bound=BaseModel)
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from pydantic import validator
|
from pydantic import ValidationInfo, field_validator
|
||||||
|
|
||||||
from invokeai.app.invocations.primitives import IntegerCollectionOutput
|
from invokeai.app.invocations.primitives import IntegerCollectionOutput
|
||||||
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
||||||
@ -20,9 +20,9 @@ class RangeInvocation(BaseInvocation):
|
|||||||
stop: int = InputField(default=10, description="The stop of the range")
|
stop: int = InputField(default=10, description="The stop of the range")
|
||||||
step: int = InputField(default=1, description="The step of the range")
|
step: int = InputField(default=1, description="The step of the range")
|
||||||
|
|
||||||
@validator("stop")
|
@field_validator("stop")
|
||||||
def stop_gt_start(cls, v, values):
|
def stop_gt_start(cls, v: int, info: ValidationInfo):
|
||||||
if "start" in values and v <= values["start"]:
|
if "start" in info.data and v <= info.data["start"]:
|
||||||
raise ValueError("stop must be greater than start")
|
raise ValueError("stop must be greater than start")
|
||||||
return v
|
return v
|
||||||
|
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
import re
|
import re
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import List, Union
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from compel import Compel, ReturnedEmbeddingsType
|
from compel import Compel, ReturnedEmbeddingsType
|
||||||
@ -43,7 +43,13 @@ class ConditioningFieldData:
|
|||||||
# PerpNeg = "perp_neg"
|
# PerpNeg = "perp_neg"
|
||||||
|
|
||||||
|
|
||||||
@invocation("compel", title="Prompt", tags=["prompt", "compel"], category="conditioning", version="1.0.0")
|
@invocation(
|
||||||
|
"compel",
|
||||||
|
title="Prompt",
|
||||||
|
tags=["prompt", "compel"],
|
||||||
|
category="conditioning",
|
||||||
|
version="1.0.0",
|
||||||
|
)
|
||||||
class CompelInvocation(BaseInvocation):
|
class CompelInvocation(BaseInvocation):
|
||||||
"""Parse prompt using compel package to conditioning."""
|
"""Parse prompt using compel package to conditioning."""
|
||||||
|
|
||||||
@ -61,17 +67,19 @@ class CompelInvocation(BaseInvocation):
|
|||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
||||||
tokenizer_info = context.services.model_manager.get_model(
|
tokenizer_info = context.services.model_manager.get_model(
|
||||||
**self.clip.tokenizer.dict(),
|
**self.clip.tokenizer.model_dump(),
|
||||||
context=context,
|
context=context,
|
||||||
)
|
)
|
||||||
text_encoder_info = context.services.model_manager.get_model(
|
text_encoder_info = context.services.model_manager.get_model(
|
||||||
**self.clip.text_encoder.dict(),
|
**self.clip.text_encoder.model_dump(),
|
||||||
context=context,
|
context=context,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _lora_loader():
|
def _lora_loader():
|
||||||
for lora in self.clip.loras:
|
for lora in self.clip.loras:
|
||||||
lora_info = context.services.model_manager.get_model(**lora.dict(exclude={"weight"}), context=context)
|
lora_info = context.services.model_manager.get_model(
|
||||||
|
**lora.model_dump(exclude={"weight"}), context=context
|
||||||
|
)
|
||||||
yield (lora_info.context.model, lora.weight)
|
yield (lora_info.context.model, lora.weight)
|
||||||
del lora_info
|
del lora_info
|
||||||
return
|
return
|
||||||
@ -160,11 +168,11 @@ class SDXLPromptInvocationBase:
|
|||||||
zero_on_empty: bool,
|
zero_on_empty: bool,
|
||||||
):
|
):
|
||||||
tokenizer_info = context.services.model_manager.get_model(
|
tokenizer_info = context.services.model_manager.get_model(
|
||||||
**clip_field.tokenizer.dict(),
|
**clip_field.tokenizer.model_dump(),
|
||||||
context=context,
|
context=context,
|
||||||
)
|
)
|
||||||
text_encoder_info = context.services.model_manager.get_model(
|
text_encoder_info = context.services.model_manager.get_model(
|
||||||
**clip_field.text_encoder.dict(),
|
**clip_field.text_encoder.model_dump(),
|
||||||
context=context,
|
context=context,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -172,7 +180,11 @@ class SDXLPromptInvocationBase:
|
|||||||
if prompt == "" and zero_on_empty:
|
if prompt == "" and zero_on_empty:
|
||||||
cpu_text_encoder = text_encoder_info.context.model
|
cpu_text_encoder = text_encoder_info.context.model
|
||||||
c = torch.zeros(
|
c = torch.zeros(
|
||||||
(1, cpu_text_encoder.config.max_position_embeddings, cpu_text_encoder.config.hidden_size),
|
(
|
||||||
|
1,
|
||||||
|
cpu_text_encoder.config.max_position_embeddings,
|
||||||
|
cpu_text_encoder.config.hidden_size,
|
||||||
|
),
|
||||||
dtype=text_encoder_info.context.cache.precision,
|
dtype=text_encoder_info.context.cache.precision,
|
||||||
)
|
)
|
||||||
if get_pooled:
|
if get_pooled:
|
||||||
@ -186,7 +198,9 @@ class SDXLPromptInvocationBase:
|
|||||||
|
|
||||||
def _lora_loader():
|
def _lora_loader():
|
||||||
for lora in clip_field.loras:
|
for lora in clip_field.loras:
|
||||||
lora_info = context.services.model_manager.get_model(**lora.dict(exclude={"weight"}), context=context)
|
lora_info = context.services.model_manager.get_model(
|
||||||
|
**lora.model_dump(exclude={"weight"}), context=context
|
||||||
|
)
|
||||||
yield (lora_info.context.model, lora.weight)
|
yield (lora_info.context.model, lora.weight)
|
||||||
del lora_info
|
del lora_info
|
||||||
return
|
return
|
||||||
@ -273,8 +287,16 @@ class SDXLPromptInvocationBase:
|
|||||||
class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||||
"""Parse prompt using compel package to conditioning."""
|
"""Parse prompt using compel package to conditioning."""
|
||||||
|
|
||||||
prompt: str = InputField(default="", description=FieldDescriptions.compel_prompt, ui_component=UIComponent.Textarea)
|
prompt: str = InputField(
|
||||||
style: str = InputField(default="", description=FieldDescriptions.compel_prompt, ui_component=UIComponent.Textarea)
|
default="",
|
||||||
|
description=FieldDescriptions.compel_prompt,
|
||||||
|
ui_component=UIComponent.Textarea,
|
||||||
|
)
|
||||||
|
style: str = InputField(
|
||||||
|
default="",
|
||||||
|
description=FieldDescriptions.compel_prompt,
|
||||||
|
ui_component=UIComponent.Textarea,
|
||||||
|
)
|
||||||
original_width: int = InputField(default=1024, description="")
|
original_width: int = InputField(default=1024, description="")
|
||||||
original_height: int = InputField(default=1024, description="")
|
original_height: int = InputField(default=1024, description="")
|
||||||
crop_top: int = InputField(default=0, description="")
|
crop_top: int = InputField(default=0, description="")
|
||||||
@ -310,7 +332,9 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
|||||||
[
|
[
|
||||||
c1,
|
c1,
|
||||||
torch.zeros(
|
torch.zeros(
|
||||||
(c1.shape[0], c2.shape[1] - c1.shape[1], c1.shape[2]), device=c1.device, dtype=c1.dtype
|
(c1.shape[0], c2.shape[1] - c1.shape[1], c1.shape[2]),
|
||||||
|
device=c1.device,
|
||||||
|
dtype=c1.dtype,
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
dim=1,
|
dim=1,
|
||||||
@ -321,7 +345,9 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
|||||||
[
|
[
|
||||||
c2,
|
c2,
|
||||||
torch.zeros(
|
torch.zeros(
|
||||||
(c2.shape[0], c1.shape[1] - c2.shape[1], c2.shape[2]), device=c2.device, dtype=c2.dtype
|
(c2.shape[0], c1.shape[1] - c2.shape[1], c2.shape[2]),
|
||||||
|
device=c2.device,
|
||||||
|
dtype=c2.dtype,
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
dim=1,
|
dim=1,
|
||||||
@ -359,7 +385,9 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase
|
|||||||
"""Parse prompt using compel package to conditioning."""
|
"""Parse prompt using compel package to conditioning."""
|
||||||
|
|
||||||
style: str = InputField(
|
style: str = InputField(
|
||||||
default="", description=FieldDescriptions.compel_prompt, ui_component=UIComponent.Textarea
|
default="",
|
||||||
|
description=FieldDescriptions.compel_prompt,
|
||||||
|
ui_component=UIComponent.Textarea,
|
||||||
) # TODO: ?
|
) # TODO: ?
|
||||||
original_width: int = InputField(default=1024, description="")
|
original_width: int = InputField(default=1024, description="")
|
||||||
original_height: int = InputField(default=1024, description="")
|
original_height: int = InputField(default=1024, description="")
|
||||||
@ -403,10 +431,16 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase
|
|||||||
class ClipSkipInvocationOutput(BaseInvocationOutput):
|
class ClipSkipInvocationOutput(BaseInvocationOutput):
|
||||||
"""Clip skip node output"""
|
"""Clip skip node output"""
|
||||||
|
|
||||||
clip: ClipField = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP")
|
clip: Optional[ClipField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP")
|
||||||
|
|
||||||
|
|
||||||
@invocation("clip_skip", title="CLIP Skip", tags=["clipskip", "clip", "skip"], category="conditioning", version="1.0.0")
|
@invocation(
|
||||||
|
"clip_skip",
|
||||||
|
title="CLIP Skip",
|
||||||
|
tags=["clipskip", "clip", "skip"],
|
||||||
|
category="conditioning",
|
||||||
|
version="1.0.0",
|
||||||
|
)
|
||||||
class ClipSkipInvocation(BaseInvocation):
|
class ClipSkipInvocation(BaseInvocation):
|
||||||
"""Skip layers in clip text_encoder model."""
|
"""Skip layers in clip text_encoder model."""
|
||||||
|
|
||||||
@ -421,7 +455,9 @@ class ClipSkipInvocation(BaseInvocation):
|
|||||||
|
|
||||||
|
|
||||||
def get_max_token_count(
|
def get_max_token_count(
|
||||||
tokenizer, prompt: Union[FlattenedPrompt, Blend, Conjunction], truncate_if_too_long=False
|
tokenizer,
|
||||||
|
prompt: Union[FlattenedPrompt, Blend, Conjunction],
|
||||||
|
truncate_if_too_long=False,
|
||||||
) -> int:
|
) -> int:
|
||||||
if type(prompt) is Blend:
|
if type(prompt) is Blend:
|
||||||
blend: Blend = prompt
|
blend: Blend = prompt
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
# initial implementation by Gregg Helt, 2023
|
# initial implementation by Gregg Helt, 2023
|
||||||
# heavily leverages controlnet_aux package: https://github.com/patrickvonplaten/controlnet_aux
|
# heavily leverages controlnet_aux package: https://github.com/patrickvonplaten/controlnet_aux
|
||||||
from builtins import bool, float
|
from builtins import bool, float
|
||||||
from typing import Dict, List, Literal, Optional, Union
|
from typing import Dict, List, Literal, Union
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -24,7 +24,7 @@ from controlnet_aux import (
|
|||||||
)
|
)
|
||||||
from controlnet_aux.util import HWC3, ade_palette
|
from controlnet_aux.util import HWC3, ade_palette
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from pydantic import BaseModel, Field, validator
|
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||||
|
|
||||||
from invokeai.app.invocations.primitives import ImageField, ImageOutput
|
from invokeai.app.invocations.primitives import ImageField, ImageOutput
|
||||||
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
|
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
|
||||||
@ -57,6 +57,8 @@ class ControlNetModelField(BaseModel):
|
|||||||
model_name: str = Field(description="Name of the ControlNet model")
|
model_name: str = Field(description="Name of the ControlNet model")
|
||||||
base_model: BaseModelType = Field(description="Base model")
|
base_model: BaseModelType = Field(description="Base model")
|
||||||
|
|
||||||
|
model_config = ConfigDict(protected_namespaces=())
|
||||||
|
|
||||||
|
|
||||||
class ControlField(BaseModel):
|
class ControlField(BaseModel):
|
||||||
image: ImageField = Field(description="The control image")
|
image: ImageField = Field(description="The control image")
|
||||||
@ -71,7 +73,7 @@ class ControlField(BaseModel):
|
|||||||
control_mode: CONTROLNET_MODE_VALUES = Field(default="balanced", description="The control mode to use")
|
control_mode: CONTROLNET_MODE_VALUES = Field(default="balanced", description="The control mode to use")
|
||||||
resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode to use")
|
resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode to use")
|
||||||
|
|
||||||
@validator("control_weight")
|
@field_validator("control_weight")
|
||||||
def validate_control_weight(cls, v):
|
def validate_control_weight(cls, v):
|
||||||
"""Validate that all control weights in the valid range"""
|
"""Validate that all control weights in the valid range"""
|
||||||
if isinstance(v, list):
|
if isinstance(v, list):
|
||||||
@ -124,9 +126,7 @@ class ControlNetInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@invocation(
|
# This invocation exists for other invocations to subclass it - do not register with @invocation!
|
||||||
"image_processor", title="Base Image Processor", tags=["controlnet"], category="controlnet", version="1.0.0"
|
|
||||||
)
|
|
||||||
class ImageProcessorInvocation(BaseInvocation):
|
class ImageProcessorInvocation(BaseInvocation):
|
||||||
"""Base class for invocations that preprocess images for ControlNet"""
|
"""Base class for invocations that preprocess images for ControlNet"""
|
||||||
|
|
||||||
@ -393,9 +393,9 @@ class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation):
|
|||||||
|
|
||||||
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
|
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
|
||||||
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
|
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
|
||||||
h: Optional[int] = InputField(default=512, ge=0, description="Content shuffle `h` parameter")
|
h: int = InputField(default=512, ge=0, description="Content shuffle `h` parameter")
|
||||||
w: Optional[int] = InputField(default=512, ge=0, description="Content shuffle `w` parameter")
|
w: int = InputField(default=512, ge=0, description="Content shuffle `w` parameter")
|
||||||
f: Optional[int] = InputField(default=256, ge=0, description="Content shuffle `f` parameter")
|
f: int = InputField(default=256, ge=0, description="Content shuffle `f` parameter")
|
||||||
|
|
||||||
def run_processor(self, image):
|
def run_processor(self, image):
|
||||||
content_shuffle_processor = ContentShuffleDetector()
|
content_shuffle_processor = ContentShuffleDetector()
|
||||||
@ -575,14 +575,14 @@ class ColorMapImageProcessorInvocation(ImageProcessorInvocation):
|
|||||||
|
|
||||||
def run_processor(self, image: Image.Image):
|
def run_processor(self, image: Image.Image):
|
||||||
image = image.convert("RGB")
|
image = image.convert("RGB")
|
||||||
image = np.array(image, dtype=np.uint8)
|
np_image = np.array(image, dtype=np.uint8)
|
||||||
height, width = image.shape[:2]
|
height, width = np_image.shape[:2]
|
||||||
|
|
||||||
width_tile_size = min(self.color_map_tile_size, width)
|
width_tile_size = min(self.color_map_tile_size, width)
|
||||||
height_tile_size = min(self.color_map_tile_size, height)
|
height_tile_size = min(self.color_map_tile_size, height)
|
||||||
|
|
||||||
color_map = cv2.resize(
|
color_map = cv2.resize(
|
||||||
image,
|
np_image,
|
||||||
(width // width_tile_size, height // height_tile_size),
|
(width // width_tile_size, height // height_tile_size),
|
||||||
interpolation=cv2.INTER_CUBIC,
|
interpolation=cv2.INTER_CUBIC,
|
||||||
)
|
)
|
||||||
|
@ -8,7 +8,7 @@ import numpy as np
|
|||||||
from mediapipe.python.solutions.face_mesh import FaceMesh # type: ignore[import]
|
from mediapipe.python.solutions.face_mesh import FaceMesh # type: ignore[import]
|
||||||
from PIL import Image, ImageDraw, ImageFilter, ImageFont, ImageOps
|
from PIL import Image, ImageDraw, ImageFilter, ImageFont, ImageOps
|
||||||
from PIL.Image import Image as ImageType
|
from PIL.Image import Image as ImageType
|
||||||
from pydantic import validator
|
from pydantic import field_validator
|
||||||
|
|
||||||
import invokeai.assets.fonts as font_assets
|
import invokeai.assets.fonts as font_assets
|
||||||
from invokeai.app.invocations.baseinvocation import (
|
from invokeai.app.invocations.baseinvocation import (
|
||||||
@ -550,7 +550,7 @@ class FaceMaskInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
invert_mask: bool = InputField(default=False, description="Toggle to invert the mask")
|
invert_mask: bool = InputField(default=False, description="Toggle to invert the mask")
|
||||||
|
|
||||||
@validator("face_ids")
|
@field_validator("face_ids")
|
||||||
def validate_comma_separated_ints(cls, v) -> str:
|
def validate_comma_separated_ints(cls, v) -> str:
|
||||||
comma_separated_ints_regex = re.compile(r"^\d*(,\d+)*$")
|
comma_separated_ints_regex = re.compile(r"^\d*(,\d+)*$")
|
||||||
if comma_separated_ints_regex.match(v) is None:
|
if comma_separated_ints_regex.match(v) is None:
|
||||||
|
@ -36,7 +36,13 @@ class ShowImageInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@invocation("blank_image", title="Blank Image", tags=["image"], category="image", version="1.0.0")
|
@invocation(
|
||||||
|
"blank_image",
|
||||||
|
title="Blank Image",
|
||||||
|
tags=["image"],
|
||||||
|
category="image",
|
||||||
|
version="1.0.0",
|
||||||
|
)
|
||||||
class BlankImageInvocation(BaseInvocation):
|
class BlankImageInvocation(BaseInvocation):
|
||||||
"""Creates a blank image and forwards it to the pipeline"""
|
"""Creates a blank image and forwards it to the pipeline"""
|
||||||
|
|
||||||
@ -65,7 +71,13 @@ class BlankImageInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@invocation("img_crop", title="Crop Image", tags=["image", "crop"], category="image", version="1.0.0")
|
@invocation(
|
||||||
|
"img_crop",
|
||||||
|
title="Crop Image",
|
||||||
|
tags=["image", "crop"],
|
||||||
|
category="image",
|
||||||
|
version="1.0.0",
|
||||||
|
)
|
||||||
class ImageCropInvocation(BaseInvocation):
|
class ImageCropInvocation(BaseInvocation):
|
||||||
"""Crops an image to a specified box. The box can be outside of the image."""
|
"""Crops an image to a specified box. The box can be outside of the image."""
|
||||||
|
|
||||||
@ -98,7 +110,13 @@ class ImageCropInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@invocation("img_paste", title="Paste Image", tags=["image", "paste"], category="image", version="1.0.1")
|
@invocation(
|
||||||
|
"img_paste",
|
||||||
|
title="Paste Image",
|
||||||
|
tags=["image", "paste"],
|
||||||
|
category="image",
|
||||||
|
version="1.0.1",
|
||||||
|
)
|
||||||
class ImagePasteInvocation(BaseInvocation):
|
class ImagePasteInvocation(BaseInvocation):
|
||||||
"""Pastes an image into another image."""
|
"""Pastes an image into another image."""
|
||||||
|
|
||||||
@ -151,7 +169,13 @@ class ImagePasteInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@invocation("tomask", title="Mask from Alpha", tags=["image", "mask"], category="image", version="1.0.0")
|
@invocation(
|
||||||
|
"tomask",
|
||||||
|
title="Mask from Alpha",
|
||||||
|
tags=["image", "mask"],
|
||||||
|
category="image",
|
||||||
|
version="1.0.0",
|
||||||
|
)
|
||||||
class MaskFromAlphaInvocation(BaseInvocation):
|
class MaskFromAlphaInvocation(BaseInvocation):
|
||||||
"""Extracts the alpha channel of an image as a mask."""
|
"""Extracts the alpha channel of an image as a mask."""
|
||||||
|
|
||||||
@ -182,7 +206,13 @@ class MaskFromAlphaInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@invocation("img_mul", title="Multiply Images", tags=["image", "multiply"], category="image", version="1.0.0")
|
@invocation(
|
||||||
|
"img_mul",
|
||||||
|
title="Multiply Images",
|
||||||
|
tags=["image", "multiply"],
|
||||||
|
category="image",
|
||||||
|
version="1.0.0",
|
||||||
|
)
|
||||||
class ImageMultiplyInvocation(BaseInvocation):
|
class ImageMultiplyInvocation(BaseInvocation):
|
||||||
"""Multiplies two images together using `PIL.ImageChops.multiply()`."""
|
"""Multiplies two images together using `PIL.ImageChops.multiply()`."""
|
||||||
|
|
||||||
@ -215,7 +245,13 @@ class ImageMultiplyInvocation(BaseInvocation):
|
|||||||
IMAGE_CHANNELS = Literal["A", "R", "G", "B"]
|
IMAGE_CHANNELS = Literal["A", "R", "G", "B"]
|
||||||
|
|
||||||
|
|
||||||
@invocation("img_chan", title="Extract Image Channel", tags=["image", "channel"], category="image", version="1.0.0")
|
@invocation(
|
||||||
|
"img_chan",
|
||||||
|
title="Extract Image Channel",
|
||||||
|
tags=["image", "channel"],
|
||||||
|
category="image",
|
||||||
|
version="1.0.0",
|
||||||
|
)
|
||||||
class ImageChannelInvocation(BaseInvocation):
|
class ImageChannelInvocation(BaseInvocation):
|
||||||
"""Gets a channel from an image."""
|
"""Gets a channel from an image."""
|
||||||
|
|
||||||
@ -247,7 +283,13 @@ class ImageChannelInvocation(BaseInvocation):
|
|||||||
IMAGE_MODES = Literal["L", "RGB", "RGBA", "CMYK", "YCbCr", "LAB", "HSV", "I", "F"]
|
IMAGE_MODES = Literal["L", "RGB", "RGBA", "CMYK", "YCbCr", "LAB", "HSV", "I", "F"]
|
||||||
|
|
||||||
|
|
||||||
@invocation("img_conv", title="Convert Image Mode", tags=["image", "convert"], category="image", version="1.0.0")
|
@invocation(
|
||||||
|
"img_conv",
|
||||||
|
title="Convert Image Mode",
|
||||||
|
tags=["image", "convert"],
|
||||||
|
category="image",
|
||||||
|
version="1.0.0",
|
||||||
|
)
|
||||||
class ImageConvertInvocation(BaseInvocation):
|
class ImageConvertInvocation(BaseInvocation):
|
||||||
"""Converts an image to a different mode."""
|
"""Converts an image to a different mode."""
|
||||||
|
|
||||||
@ -276,7 +318,13 @@ class ImageConvertInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@invocation("img_blur", title="Blur Image", tags=["image", "blur"], category="image", version="1.0.0")
|
@invocation(
|
||||||
|
"img_blur",
|
||||||
|
title="Blur Image",
|
||||||
|
tags=["image", "blur"],
|
||||||
|
category="image",
|
||||||
|
version="1.0.0",
|
||||||
|
)
|
||||||
class ImageBlurInvocation(BaseInvocation):
|
class ImageBlurInvocation(BaseInvocation):
|
||||||
"""Blurs an image"""
|
"""Blurs an image"""
|
||||||
|
|
||||||
@ -330,7 +378,13 @@ PIL_RESAMPLING_MAP = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@invocation("img_resize", title="Resize Image", tags=["image", "resize"], category="image", version="1.0.0")
|
@invocation(
|
||||||
|
"img_resize",
|
||||||
|
title="Resize Image",
|
||||||
|
tags=["image", "resize"],
|
||||||
|
category="image",
|
||||||
|
version="1.0.0",
|
||||||
|
)
|
||||||
class ImageResizeInvocation(BaseInvocation):
|
class ImageResizeInvocation(BaseInvocation):
|
||||||
"""Resizes an image to specific dimensions"""
|
"""Resizes an image to specific dimensions"""
|
||||||
|
|
||||||
@ -359,7 +413,7 @@ class ImageResizeInvocation(BaseInvocation):
|
|||||||
node_id=self.id,
|
node_id=self.id,
|
||||||
session_id=context.graph_execution_state_id,
|
session_id=context.graph_execution_state_id,
|
||||||
is_intermediate=self.is_intermediate,
|
is_intermediate=self.is_intermediate,
|
||||||
metadata=self.metadata.dict() if self.metadata else None,
|
metadata=self.metadata.model_dump() if self.metadata else None,
|
||||||
workflow=self.workflow,
|
workflow=self.workflow,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -370,7 +424,13 @@ class ImageResizeInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@invocation("img_scale", title="Scale Image", tags=["image", "scale"], category="image", version="1.0.0")
|
@invocation(
|
||||||
|
"img_scale",
|
||||||
|
title="Scale Image",
|
||||||
|
tags=["image", "scale"],
|
||||||
|
category="image",
|
||||||
|
version="1.0.0",
|
||||||
|
)
|
||||||
class ImageScaleInvocation(BaseInvocation):
|
class ImageScaleInvocation(BaseInvocation):
|
||||||
"""Scales an image by a factor"""
|
"""Scales an image by a factor"""
|
||||||
|
|
||||||
@ -411,7 +471,13 @@ class ImageScaleInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@invocation("img_lerp", title="Lerp Image", tags=["image", "lerp"], category="image", version="1.0.0")
|
@invocation(
|
||||||
|
"img_lerp",
|
||||||
|
title="Lerp Image",
|
||||||
|
tags=["image", "lerp"],
|
||||||
|
category="image",
|
||||||
|
version="1.0.0",
|
||||||
|
)
|
||||||
class ImageLerpInvocation(BaseInvocation):
|
class ImageLerpInvocation(BaseInvocation):
|
||||||
"""Linear interpolation of all pixels of an image"""
|
"""Linear interpolation of all pixels of an image"""
|
||||||
|
|
||||||
@ -444,7 +510,13 @@ class ImageLerpInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@invocation("img_ilerp", title="Inverse Lerp Image", tags=["image", "ilerp"], category="image", version="1.0.0")
|
@invocation(
|
||||||
|
"img_ilerp",
|
||||||
|
title="Inverse Lerp Image",
|
||||||
|
tags=["image", "ilerp"],
|
||||||
|
category="image",
|
||||||
|
version="1.0.0",
|
||||||
|
)
|
||||||
class ImageInverseLerpInvocation(BaseInvocation):
|
class ImageInverseLerpInvocation(BaseInvocation):
|
||||||
"""Inverse linear interpolation of all pixels of an image"""
|
"""Inverse linear interpolation of all pixels of an image"""
|
||||||
|
|
||||||
@ -456,7 +528,7 @@ class ImageInverseLerpInvocation(BaseInvocation):
|
|||||||
image = context.services.images.get_pil_image(self.image.image_name)
|
image = context.services.images.get_pil_image(self.image.image_name)
|
||||||
|
|
||||||
image_arr = numpy.asarray(image, dtype=numpy.float32)
|
image_arr = numpy.asarray(image, dtype=numpy.float32)
|
||||||
image_arr = numpy.minimum(numpy.maximum(image_arr - self.min, 0) / float(self.max - self.min), 1) * 255
|
image_arr = numpy.minimum(numpy.maximum(image_arr - self.min, 0) / float(self.max - self.min), 1) * 255 # type: ignore [assignment]
|
||||||
|
|
||||||
ilerp_image = Image.fromarray(numpy.uint8(image_arr))
|
ilerp_image = Image.fromarray(numpy.uint8(image_arr))
|
||||||
|
|
||||||
@ -477,7 +549,13 @@ class ImageInverseLerpInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@invocation("img_nsfw", title="Blur NSFW Image", tags=["image", "nsfw"], category="image", version="1.0.0")
|
@invocation(
|
||||||
|
"img_nsfw",
|
||||||
|
title="Blur NSFW Image",
|
||||||
|
tags=["image", "nsfw"],
|
||||||
|
category="image",
|
||||||
|
version="1.0.0",
|
||||||
|
)
|
||||||
class ImageNSFWBlurInvocation(BaseInvocation):
|
class ImageNSFWBlurInvocation(BaseInvocation):
|
||||||
"""Add blur to NSFW-flagged images"""
|
"""Add blur to NSFW-flagged images"""
|
||||||
|
|
||||||
@ -505,7 +583,7 @@ class ImageNSFWBlurInvocation(BaseInvocation):
|
|||||||
node_id=self.id,
|
node_id=self.id,
|
||||||
session_id=context.graph_execution_state_id,
|
session_id=context.graph_execution_state_id,
|
||||||
is_intermediate=self.is_intermediate,
|
is_intermediate=self.is_intermediate,
|
||||||
metadata=self.metadata.dict() if self.metadata else None,
|
metadata=self.metadata.model_dump() if self.metadata else None,
|
||||||
workflow=self.workflow,
|
workflow=self.workflow,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -515,7 +593,7 @@ class ImageNSFWBlurInvocation(BaseInvocation):
|
|||||||
height=image_dto.height,
|
height=image_dto.height,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_caution_img(self) -> Image:
|
def _get_caution_img(self) -> Image.Image:
|
||||||
import invokeai.app.assets.images as image_assets
|
import invokeai.app.assets.images as image_assets
|
||||||
|
|
||||||
caution = Image.open(Path(image_assets.__path__[0]) / "caution.png")
|
caution = Image.open(Path(image_assets.__path__[0]) / "caution.png")
|
||||||
@ -523,7 +601,11 @@ class ImageNSFWBlurInvocation(BaseInvocation):
|
|||||||
|
|
||||||
|
|
||||||
@invocation(
|
@invocation(
|
||||||
"img_watermark", title="Add Invisible Watermark", tags=["image", "watermark"], category="image", version="1.0.0"
|
"img_watermark",
|
||||||
|
title="Add Invisible Watermark",
|
||||||
|
tags=["image", "watermark"],
|
||||||
|
category="image",
|
||||||
|
version="1.0.0",
|
||||||
)
|
)
|
||||||
class ImageWatermarkInvocation(BaseInvocation):
|
class ImageWatermarkInvocation(BaseInvocation):
|
||||||
"""Add an invisible watermark to an image"""
|
"""Add an invisible watermark to an image"""
|
||||||
@ -544,7 +626,7 @@ class ImageWatermarkInvocation(BaseInvocation):
|
|||||||
node_id=self.id,
|
node_id=self.id,
|
||||||
session_id=context.graph_execution_state_id,
|
session_id=context.graph_execution_state_id,
|
||||||
is_intermediate=self.is_intermediate,
|
is_intermediate=self.is_intermediate,
|
||||||
metadata=self.metadata.dict() if self.metadata else None,
|
metadata=self.metadata.model_dump() if self.metadata else None,
|
||||||
workflow=self.workflow,
|
workflow=self.workflow,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -555,7 +637,13 @@ class ImageWatermarkInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@invocation("mask_edge", title="Mask Edge", tags=["image", "mask", "inpaint"], category="image", version="1.0.0")
|
@invocation(
|
||||||
|
"mask_edge",
|
||||||
|
title="Mask Edge",
|
||||||
|
tags=["image", "mask", "inpaint"],
|
||||||
|
category="image",
|
||||||
|
version="1.0.0",
|
||||||
|
)
|
||||||
class MaskEdgeInvocation(BaseInvocation):
|
class MaskEdgeInvocation(BaseInvocation):
|
||||||
"""Applies an edge mask to an image"""
|
"""Applies an edge mask to an image"""
|
||||||
|
|
||||||
@ -601,7 +689,11 @@ class MaskEdgeInvocation(BaseInvocation):
|
|||||||
|
|
||||||
|
|
||||||
@invocation(
|
@invocation(
|
||||||
"mask_combine", title="Combine Masks", tags=["image", "mask", "multiply"], category="image", version="1.0.0"
|
"mask_combine",
|
||||||
|
title="Combine Masks",
|
||||||
|
tags=["image", "mask", "multiply"],
|
||||||
|
category="image",
|
||||||
|
version="1.0.0",
|
||||||
)
|
)
|
||||||
class MaskCombineInvocation(BaseInvocation):
|
class MaskCombineInvocation(BaseInvocation):
|
||||||
"""Combine two masks together by multiplying them using `PIL.ImageChops.multiply()`."""
|
"""Combine two masks together by multiplying them using `PIL.ImageChops.multiply()`."""
|
||||||
@ -632,7 +724,13 @@ class MaskCombineInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@invocation("color_correct", title="Color Correct", tags=["image", "color"], category="image", version="1.0.0")
|
@invocation(
|
||||||
|
"color_correct",
|
||||||
|
title="Color Correct",
|
||||||
|
tags=["image", "color"],
|
||||||
|
category="image",
|
||||||
|
version="1.0.0",
|
||||||
|
)
|
||||||
class ColorCorrectInvocation(BaseInvocation):
|
class ColorCorrectInvocation(BaseInvocation):
|
||||||
"""
|
"""
|
||||||
Shifts the colors of a target image to match the reference image, optionally
|
Shifts the colors of a target image to match the reference image, optionally
|
||||||
@ -742,7 +840,13 @@ class ColorCorrectInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@invocation("img_hue_adjust", title="Adjust Image Hue", tags=["image", "hue"], category="image", version="1.0.0")
|
@invocation(
|
||||||
|
"img_hue_adjust",
|
||||||
|
title="Adjust Image Hue",
|
||||||
|
tags=["image", "hue"],
|
||||||
|
category="image",
|
||||||
|
version="1.0.0",
|
||||||
|
)
|
||||||
class ImageHueAdjustmentInvocation(BaseInvocation):
|
class ImageHueAdjustmentInvocation(BaseInvocation):
|
||||||
"""Adjusts the Hue of an image."""
|
"""Adjusts the Hue of an image."""
|
||||||
|
|
||||||
@ -980,7 +1084,7 @@ class SaveImageInvocation(BaseInvocation):
|
|||||||
|
|
||||||
image: ImageField = InputField(description=FieldDescriptions.image)
|
image: ImageField = InputField(description=FieldDescriptions.image)
|
||||||
board: Optional[BoardField] = InputField(default=None, description=FieldDescriptions.board, input=Input.Direct)
|
board: Optional[BoardField] = InputField(default=None, description=FieldDescriptions.board, input=Input.Direct)
|
||||||
metadata: CoreMetadata = InputField(
|
metadata: Optional[CoreMetadata] = InputField(
|
||||||
default=None,
|
default=None,
|
||||||
description=FieldDescriptions.core_metadata,
|
description=FieldDescriptions.core_metadata,
|
||||||
ui_hidden=True,
|
ui_hidden=True,
|
||||||
@ -997,7 +1101,7 @@ class SaveImageInvocation(BaseInvocation):
|
|||||||
node_id=self.id,
|
node_id=self.id,
|
||||||
session_id=context.graph_execution_state_id,
|
session_id=context.graph_execution_state_id,
|
||||||
is_intermediate=self.is_intermediate,
|
is_intermediate=self.is_intermediate,
|
||||||
metadata=self.metadata.dict() if self.metadata else None,
|
metadata=self.metadata.model_dump() if self.metadata else None,
|
||||||
workflow=self.workflow,
|
workflow=self.workflow,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -2,7 +2,7 @@ import os
|
|||||||
from builtins import float
|
from builtins import float
|
||||||
from typing import List, Union
|
from typing import List, Union
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
|
|
||||||
from invokeai.app.invocations.baseinvocation import (
|
from invokeai.app.invocations.baseinvocation import (
|
||||||
BaseInvocation,
|
BaseInvocation,
|
||||||
@ -25,11 +25,15 @@ class IPAdapterModelField(BaseModel):
|
|||||||
model_name: str = Field(description="Name of the IP-Adapter model")
|
model_name: str = Field(description="Name of the IP-Adapter model")
|
||||||
base_model: BaseModelType = Field(description="Base model")
|
base_model: BaseModelType = Field(description="Base model")
|
||||||
|
|
||||||
|
model_config = ConfigDict(protected_namespaces=())
|
||||||
|
|
||||||
|
|
||||||
class CLIPVisionModelField(BaseModel):
|
class CLIPVisionModelField(BaseModel):
|
||||||
model_name: str = Field(description="Name of the CLIP Vision image encoder model")
|
model_name: str = Field(description="Name of the CLIP Vision image encoder model")
|
||||||
base_model: BaseModelType = Field(description="Base model (usually 'Any')")
|
base_model: BaseModelType = Field(description="Base model (usually 'Any')")
|
||||||
|
|
||||||
|
model_config = ConfigDict(protected_namespaces=())
|
||||||
|
|
||||||
|
|
||||||
class IPAdapterField(BaseModel):
|
class IPAdapterField(BaseModel):
|
||||||
image: ImageField = Field(description="The IP-Adapter image prompt.")
|
image: ImageField = Field(description="The IP-Adapter image prompt.")
|
||||||
|
@ -19,7 +19,7 @@ from diffusers.models.attention_processor import (
|
|||||||
)
|
)
|
||||||
from diffusers.schedulers import DPMSolverSDEScheduler
|
from diffusers.schedulers import DPMSolverSDEScheduler
|
||||||
from diffusers.schedulers import SchedulerMixin as Scheduler
|
from diffusers.schedulers import SchedulerMixin as Scheduler
|
||||||
from pydantic import validator
|
from pydantic import field_validator
|
||||||
from torchvision.transforms.functional import resize as tv_resize
|
from torchvision.transforms.functional import resize as tv_resize
|
||||||
|
|
||||||
from invokeai.app.invocations.ip_adapter import IPAdapterField
|
from invokeai.app.invocations.ip_adapter import IPAdapterField
|
||||||
@ -84,12 +84,20 @@ class SchedulerOutput(BaseInvocationOutput):
|
|||||||
scheduler: SAMPLER_NAME_VALUES = OutputField(description=FieldDescriptions.scheduler, ui_type=UIType.Scheduler)
|
scheduler: SAMPLER_NAME_VALUES = OutputField(description=FieldDescriptions.scheduler, ui_type=UIType.Scheduler)
|
||||||
|
|
||||||
|
|
||||||
@invocation("scheduler", title="Scheduler", tags=["scheduler"], category="latents", version="1.0.0")
|
@invocation(
|
||||||
|
"scheduler",
|
||||||
|
title="Scheduler",
|
||||||
|
tags=["scheduler"],
|
||||||
|
category="latents",
|
||||||
|
version="1.0.0",
|
||||||
|
)
|
||||||
class SchedulerInvocation(BaseInvocation):
|
class SchedulerInvocation(BaseInvocation):
|
||||||
"""Selects a scheduler."""
|
"""Selects a scheduler."""
|
||||||
|
|
||||||
scheduler: SAMPLER_NAME_VALUES = InputField(
|
scheduler: SAMPLER_NAME_VALUES = InputField(
|
||||||
default="euler", description=FieldDescriptions.scheduler, ui_type=UIType.Scheduler
|
default="euler",
|
||||||
|
description=FieldDescriptions.scheduler,
|
||||||
|
ui_type=UIType.Scheduler,
|
||||||
)
|
)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> SchedulerOutput:
|
def invoke(self, context: InvocationContext) -> SchedulerOutput:
|
||||||
@ -97,7 +105,11 @@ class SchedulerInvocation(BaseInvocation):
|
|||||||
|
|
||||||
|
|
||||||
@invocation(
|
@invocation(
|
||||||
"create_denoise_mask", title="Create Denoise Mask", tags=["mask", "denoise"], category="latents", version="1.0.0"
|
"create_denoise_mask",
|
||||||
|
title="Create Denoise Mask",
|
||||||
|
tags=["mask", "denoise"],
|
||||||
|
category="latents",
|
||||||
|
version="1.0.0",
|
||||||
)
|
)
|
||||||
class CreateDenoiseMaskInvocation(BaseInvocation):
|
class CreateDenoiseMaskInvocation(BaseInvocation):
|
||||||
"""Creates mask for denoising model run."""
|
"""Creates mask for denoising model run."""
|
||||||
@ -106,7 +118,11 @@ class CreateDenoiseMaskInvocation(BaseInvocation):
|
|||||||
image: Optional[ImageField] = InputField(default=None, description="Image which will be masked", ui_order=1)
|
image: Optional[ImageField] = InputField(default=None, description="Image which will be masked", ui_order=1)
|
||||||
mask: ImageField = InputField(description="The mask to use when pasting", ui_order=2)
|
mask: ImageField = InputField(description="The mask to use when pasting", ui_order=2)
|
||||||
tiled: bool = InputField(default=False, description=FieldDescriptions.tiled, ui_order=3)
|
tiled: bool = InputField(default=False, description=FieldDescriptions.tiled, ui_order=3)
|
||||||
fp32: bool = InputField(default=DEFAULT_PRECISION == "float32", description=FieldDescriptions.fp32, ui_order=4)
|
fp32: bool = InputField(
|
||||||
|
default=DEFAULT_PRECISION == "float32",
|
||||||
|
description=FieldDescriptions.fp32,
|
||||||
|
ui_order=4,
|
||||||
|
)
|
||||||
|
|
||||||
def prep_mask_tensor(self, mask_image):
|
def prep_mask_tensor(self, mask_image):
|
||||||
if mask_image.mode != "L":
|
if mask_image.mode != "L":
|
||||||
@ -134,7 +150,7 @@ class CreateDenoiseMaskInvocation(BaseInvocation):
|
|||||||
|
|
||||||
if image is not None:
|
if image is not None:
|
||||||
vae_info = context.services.model_manager.get_model(
|
vae_info = context.services.model_manager.get_model(
|
||||||
**self.vae.vae.dict(),
|
**self.vae.vae.model_dump(),
|
||||||
context=context,
|
context=context,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -167,7 +183,7 @@ def get_scheduler(
|
|||||||
) -> Scheduler:
|
) -> Scheduler:
|
||||||
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP["ddim"])
|
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP["ddim"])
|
||||||
orig_scheduler_info = context.services.model_manager.get_model(
|
orig_scheduler_info = context.services.model_manager.get_model(
|
||||||
**scheduler_info.dict(),
|
**scheduler_info.model_dump(),
|
||||||
context=context,
|
context=context,
|
||||||
)
|
)
|
||||||
with orig_scheduler_info as orig_scheduler:
|
with orig_scheduler_info as orig_scheduler:
|
||||||
@ -209,34 +225,64 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
negative_conditioning: ConditioningField = InputField(
|
negative_conditioning: ConditioningField = InputField(
|
||||||
description=FieldDescriptions.negative_cond, input=Input.Connection, ui_order=1
|
description=FieldDescriptions.negative_cond, input=Input.Connection, ui_order=1
|
||||||
)
|
)
|
||||||
noise: Optional[LatentsField] = InputField(description=FieldDescriptions.noise, input=Input.Connection, ui_order=3)
|
noise: Optional[LatentsField] = InputField(
|
||||||
|
default=None,
|
||||||
|
description=FieldDescriptions.noise,
|
||||||
|
input=Input.Connection,
|
||||||
|
ui_order=3,
|
||||||
|
)
|
||||||
steps: int = InputField(default=10, gt=0, description=FieldDescriptions.steps)
|
steps: int = InputField(default=10, gt=0, description=FieldDescriptions.steps)
|
||||||
cfg_scale: Union[float, List[float]] = InputField(
|
cfg_scale: Union[float, List[float]] = InputField(
|
||||||
default=7.5, ge=1, description=FieldDescriptions.cfg_scale, title="CFG Scale"
|
default=7.5, ge=1, description=FieldDescriptions.cfg_scale, title="CFG Scale"
|
||||||
)
|
)
|
||||||
denoising_start: float = InputField(default=0.0, ge=0, le=1, description=FieldDescriptions.denoising_start)
|
denoising_start: float = InputField(
|
||||||
|
default=0.0,
|
||||||
|
ge=0,
|
||||||
|
le=1,
|
||||||
|
description=FieldDescriptions.denoising_start,
|
||||||
|
)
|
||||||
denoising_end: float = InputField(default=1.0, ge=0, le=1, description=FieldDescriptions.denoising_end)
|
denoising_end: float = InputField(default=1.0, ge=0, le=1, description=FieldDescriptions.denoising_end)
|
||||||
scheduler: SAMPLER_NAME_VALUES = InputField(
|
scheduler: SAMPLER_NAME_VALUES = InputField(
|
||||||
default="euler", description=FieldDescriptions.scheduler, ui_type=UIType.Scheduler
|
default="euler",
|
||||||
|
description=FieldDescriptions.scheduler,
|
||||||
|
ui_type=UIType.Scheduler,
|
||||||
)
|
)
|
||||||
unet: UNetField = InputField(description=FieldDescriptions.unet, input=Input.Connection, title="UNet", ui_order=2)
|
unet: UNetField = InputField(
|
||||||
control: Union[ControlField, list[ControlField]] = InputField(
|
description=FieldDescriptions.unet,
|
||||||
|
input=Input.Connection,
|
||||||
|
title="UNet",
|
||||||
|
ui_order=2,
|
||||||
|
)
|
||||||
|
control: Optional[Union[ControlField, list[ControlField]]] = InputField(
|
||||||
default=None,
|
default=None,
|
||||||
input=Input.Connection,
|
input=Input.Connection,
|
||||||
ui_order=5,
|
ui_order=5,
|
||||||
)
|
)
|
||||||
ip_adapter: Optional[Union[IPAdapterField, list[IPAdapterField]]] = InputField(
|
ip_adapter: Optional[Union[IPAdapterField, list[IPAdapterField]]] = InputField(
|
||||||
description=FieldDescriptions.ip_adapter, title="IP-Adapter", default=None, input=Input.Connection, ui_order=6
|
description=FieldDescriptions.ip_adapter,
|
||||||
|
title="IP-Adapter",
|
||||||
|
default=None,
|
||||||
|
input=Input.Connection,
|
||||||
|
ui_order=6,
|
||||||
)
|
)
|
||||||
t2i_adapter: Union[T2IAdapterField, list[T2IAdapterField]] = InputField(
|
t2i_adapter: Optional[Union[T2IAdapterField, list[T2IAdapterField]]] = InputField(
|
||||||
description=FieldDescriptions.t2i_adapter, title="T2I-Adapter", default=None, input=Input.Connection, ui_order=7
|
description=FieldDescriptions.t2i_adapter,
|
||||||
|
title="T2I-Adapter",
|
||||||
|
default=None,
|
||||||
|
input=Input.Connection,
|
||||||
|
ui_order=7,
|
||||||
|
)
|
||||||
|
latents: Optional[LatentsField] = InputField(
|
||||||
|
default=None, description=FieldDescriptions.latents, input=Input.Connection
|
||||||
)
|
)
|
||||||
latents: Optional[LatentsField] = InputField(description=FieldDescriptions.latents, input=Input.Connection)
|
|
||||||
denoise_mask: Optional[DenoiseMaskField] = InputField(
|
denoise_mask: Optional[DenoiseMaskField] = InputField(
|
||||||
default=None, description=FieldDescriptions.mask, input=Input.Connection, ui_order=8
|
default=None,
|
||||||
|
description=FieldDescriptions.mask,
|
||||||
|
input=Input.Connection,
|
||||||
|
ui_order=8,
|
||||||
)
|
)
|
||||||
|
|
||||||
@validator("cfg_scale")
|
@field_validator("cfg_scale")
|
||||||
def ge_one(cls, v):
|
def ge_one(cls, v):
|
||||||
"""validate that all cfg_scale values are >= 1"""
|
"""validate that all cfg_scale values are >= 1"""
|
||||||
if isinstance(v, list):
|
if isinstance(v, list):
|
||||||
@ -259,7 +305,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
stable_diffusion_step_callback(
|
stable_diffusion_step_callback(
|
||||||
context=context,
|
context=context,
|
||||||
intermediate_state=intermediate_state,
|
intermediate_state=intermediate_state,
|
||||||
node=self.dict(),
|
node=self.model_dump(),
|
||||||
source_node_id=source_node_id,
|
source_node_id=source_node_id,
|
||||||
base_model=base_model,
|
base_model=base_model,
|
||||||
)
|
)
|
||||||
@ -451,9 +497,10 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
# models are needed in memory. This would help to reduce peak memory utilization in low-memory environments.
|
# models are needed in memory. This would help to reduce peak memory utilization in low-memory environments.
|
||||||
with image_encoder_model_info as image_encoder_model:
|
with image_encoder_model_info as image_encoder_model:
|
||||||
# Get image embeddings from CLIP and ImageProjModel.
|
# Get image embeddings from CLIP and ImageProjModel.
|
||||||
image_prompt_embeds, uncond_image_prompt_embeds = ip_adapter_model.get_image_embeds(
|
(
|
||||||
input_image, image_encoder_model
|
image_prompt_embeds,
|
||||||
)
|
uncond_image_prompt_embeds,
|
||||||
|
) = ip_adapter_model.get_image_embeds(input_image, image_encoder_model)
|
||||||
conditioning_data.ip_adapter_conditioning.append(
|
conditioning_data.ip_adapter_conditioning.append(
|
||||||
IPAdapterConditioningInfo(image_prompt_embeds, uncond_image_prompt_embeds)
|
IPAdapterConditioningInfo(image_prompt_embeds, uncond_image_prompt_embeds)
|
||||||
)
|
)
|
||||||
@ -628,7 +675,10 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
# TODO(ryand): I have hard-coded `do_classifier_free_guidance=True` to mirror the behaviour of ControlNets,
|
# TODO(ryand): I have hard-coded `do_classifier_free_guidance=True` to mirror the behaviour of ControlNets,
|
||||||
# below. Investigate whether this is appropriate.
|
# below. Investigate whether this is appropriate.
|
||||||
t2i_adapter_data = self.run_t2i_adapters(
|
t2i_adapter_data = self.run_t2i_adapters(
|
||||||
context, self.t2i_adapter, latents.shape, do_classifier_free_guidance=True
|
context,
|
||||||
|
self.t2i_adapter,
|
||||||
|
latents.shape,
|
||||||
|
do_classifier_free_guidance=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get the source node id (we are invoking the prepared node)
|
# Get the source node id (we are invoking the prepared node)
|
||||||
@ -641,7 +691,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
def _lora_loader():
|
def _lora_loader():
|
||||||
for lora in self.unet.loras:
|
for lora in self.unet.loras:
|
||||||
lora_info = context.services.model_manager.get_model(
|
lora_info = context.services.model_manager.get_model(
|
||||||
**lora.dict(exclude={"weight"}),
|
**lora.model_dump(exclude={"weight"}),
|
||||||
context=context,
|
context=context,
|
||||||
)
|
)
|
||||||
yield (lora_info.context.model, lora.weight)
|
yield (lora_info.context.model, lora.weight)
|
||||||
@ -649,7 +699,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
return
|
return
|
||||||
|
|
||||||
unet_info = context.services.model_manager.get_model(
|
unet_info = context.services.model_manager.get_model(
|
||||||
**self.unet.unet.dict(),
|
**self.unet.unet.model_dump(),
|
||||||
context=context,
|
context=context,
|
||||||
)
|
)
|
||||||
with (
|
with (
|
||||||
@ -700,7 +750,10 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
denoising_end=self.denoising_end,
|
denoising_end=self.denoising_end,
|
||||||
)
|
)
|
||||||
|
|
||||||
result_latents, result_attention_map_saver = pipeline.latents_from_embeddings(
|
(
|
||||||
|
result_latents,
|
||||||
|
result_attention_map_saver,
|
||||||
|
) = pipeline.latents_from_embeddings(
|
||||||
latents=latents,
|
latents=latents,
|
||||||
timesteps=timesteps,
|
timesteps=timesteps,
|
||||||
init_timestep=init_timestep,
|
init_timestep=init_timestep,
|
||||||
@ -728,7 +781,11 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
|
|
||||||
|
|
||||||
@invocation(
|
@invocation(
|
||||||
"l2i", title="Latents to Image", tags=["latents", "image", "vae", "l2i"], category="latents", version="1.0.0"
|
"l2i",
|
||||||
|
title="Latents to Image",
|
||||||
|
tags=["latents", "image", "vae", "l2i"],
|
||||||
|
category="latents",
|
||||||
|
version="1.0.0",
|
||||||
)
|
)
|
||||||
class LatentsToImageInvocation(BaseInvocation):
|
class LatentsToImageInvocation(BaseInvocation):
|
||||||
"""Generates an image from latents."""
|
"""Generates an image from latents."""
|
||||||
@ -743,7 +800,7 @@ class LatentsToImageInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
tiled: bool = InputField(default=False, description=FieldDescriptions.tiled)
|
tiled: bool = InputField(default=False, description=FieldDescriptions.tiled)
|
||||||
fp32: bool = InputField(default=DEFAULT_PRECISION == "float32", description=FieldDescriptions.fp32)
|
fp32: bool = InputField(default=DEFAULT_PRECISION == "float32", description=FieldDescriptions.fp32)
|
||||||
metadata: CoreMetadata = InputField(
|
metadata: Optional[CoreMetadata] = InputField(
|
||||||
default=None,
|
default=None,
|
||||||
description=FieldDescriptions.core_metadata,
|
description=FieldDescriptions.core_metadata,
|
||||||
ui_hidden=True,
|
ui_hidden=True,
|
||||||
@ -754,7 +811,7 @@ class LatentsToImageInvocation(BaseInvocation):
|
|||||||
latents = context.services.latents.get(self.latents.latents_name)
|
latents = context.services.latents.get(self.latents.latents_name)
|
||||||
|
|
||||||
vae_info = context.services.model_manager.get_model(
|
vae_info = context.services.model_manager.get_model(
|
||||||
**self.vae.vae.dict(),
|
**self.vae.vae.model_dump(),
|
||||||
context=context,
|
context=context,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -816,7 +873,7 @@ class LatentsToImageInvocation(BaseInvocation):
|
|||||||
node_id=self.id,
|
node_id=self.id,
|
||||||
session_id=context.graph_execution_state_id,
|
session_id=context.graph_execution_state_id,
|
||||||
is_intermediate=self.is_intermediate,
|
is_intermediate=self.is_intermediate,
|
||||||
metadata=self.metadata.dict() if self.metadata else None,
|
metadata=self.metadata.model_dump() if self.metadata else None,
|
||||||
workflow=self.workflow,
|
workflow=self.workflow,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -830,7 +887,13 @@ class LatentsToImageInvocation(BaseInvocation):
|
|||||||
LATENTS_INTERPOLATION_MODE = Literal["nearest", "linear", "bilinear", "bicubic", "trilinear", "area", "nearest-exact"]
|
LATENTS_INTERPOLATION_MODE = Literal["nearest", "linear", "bilinear", "bicubic", "trilinear", "area", "nearest-exact"]
|
||||||
|
|
||||||
|
|
||||||
@invocation("lresize", title="Resize Latents", tags=["latents", "resize"], category="latents", version="1.0.0")
|
@invocation(
|
||||||
|
"lresize",
|
||||||
|
title="Resize Latents",
|
||||||
|
tags=["latents", "resize"],
|
||||||
|
category="latents",
|
||||||
|
version="1.0.0",
|
||||||
|
)
|
||||||
class ResizeLatentsInvocation(BaseInvocation):
|
class ResizeLatentsInvocation(BaseInvocation):
|
||||||
"""Resizes latents to explicit width/height (in pixels). Provided dimensions are floor-divided by 8."""
|
"""Resizes latents to explicit width/height (in pixels). Provided dimensions are floor-divided by 8."""
|
||||||
|
|
||||||
@ -876,7 +939,13 @@ class ResizeLatentsInvocation(BaseInvocation):
|
|||||||
return build_latents_output(latents_name=name, latents=resized_latents, seed=self.latents.seed)
|
return build_latents_output(latents_name=name, latents=resized_latents, seed=self.latents.seed)
|
||||||
|
|
||||||
|
|
||||||
@invocation("lscale", title="Scale Latents", tags=["latents", "resize"], category="latents", version="1.0.0")
|
@invocation(
|
||||||
|
"lscale",
|
||||||
|
title="Scale Latents",
|
||||||
|
tags=["latents", "resize"],
|
||||||
|
category="latents",
|
||||||
|
version="1.0.0",
|
||||||
|
)
|
||||||
class ScaleLatentsInvocation(BaseInvocation):
|
class ScaleLatentsInvocation(BaseInvocation):
|
||||||
"""Scales latents by a given factor."""
|
"""Scales latents by a given factor."""
|
||||||
|
|
||||||
@ -915,7 +984,11 @@ class ScaleLatentsInvocation(BaseInvocation):
|
|||||||
|
|
||||||
|
|
||||||
@invocation(
|
@invocation(
|
||||||
"i2l", title="Image to Latents", tags=["latents", "image", "vae", "i2l"], category="latents", version="1.0.0"
|
"i2l",
|
||||||
|
title="Image to Latents",
|
||||||
|
tags=["latents", "image", "vae", "i2l"],
|
||||||
|
category="latents",
|
||||||
|
version="1.0.0",
|
||||||
)
|
)
|
||||||
class ImageToLatentsInvocation(BaseInvocation):
|
class ImageToLatentsInvocation(BaseInvocation):
|
||||||
"""Encodes an image into latents."""
|
"""Encodes an image into latents."""
|
||||||
@ -979,7 +1052,7 @@ class ImageToLatentsInvocation(BaseInvocation):
|
|||||||
image = context.services.images.get_pil_image(self.image.image_name)
|
image = context.services.images.get_pil_image(self.image.image_name)
|
||||||
|
|
||||||
vae_info = context.services.model_manager.get_model(
|
vae_info = context.services.model_manager.get_model(
|
||||||
**self.vae.vae.dict(),
|
**self.vae.vae.model_dump(),
|
||||||
context=context,
|
context=context,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1007,7 +1080,13 @@ class ImageToLatentsInvocation(BaseInvocation):
|
|||||||
return vae.encode(image_tensor).latents
|
return vae.encode(image_tensor).latents
|
||||||
|
|
||||||
|
|
||||||
@invocation("lblend", title="Blend Latents", tags=["latents", "blend"], category="latents", version="1.0.0")
|
@invocation(
|
||||||
|
"lblend",
|
||||||
|
title="Blend Latents",
|
||||||
|
tags=["latents", "blend"],
|
||||||
|
category="latents",
|
||||||
|
version="1.0.0",
|
||||||
|
)
|
||||||
class BlendLatentsInvocation(BaseInvocation):
|
class BlendLatentsInvocation(BaseInvocation):
|
||||||
"""Blend two latents using a given alpha. Latents must have same size."""
|
"""Blend two latents using a given alpha. Latents must have same size."""
|
||||||
|
|
||||||
|
@ -3,7 +3,7 @@
|
|||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from pydantic import validator
|
from pydantic import field_validator
|
||||||
|
|
||||||
from invokeai.app.invocations.primitives import FloatOutput, IntegerOutput
|
from invokeai.app.invocations.primitives import FloatOutput, IntegerOutput
|
||||||
|
|
||||||
@ -72,7 +72,14 @@ class RandomIntInvocation(BaseInvocation):
|
|||||||
return IntegerOutput(value=np.random.randint(self.low, self.high))
|
return IntegerOutput(value=np.random.randint(self.low, self.high))
|
||||||
|
|
||||||
|
|
||||||
@invocation("rand_float", title="Random Float", tags=["math", "float", "random"], category="math", version="1.0.0")
|
@invocation(
|
||||||
|
"rand_float",
|
||||||
|
title="Random Float",
|
||||||
|
tags=["math", "float", "random"],
|
||||||
|
category="math",
|
||||||
|
version="1.0.1",
|
||||||
|
use_cache=False,
|
||||||
|
)
|
||||||
class RandomFloatInvocation(BaseInvocation):
|
class RandomFloatInvocation(BaseInvocation):
|
||||||
"""Outputs a single random float"""
|
"""Outputs a single random float"""
|
||||||
|
|
||||||
@ -178,7 +185,7 @@ class IntegerMathInvocation(BaseInvocation):
|
|||||||
a: int = InputField(default=0, description=FieldDescriptions.num_1)
|
a: int = InputField(default=0, description=FieldDescriptions.num_1)
|
||||||
b: int = InputField(default=0, description=FieldDescriptions.num_2)
|
b: int = InputField(default=0, description=FieldDescriptions.num_2)
|
||||||
|
|
||||||
@validator("b")
|
@field_validator("b")
|
||||||
def no_unrepresentable_results(cls, v, values):
|
def no_unrepresentable_results(cls, v, values):
|
||||||
if values["operation"] == "DIV" and v == 0:
|
if values["operation"] == "DIV" and v == 0:
|
||||||
raise ValueError("Cannot divide by zero")
|
raise ValueError("Cannot divide by zero")
|
||||||
@ -252,7 +259,7 @@ class FloatMathInvocation(BaseInvocation):
|
|||||||
a: float = InputField(default=0, description=FieldDescriptions.num_1)
|
a: float = InputField(default=0, description=FieldDescriptions.num_1)
|
||||||
b: float = InputField(default=0, description=FieldDescriptions.num_2)
|
b: float = InputField(default=0, description=FieldDescriptions.num_2)
|
||||||
|
|
||||||
@validator("b")
|
@field_validator("b")
|
||||||
def no_unrepresentable_results(cls, v, values):
|
def no_unrepresentable_results(cls, v, values):
|
||||||
if values["operation"] == "DIV" and v == 0:
|
if values["operation"] == "DIV" and v == 0:
|
||||||
raise ValueError("Cannot divide by zero")
|
raise ValueError("Cannot divide by zero")
|
||||||
|
@ -223,4 +223,4 @@ class MetadataAccumulatorInvocation(BaseInvocation):
|
|||||||
def invoke(self, context: InvocationContext) -> MetadataAccumulatorOutput:
|
def invoke(self, context: InvocationContext) -> MetadataAccumulatorOutput:
|
||||||
"""Collects and outputs a CoreMetadata object"""
|
"""Collects and outputs a CoreMetadata object"""
|
||||||
|
|
||||||
return MetadataAccumulatorOutput(metadata=CoreMetadata(**self.dict()))
|
return MetadataAccumulatorOutput(metadata=CoreMetadata(**self.model_dump()))
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import copy
|
import copy
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
|
|
||||||
from ...backend.model_management import BaseModelType, ModelType, SubModelType
|
from ...backend.model_management import BaseModelType, ModelType, SubModelType
|
||||||
from .baseinvocation import (
|
from .baseinvocation import (
|
||||||
@ -24,6 +24,8 @@ class ModelInfo(BaseModel):
|
|||||||
model_type: ModelType = Field(description="Info to load submodel")
|
model_type: ModelType = Field(description="Info to load submodel")
|
||||||
submodel: Optional[SubModelType] = Field(default=None, description="Info to load submodel")
|
submodel: Optional[SubModelType] = Field(default=None, description="Info to load submodel")
|
||||||
|
|
||||||
|
model_config = ConfigDict(protected_namespaces=())
|
||||||
|
|
||||||
|
|
||||||
class LoraInfo(ModelInfo):
|
class LoraInfo(ModelInfo):
|
||||||
weight: float = Field(description="Lora's weight which to use when apply to model")
|
weight: float = Field(description="Lora's weight which to use when apply to model")
|
||||||
@ -65,6 +67,8 @@ class MainModelField(BaseModel):
|
|||||||
base_model: BaseModelType = Field(description="Base model")
|
base_model: BaseModelType = Field(description="Base model")
|
||||||
model_type: ModelType = Field(description="Model Type")
|
model_type: ModelType = Field(description="Model Type")
|
||||||
|
|
||||||
|
model_config = ConfigDict(protected_namespaces=())
|
||||||
|
|
||||||
|
|
||||||
class LoRAModelField(BaseModel):
|
class LoRAModelField(BaseModel):
|
||||||
"""LoRA model field"""
|
"""LoRA model field"""
|
||||||
@ -72,8 +76,16 @@ class LoRAModelField(BaseModel):
|
|||||||
model_name: str = Field(description="Name of the LoRA model")
|
model_name: str = Field(description="Name of the LoRA model")
|
||||||
base_model: BaseModelType = Field(description="Base model")
|
base_model: BaseModelType = Field(description="Base model")
|
||||||
|
|
||||||
|
model_config = ConfigDict(protected_namespaces=())
|
||||||
|
|
||||||
@invocation("main_model_loader", title="Main Model", tags=["model"], category="model", version="1.0.0")
|
|
||||||
|
@invocation(
|
||||||
|
"main_model_loader",
|
||||||
|
title="Main Model",
|
||||||
|
tags=["model"],
|
||||||
|
category="model",
|
||||||
|
version="1.0.0",
|
||||||
|
)
|
||||||
class MainModelLoaderInvocation(BaseInvocation):
|
class MainModelLoaderInvocation(BaseInvocation):
|
||||||
"""Loads a main model, outputting its submodels."""
|
"""Loads a main model, outputting its submodels."""
|
||||||
|
|
||||||
@ -180,10 +192,16 @@ class LoraLoaderInvocation(BaseInvocation):
|
|||||||
lora: LoRAModelField = InputField(description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA")
|
lora: LoRAModelField = InputField(description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA")
|
||||||
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
|
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
|
||||||
unet: Optional[UNetField] = InputField(
|
unet: Optional[UNetField] = InputField(
|
||||||
default=None, description=FieldDescriptions.unet, input=Input.Connection, title="UNet"
|
default=None,
|
||||||
|
description=FieldDescriptions.unet,
|
||||||
|
input=Input.Connection,
|
||||||
|
title="UNet",
|
||||||
)
|
)
|
||||||
clip: Optional[ClipField] = InputField(
|
clip: Optional[ClipField] = InputField(
|
||||||
default=None, description=FieldDescriptions.clip, input=Input.Connection, title="CLIP"
|
default=None,
|
||||||
|
description=FieldDescriptions.clip,
|
||||||
|
input=Input.Connection,
|
||||||
|
title="CLIP",
|
||||||
)
|
)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> LoraLoaderOutput:
|
def invoke(self, context: InvocationContext) -> LoraLoaderOutput:
|
||||||
@ -244,20 +262,35 @@ class SDXLLoraLoaderOutput(BaseInvocationOutput):
|
|||||||
clip2: Optional[ClipField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP 2")
|
clip2: Optional[ClipField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP 2")
|
||||||
|
|
||||||
|
|
||||||
@invocation("sdxl_lora_loader", title="SDXL LoRA", tags=["lora", "model"], category="model", version="1.0.0")
|
@invocation(
|
||||||
|
"sdxl_lora_loader",
|
||||||
|
title="SDXL LoRA",
|
||||||
|
tags=["lora", "model"],
|
||||||
|
category="model",
|
||||||
|
version="1.0.0",
|
||||||
|
)
|
||||||
class SDXLLoraLoaderInvocation(BaseInvocation):
|
class SDXLLoraLoaderInvocation(BaseInvocation):
|
||||||
"""Apply selected lora to unet and text_encoder."""
|
"""Apply selected lora to unet and text_encoder."""
|
||||||
|
|
||||||
lora: LoRAModelField = InputField(description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA")
|
lora: LoRAModelField = InputField(description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA")
|
||||||
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
|
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
|
||||||
unet: Optional[UNetField] = InputField(
|
unet: Optional[UNetField] = InputField(
|
||||||
default=None, description=FieldDescriptions.unet, input=Input.Connection, title="UNet"
|
default=None,
|
||||||
|
description=FieldDescriptions.unet,
|
||||||
|
input=Input.Connection,
|
||||||
|
title="UNet",
|
||||||
)
|
)
|
||||||
clip: Optional[ClipField] = InputField(
|
clip: Optional[ClipField] = InputField(
|
||||||
default=None, description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 1"
|
default=None,
|
||||||
|
description=FieldDescriptions.clip,
|
||||||
|
input=Input.Connection,
|
||||||
|
title="CLIP 1",
|
||||||
)
|
)
|
||||||
clip2: Optional[ClipField] = InputField(
|
clip2: Optional[ClipField] = InputField(
|
||||||
default=None, description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 2"
|
default=None,
|
||||||
|
description=FieldDescriptions.clip,
|
||||||
|
input=Input.Connection,
|
||||||
|
title="CLIP 2",
|
||||||
)
|
)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> SDXLLoraLoaderOutput:
|
def invoke(self, context: InvocationContext) -> SDXLLoraLoaderOutput:
|
||||||
@ -330,6 +363,8 @@ class VAEModelField(BaseModel):
|
|||||||
model_name: str = Field(description="Name of the model")
|
model_name: str = Field(description="Name of the model")
|
||||||
base_model: BaseModelType = Field(description="Base model")
|
base_model: BaseModelType = Field(description="Base model")
|
||||||
|
|
||||||
|
model_config = ConfigDict(protected_namespaces=())
|
||||||
|
|
||||||
|
|
||||||
@invocation_output("vae_loader_output")
|
@invocation_output("vae_loader_output")
|
||||||
class VaeLoaderOutput(BaseInvocationOutput):
|
class VaeLoaderOutput(BaseInvocationOutput):
|
||||||
@ -343,7 +378,10 @@ class VaeLoaderInvocation(BaseInvocation):
|
|||||||
"""Loads a VAE model, outputting a VaeLoaderOutput"""
|
"""Loads a VAE model, outputting a VaeLoaderOutput"""
|
||||||
|
|
||||||
vae_model: VAEModelField = InputField(
|
vae_model: VAEModelField = InputField(
|
||||||
description=FieldDescriptions.vae_model, input=Input.Direct, ui_type=UIType.VaeModel, title="VAE"
|
description=FieldDescriptions.vae_model,
|
||||||
|
input=Input.Direct,
|
||||||
|
ui_type=UIType.VaeModel,
|
||||||
|
title="VAE",
|
||||||
)
|
)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> VaeLoaderOutput:
|
def invoke(self, context: InvocationContext) -> VaeLoaderOutput:
|
||||||
@ -372,19 +410,31 @@ class VaeLoaderInvocation(BaseInvocation):
|
|||||||
class SeamlessModeOutput(BaseInvocationOutput):
|
class SeamlessModeOutput(BaseInvocationOutput):
|
||||||
"""Modified Seamless Model output"""
|
"""Modified Seamless Model output"""
|
||||||
|
|
||||||
unet: Optional[UNetField] = OutputField(description=FieldDescriptions.unet, title="UNet")
|
unet: Optional[UNetField] = OutputField(default=None, description=FieldDescriptions.unet, title="UNet")
|
||||||
vae: Optional[VaeField] = OutputField(description=FieldDescriptions.vae, title="VAE")
|
vae: Optional[VaeField] = OutputField(default=None, description=FieldDescriptions.vae, title="VAE")
|
||||||
|
|
||||||
|
|
||||||
@invocation("seamless", title="Seamless", tags=["seamless", "model"], category="model", version="1.0.0")
|
@invocation(
|
||||||
|
"seamless",
|
||||||
|
title="Seamless",
|
||||||
|
tags=["seamless", "model"],
|
||||||
|
category="model",
|
||||||
|
version="1.0.0",
|
||||||
|
)
|
||||||
class SeamlessModeInvocation(BaseInvocation):
|
class SeamlessModeInvocation(BaseInvocation):
|
||||||
"""Applies the seamless transformation to the Model UNet and VAE."""
|
"""Applies the seamless transformation to the Model UNet and VAE."""
|
||||||
|
|
||||||
unet: Optional[UNetField] = InputField(
|
unet: Optional[UNetField] = InputField(
|
||||||
default=None, description=FieldDescriptions.unet, input=Input.Connection, title="UNet"
|
default=None,
|
||||||
|
description=FieldDescriptions.unet,
|
||||||
|
input=Input.Connection,
|
||||||
|
title="UNet",
|
||||||
)
|
)
|
||||||
vae: Optional[VaeField] = InputField(
|
vae: Optional[VaeField] = InputField(
|
||||||
default=None, description=FieldDescriptions.vae_model, input=Input.Connection, title="VAE"
|
default=None,
|
||||||
|
description=FieldDescriptions.vae_model,
|
||||||
|
input=Input.Connection,
|
||||||
|
title="VAE",
|
||||||
)
|
)
|
||||||
seamless_y: bool = InputField(default=True, input=Input.Any, description="Specify whether Y axis is seamless")
|
seamless_y: bool = InputField(default=True, input=Input.Any, description="Specify whether Y axis is seamless")
|
||||||
seamless_x: bool = InputField(default=True, input=Input.Any, description="Specify whether X axis is seamless")
|
seamless_x: bool = InputField(default=True, input=Input.Any, description="Specify whether X axis is seamless")
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from pydantic import validator
|
from pydantic import field_validator
|
||||||
|
|
||||||
from invokeai.app.invocations.latent import LatentsField
|
from invokeai.app.invocations.latent import LatentsField
|
||||||
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
||||||
@ -65,7 +65,7 @@ Nodes
|
|||||||
class NoiseOutput(BaseInvocationOutput):
|
class NoiseOutput(BaseInvocationOutput):
|
||||||
"""Invocation noise output"""
|
"""Invocation noise output"""
|
||||||
|
|
||||||
noise: LatentsField = OutputField(default=None, description=FieldDescriptions.noise)
|
noise: LatentsField = OutputField(description=FieldDescriptions.noise)
|
||||||
width: int = OutputField(description=FieldDescriptions.width)
|
width: int = OutputField(description=FieldDescriptions.width)
|
||||||
height: int = OutputField(description=FieldDescriptions.height)
|
height: int = OutputField(description=FieldDescriptions.height)
|
||||||
|
|
||||||
@ -78,7 +78,13 @@ def build_noise_output(latents_name: str, latents: torch.Tensor, seed: int):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@invocation("noise", title="Noise", tags=["latents", "noise"], category="latents", version="1.0.0")
|
@invocation(
|
||||||
|
"noise",
|
||||||
|
title="Noise",
|
||||||
|
tags=["latents", "noise"],
|
||||||
|
category="latents",
|
||||||
|
version="1.0.0",
|
||||||
|
)
|
||||||
class NoiseInvocation(BaseInvocation):
|
class NoiseInvocation(BaseInvocation):
|
||||||
"""Generates latent noise."""
|
"""Generates latent noise."""
|
||||||
|
|
||||||
@ -105,7 +111,7 @@ class NoiseInvocation(BaseInvocation):
|
|||||||
description="Use CPU for noise generation (for reproducible results across platforms)",
|
description="Use CPU for noise generation (for reproducible results across platforms)",
|
||||||
)
|
)
|
||||||
|
|
||||||
@validator("seed", pre=True)
|
@field_validator("seed", mode="before")
|
||||||
def modulo_seed(cls, v):
|
def modulo_seed(cls, v):
|
||||||
"""Returns the seed modulo (SEED_MAX + 1) to ensure it is within the valid range."""
|
"""Returns the seed modulo (SEED_MAX + 1) to ensure it is within the valid range."""
|
||||||
return v % (SEED_MAX + 1)
|
return v % (SEED_MAX + 1)
|
||||||
|
@ -9,7 +9,7 @@ from typing import List, Literal, Optional, Union
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from diffusers.image_processor import VaeImageProcessor
|
from diffusers.image_processor import VaeImageProcessor
|
||||||
from pydantic import BaseModel, Field, validator
|
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from invokeai.app.invocations.metadata import CoreMetadata
|
from invokeai.app.invocations.metadata import CoreMetadata
|
||||||
@ -63,14 +63,17 @@ class ONNXPromptInvocation(BaseInvocation):
|
|||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
||||||
tokenizer_info = context.services.model_manager.get_model(
|
tokenizer_info = context.services.model_manager.get_model(
|
||||||
**self.clip.tokenizer.dict(),
|
**self.clip.tokenizer.model_dump(),
|
||||||
)
|
)
|
||||||
text_encoder_info = context.services.model_manager.get_model(
|
text_encoder_info = context.services.model_manager.get_model(
|
||||||
**self.clip.text_encoder.dict(),
|
**self.clip.text_encoder.model_dump(),
|
||||||
)
|
)
|
||||||
with tokenizer_info as orig_tokenizer, text_encoder_info as text_encoder: # , ExitStack() as stack:
|
with tokenizer_info as orig_tokenizer, text_encoder_info as text_encoder: # , ExitStack() as stack:
|
||||||
loras = [
|
loras = [
|
||||||
(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight)
|
(
|
||||||
|
context.services.model_manager.get_model(**lora.model_dump(exclude={"weight"})).context.model,
|
||||||
|
lora.weight,
|
||||||
|
)
|
||||||
for lora in self.clip.loras
|
for lora in self.clip.loras
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -175,14 +178,14 @@ class ONNXTextToLatentsInvocation(BaseInvocation):
|
|||||||
description=FieldDescriptions.unet,
|
description=FieldDescriptions.unet,
|
||||||
input=Input.Connection,
|
input=Input.Connection,
|
||||||
)
|
)
|
||||||
control: Optional[Union[ControlField, list[ControlField]]] = InputField(
|
control: Union[ControlField, list[ControlField]] = InputField(
|
||||||
default=None,
|
default=None,
|
||||||
description=FieldDescriptions.control,
|
description=FieldDescriptions.control,
|
||||||
)
|
)
|
||||||
# seamless: bool = InputField(default=False, description="Whether or not to generate an image that can tile without seams", )
|
# seamless: bool = InputField(default=False, description="Whether or not to generate an image that can tile without seams", )
|
||||||
# seamless_axes: str = InputField(default="", description="The axes to tile the image on, 'x' and/or 'y'")
|
# seamless_axes: str = InputField(default="", description="The axes to tile the image on, 'x' and/or 'y'")
|
||||||
|
|
||||||
@validator("cfg_scale")
|
@field_validator("cfg_scale")
|
||||||
def ge_one(cls, v):
|
def ge_one(cls, v):
|
||||||
"""validate that all cfg_scale values are >= 1"""
|
"""validate that all cfg_scale values are >= 1"""
|
||||||
if isinstance(v, list):
|
if isinstance(v, list):
|
||||||
@ -241,7 +244,7 @@ class ONNXTextToLatentsInvocation(BaseInvocation):
|
|||||||
stable_diffusion_step_callback(
|
stable_diffusion_step_callback(
|
||||||
context=context,
|
context=context,
|
||||||
intermediate_state=intermediate_state,
|
intermediate_state=intermediate_state,
|
||||||
node=self.dict(),
|
node=self.model_dump(),
|
||||||
source_node_id=source_node_id,
|
source_node_id=source_node_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -254,12 +257,15 @@ class ONNXTextToLatentsInvocation(BaseInvocation):
|
|||||||
eta=0.0,
|
eta=0.0,
|
||||||
)
|
)
|
||||||
|
|
||||||
unet_info = context.services.model_manager.get_model(**self.unet.unet.dict())
|
unet_info = context.services.model_manager.get_model(**self.unet.unet.model_dump())
|
||||||
|
|
||||||
with unet_info as unet: # , ExitStack() as stack:
|
with unet_info as unet: # , ExitStack() as stack:
|
||||||
# loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.unet.loras]
|
# loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.unet.loras]
|
||||||
loras = [
|
loras = [
|
||||||
(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight)
|
(
|
||||||
|
context.services.model_manager.get_model(**lora.model_dump(exclude={"weight"})).context.model,
|
||||||
|
lora.weight,
|
||||||
|
)
|
||||||
for lora in self.unet.loras
|
for lora in self.unet.loras
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -346,7 +352,7 @@ class ONNXLatentsToImageInvocation(BaseInvocation):
|
|||||||
raise Exception(f"Expected vae_decoder, found: {self.vae.vae.model_type}")
|
raise Exception(f"Expected vae_decoder, found: {self.vae.vae.model_type}")
|
||||||
|
|
||||||
vae_info = context.services.model_manager.get_model(
|
vae_info = context.services.model_manager.get_model(
|
||||||
**self.vae.vae.dict(),
|
**self.vae.vae.model_dump(),
|
||||||
)
|
)
|
||||||
|
|
||||||
# clear memory as vae decode can request a lot
|
# clear memory as vae decode can request a lot
|
||||||
@ -375,7 +381,7 @@ class ONNXLatentsToImageInvocation(BaseInvocation):
|
|||||||
node_id=self.id,
|
node_id=self.id,
|
||||||
session_id=context.graph_execution_state_id,
|
session_id=context.graph_execution_state_id,
|
||||||
is_intermediate=self.is_intermediate,
|
is_intermediate=self.is_intermediate,
|
||||||
metadata=self.metadata.dict() if self.metadata else None,
|
metadata=self.metadata.model_dump() if self.metadata else None,
|
||||||
workflow=self.workflow,
|
workflow=self.workflow,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -403,6 +409,8 @@ class OnnxModelField(BaseModel):
|
|||||||
base_model: BaseModelType = Field(description="Base model")
|
base_model: BaseModelType = Field(description="Base model")
|
||||||
model_type: ModelType = Field(description="Model Type")
|
model_type: ModelType = Field(description="Model Type")
|
||||||
|
|
||||||
|
model_config = ConfigDict(protected_namespaces=())
|
||||||
|
|
||||||
|
|
||||||
@invocation("onnx_model_loader", title="ONNX Main Model", tags=["onnx", "model"], category="model", version="1.0.0")
|
@invocation("onnx_model_loader", title="ONNX Main Model", tags=["onnx", "model"], category="model", version="1.0.0")
|
||||||
class OnnxModelLoaderInvocation(BaseInvocation):
|
class OnnxModelLoaderInvocation(BaseInvocation):
|
||||||
|
@ -44,13 +44,22 @@ from invokeai.app.invocations.primitives import FloatCollectionOutput
|
|||||||
from .baseinvocation import BaseInvocation, InputField, InvocationContext, invocation
|
from .baseinvocation import BaseInvocation, InputField, InvocationContext, invocation
|
||||||
|
|
||||||
|
|
||||||
@invocation("float_range", title="Float Range", tags=["math", "range"], category="math", version="1.0.0")
|
@invocation(
|
||||||
|
"float_range",
|
||||||
|
title="Float Range",
|
||||||
|
tags=["math", "range"],
|
||||||
|
category="math",
|
||||||
|
version="1.0.0",
|
||||||
|
)
|
||||||
class FloatLinearRangeInvocation(BaseInvocation):
|
class FloatLinearRangeInvocation(BaseInvocation):
|
||||||
"""Creates a range"""
|
"""Creates a range"""
|
||||||
|
|
||||||
start: float = InputField(default=5, description="The first value of the range")
|
start: float = InputField(default=5, description="The first value of the range")
|
||||||
stop: float = InputField(default=10, description="The last value of the range")
|
stop: float = InputField(default=10, description="The last value of the range")
|
||||||
steps: int = InputField(default=30, description="number of values to interpolate over (including start and stop)")
|
steps: int = InputField(
|
||||||
|
default=30,
|
||||||
|
description="number of values to interpolate over (including start and stop)",
|
||||||
|
)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> FloatCollectionOutput:
|
def invoke(self, context: InvocationContext) -> FloatCollectionOutput:
|
||||||
param_list = list(np.linspace(self.start, self.stop, self.steps))
|
param_list = list(np.linspace(self.start, self.stop, self.steps))
|
||||||
@ -95,7 +104,13 @@ EASING_FUNCTION_KEYS = Literal[tuple(list(EASING_FUNCTIONS_MAP.keys()))]
|
|||||||
|
|
||||||
|
|
||||||
# actually I think for now could just use CollectionOutput (which is list[Any]
|
# actually I think for now could just use CollectionOutput (which is list[Any]
|
||||||
@invocation("step_param_easing", title="Step Param Easing", tags=["step", "easing"], category="step", version="1.0.0")
|
@invocation(
|
||||||
|
"step_param_easing",
|
||||||
|
title="Step Param Easing",
|
||||||
|
tags=["step", "easing"],
|
||||||
|
category="step",
|
||||||
|
version="1.0.0",
|
||||||
|
)
|
||||||
class StepParamEasingInvocation(BaseInvocation):
|
class StepParamEasingInvocation(BaseInvocation):
|
||||||
"""Experimental per-step parameter easing for denoising steps"""
|
"""Experimental per-step parameter easing for denoising steps"""
|
||||||
|
|
||||||
@ -159,7 +174,9 @@ class StepParamEasingInvocation(BaseInvocation):
|
|||||||
context.services.logger.debug("base easing duration: " + str(base_easing_duration))
|
context.services.logger.debug("base easing duration: " + str(base_easing_duration))
|
||||||
even_num_steps = num_easing_steps % 2 == 0 # even number of steps
|
even_num_steps = num_easing_steps % 2 == 0 # even number of steps
|
||||||
easing_function = easing_class(
|
easing_function = easing_class(
|
||||||
start=self.start_value, end=self.end_value, duration=base_easing_duration - 1
|
start=self.start_value,
|
||||||
|
end=self.end_value,
|
||||||
|
duration=base_easing_duration - 1,
|
||||||
)
|
)
|
||||||
base_easing_vals = list()
|
base_easing_vals = list()
|
||||||
for step_index in range(base_easing_duration):
|
for step_index in range(base_easing_duration):
|
||||||
@ -199,7 +216,11 @@ class StepParamEasingInvocation(BaseInvocation):
|
|||||||
#
|
#
|
||||||
|
|
||||||
else: # no mirroring (default)
|
else: # no mirroring (default)
|
||||||
easing_function = easing_class(start=self.start_value, end=self.end_value, duration=num_easing_steps - 1)
|
easing_function = easing_class(
|
||||||
|
start=self.start_value,
|
||||||
|
end=self.end_value,
|
||||||
|
duration=num_easing_steps - 1,
|
||||||
|
)
|
||||||
for step_index in range(num_easing_steps):
|
for step_index in range(num_easing_steps):
|
||||||
step_val = easing_function.ease(step_index)
|
step_val = easing_function.ease(step_index)
|
||||||
easing_list.append(step_val)
|
easing_list.append(step_val)
|
||||||
|
@ -3,7 +3,7 @@ from typing import Optional, Union
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from dynamicprompts.generators import CombinatorialPromptGenerator, RandomPromptGenerator
|
from dynamicprompts.generators import CombinatorialPromptGenerator, RandomPromptGenerator
|
||||||
from pydantic import validator
|
from pydantic import field_validator
|
||||||
|
|
||||||
from invokeai.app.invocations.primitives import StringCollectionOutput
|
from invokeai.app.invocations.primitives import StringCollectionOutput
|
||||||
|
|
||||||
@ -21,7 +21,10 @@ from .baseinvocation import BaseInvocation, InputField, InvocationContext, UICom
|
|||||||
class DynamicPromptInvocation(BaseInvocation):
|
class DynamicPromptInvocation(BaseInvocation):
|
||||||
"""Parses a prompt using adieyal/dynamicprompts' random or combinatorial generator"""
|
"""Parses a prompt using adieyal/dynamicprompts' random or combinatorial generator"""
|
||||||
|
|
||||||
prompt: str = InputField(description="The prompt to parse with dynamicprompts", ui_component=UIComponent.Textarea)
|
prompt: str = InputField(
|
||||||
|
description="The prompt to parse with dynamicprompts",
|
||||||
|
ui_component=UIComponent.Textarea,
|
||||||
|
)
|
||||||
max_prompts: int = InputField(default=1, description="The number of prompts to generate")
|
max_prompts: int = InputField(default=1, description="The number of prompts to generate")
|
||||||
combinatorial: bool = InputField(default=False, description="Whether to use the combinatorial generator")
|
combinatorial: bool = InputField(default=False, description="Whether to use the combinatorial generator")
|
||||||
|
|
||||||
@ -36,21 +39,31 @@ class DynamicPromptInvocation(BaseInvocation):
|
|||||||
return StringCollectionOutput(collection=prompts)
|
return StringCollectionOutput(collection=prompts)
|
||||||
|
|
||||||
|
|
||||||
@invocation("prompt_from_file", title="Prompts from File", tags=["prompt", "file"], category="prompt", version="1.0.0")
|
@invocation(
|
||||||
|
"prompt_from_file",
|
||||||
|
title="Prompts from File",
|
||||||
|
tags=["prompt", "file"],
|
||||||
|
category="prompt",
|
||||||
|
version="1.0.0",
|
||||||
|
)
|
||||||
class PromptsFromFileInvocation(BaseInvocation):
|
class PromptsFromFileInvocation(BaseInvocation):
|
||||||
"""Loads prompts from a text file"""
|
"""Loads prompts from a text file"""
|
||||||
|
|
||||||
file_path: str = InputField(description="Path to prompt text file")
|
file_path: str = InputField(description="Path to prompt text file")
|
||||||
pre_prompt: Optional[str] = InputField(
|
pre_prompt: Optional[str] = InputField(
|
||||||
default=None, description="String to prepend to each prompt", ui_component=UIComponent.Textarea
|
default=None,
|
||||||
|
description="String to prepend to each prompt",
|
||||||
|
ui_component=UIComponent.Textarea,
|
||||||
)
|
)
|
||||||
post_prompt: Optional[str] = InputField(
|
post_prompt: Optional[str] = InputField(
|
||||||
default=None, description="String to append to each prompt", ui_component=UIComponent.Textarea
|
default=None,
|
||||||
|
description="String to append to each prompt",
|
||||||
|
ui_component=UIComponent.Textarea,
|
||||||
)
|
)
|
||||||
start_line: int = InputField(default=1, ge=1, description="Line in the file to start start from")
|
start_line: int = InputField(default=1, ge=1, description="Line in the file to start start from")
|
||||||
max_prompts: int = InputField(default=1, ge=0, description="Max lines to read from file (0=all)")
|
max_prompts: int = InputField(default=1, ge=0, description="Max lines to read from file (0=all)")
|
||||||
|
|
||||||
@validator("file_path")
|
@field_validator("file_path")
|
||||||
def file_path_exists(cls, v):
|
def file_path_exists(cls, v):
|
||||||
if not exists(v):
|
if not exists(v):
|
||||||
raise ValueError(FileNotFoundError)
|
raise ValueError(FileNotFoundError)
|
||||||
@ -79,6 +92,10 @@ class PromptsFromFileInvocation(BaseInvocation):
|
|||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> StringCollectionOutput:
|
def invoke(self, context: InvocationContext) -> StringCollectionOutput:
|
||||||
prompts = self.promptsFromFile(
|
prompts = self.promptsFromFile(
|
||||||
self.file_path, self.pre_prompt, self.post_prompt, self.start_line, self.max_prompts
|
self.file_path,
|
||||||
|
self.pre_prompt,
|
||||||
|
self.post_prompt,
|
||||||
|
self.start_line,
|
||||||
|
self.max_prompts,
|
||||||
)
|
)
|
||||||
return StringCollectionOutput(collection=prompts)
|
return StringCollectionOutput(collection=prompts)
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
|
|
||||||
from invokeai.app.invocations.baseinvocation import (
|
from invokeai.app.invocations.baseinvocation import (
|
||||||
BaseInvocation,
|
BaseInvocation,
|
||||||
@ -23,6 +23,8 @@ class T2IAdapterModelField(BaseModel):
|
|||||||
model_name: str = Field(description="Name of the T2I-Adapter model")
|
model_name: str = Field(description="Name of the T2I-Adapter model")
|
||||||
base_model: BaseModelType = Field(description="Base model")
|
base_model: BaseModelType = Field(description="Base model")
|
||||||
|
|
||||||
|
model_config = ConfigDict(protected_namespaces=())
|
||||||
|
|
||||||
|
|
||||||
class T2IAdapterField(BaseModel):
|
class T2IAdapterField(BaseModel):
|
||||||
image: ImageField = Field(description="The T2I-Adapter image prompt.")
|
image: ImageField = Field(description="The T2I-Adapter image prompt.")
|
||||||
|
@ -7,6 +7,7 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
from basicsr.archs.rrdbnet_arch import RRDBNet
|
from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
from pydantic import ConfigDict
|
||||||
from realesrgan import RealESRGANer
|
from realesrgan import RealESRGANer
|
||||||
|
|
||||||
from invokeai.app.invocations.primitives import ImageField, ImageOutput
|
from invokeai.app.invocations.primitives import ImageField, ImageOutput
|
||||||
@ -38,6 +39,8 @@ class ESRGANInvocation(BaseInvocation):
|
|||||||
default=400, ge=0, description="Tile size for tiled ESRGAN upscaling (0=tiling disabled)"
|
default=400, ge=0, description="Tile size for tiled ESRGAN upscaling (0=tiling disabled)"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
model_config = ConfigDict(protected_namespaces=())
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image = context.services.images.get_pil_image(self.image.image_name)
|
image = context.services.images.get_pil_image(self.image.image_name)
|
||||||
models_path = context.services.configuration.models_path
|
models_path = context.services.configuration.models_path
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
from pydantic import BaseModel, Extra, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from invokeai.app.util.misc import get_iso_timestamp
|
from invokeai.app.util.misc import get_iso_timestamp
|
||||||
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
|
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
|
||||||
@ -18,9 +18,9 @@ class BoardRecord(BaseModelExcludeNull):
|
|||||||
"""The created timestamp of the image."""
|
"""The created timestamp of the image."""
|
||||||
updated_at: Union[datetime, str] = Field(description="The updated timestamp of the board.")
|
updated_at: Union[datetime, str] = Field(description="The updated timestamp of the board.")
|
||||||
"""The updated timestamp of the image."""
|
"""The updated timestamp of the image."""
|
||||||
deleted_at: Union[datetime, str, None] = Field(description="The deleted timestamp of the board.")
|
deleted_at: Optional[Union[datetime, str]] = Field(default=None, description="The deleted timestamp of the board.")
|
||||||
"""The updated timestamp of the image."""
|
"""The updated timestamp of the image."""
|
||||||
cover_image_name: Optional[str] = Field(description="The name of the cover image of the board.")
|
cover_image_name: Optional[str] = Field(default=None, description="The name of the cover image of the board.")
|
||||||
"""The name of the cover image of the board."""
|
"""The name of the cover image of the board."""
|
||||||
|
|
||||||
|
|
||||||
@ -46,9 +46,9 @@ def deserialize_board_record(board_dict: dict) -> BoardRecord:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class BoardChanges(BaseModel, extra=Extra.forbid):
|
class BoardChanges(BaseModel, extra="forbid"):
|
||||||
board_name: Optional[str] = Field(description="The board's new name.")
|
board_name: Optional[str] = Field(default=None, description="The board's new name.")
|
||||||
cover_image_name: Optional[str] = Field(description="The name of the board's new cover image.")
|
cover_image_name: Optional[str] = Field(default=None, description="The name of the board's new cover image.")
|
||||||
|
|
||||||
|
|
||||||
class BoardRecordNotFoundException(Exception):
|
class BoardRecordNotFoundException(Exception):
|
||||||
|
@ -17,7 +17,7 @@ class BoardDTO(BoardRecord):
|
|||||||
def board_record_to_dto(board_record: BoardRecord, cover_image_name: Optional[str], image_count: int) -> BoardDTO:
|
def board_record_to_dto(board_record: BoardRecord, cover_image_name: Optional[str], image_count: int) -> BoardDTO:
|
||||||
"""Converts a board record to a board DTO."""
|
"""Converts a board record to a board DTO."""
|
||||||
return BoardDTO(
|
return BoardDTO(
|
||||||
**board_record.dict(exclude={"cover_image_name"}),
|
**board_record.model_dump(exclude={"cover_image_name"}),
|
||||||
cover_image_name=cover_image_name,
|
cover_image_name=cover_image_name,
|
||||||
image_count=image_count,
|
image_count=image_count,
|
||||||
)
|
)
|
||||||
|
@ -18,7 +18,7 @@ from pathlib import Path
|
|||||||
from typing import ClassVar, Dict, List, Literal, Optional, Union, get_args, get_origin, get_type_hints
|
from typing import ClassVar, Dict, List, Literal, Optional, Union, get_args, get_origin, get_type_hints
|
||||||
|
|
||||||
from omegaconf import DictConfig, ListConfig, OmegaConf
|
from omegaconf import DictConfig, ListConfig, OmegaConf
|
||||||
from pydantic import BaseSettings
|
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||||
|
|
||||||
from invokeai.app.services.config.config_common import PagingArgumentParser, int_or_float_or_str
|
from invokeai.app.services.config.config_common import PagingArgumentParser, int_or_float_or_str
|
||||||
|
|
||||||
@ -32,12 +32,14 @@ class InvokeAISettings(BaseSettings):
|
|||||||
initconf: ClassVar[Optional[DictConfig]] = None
|
initconf: ClassVar[Optional[DictConfig]] = None
|
||||||
argparse_groups: ClassVar[Dict] = {}
|
argparse_groups: ClassVar[Dict] = {}
|
||||||
|
|
||||||
|
model_config = SettingsConfigDict(env_file_encoding="utf-8", arbitrary_types_allowed=True, case_sensitive=True)
|
||||||
|
|
||||||
def parse_args(self, argv: Optional[list] = sys.argv[1:]):
|
def parse_args(self, argv: Optional[list] = sys.argv[1:]):
|
||||||
parser = self.get_parser()
|
parser = self.get_parser()
|
||||||
opt, unknown_opts = parser.parse_known_args(argv)
|
opt, unknown_opts = parser.parse_known_args(argv)
|
||||||
if len(unknown_opts) > 0:
|
if len(unknown_opts) > 0:
|
||||||
print("Unknown args:", unknown_opts)
|
print("Unknown args:", unknown_opts)
|
||||||
for name in self.__fields__:
|
for name in self.model_fields:
|
||||||
if name not in self._excluded():
|
if name not in self._excluded():
|
||||||
value = getattr(opt, name)
|
value = getattr(opt, name)
|
||||||
if isinstance(value, ListConfig):
|
if isinstance(value, ListConfig):
|
||||||
@ -54,10 +56,12 @@ class InvokeAISettings(BaseSettings):
|
|||||||
cls = self.__class__
|
cls = self.__class__
|
||||||
type = get_args(get_type_hints(cls)["type"])[0]
|
type = get_args(get_type_hints(cls)["type"])[0]
|
||||||
field_dict = dict({type: dict()})
|
field_dict = dict({type: dict()})
|
||||||
for name, field in self.__fields__.items():
|
for name, field in self.model_fields.items():
|
||||||
if name in cls._excluded_from_yaml():
|
if name in cls._excluded_from_yaml():
|
||||||
continue
|
continue
|
||||||
category = field.field_info.extra.get("category") or "Uncategorized"
|
category = (
|
||||||
|
field.json_schema_extra.get("category", "Uncategorized") if field.json_schema_extra else "Uncategorized"
|
||||||
|
)
|
||||||
value = getattr(self, name)
|
value = getattr(self, name)
|
||||||
if category not in field_dict[type]:
|
if category not in field_dict[type]:
|
||||||
field_dict[type][category] = dict()
|
field_dict[type][category] = dict()
|
||||||
@ -73,7 +77,7 @@ class InvokeAISettings(BaseSettings):
|
|||||||
else:
|
else:
|
||||||
settings_stanza = "Uncategorized"
|
settings_stanza = "Uncategorized"
|
||||||
|
|
||||||
env_prefix = getattr(cls.Config, "env_prefix", None)
|
env_prefix = getattr(cls.model_config, "env_prefix", None)
|
||||||
env_prefix = env_prefix if env_prefix is not None else settings_stanza.upper()
|
env_prefix = env_prefix if env_prefix is not None else settings_stanza.upper()
|
||||||
|
|
||||||
initconf = (
|
initconf = (
|
||||||
@ -89,14 +93,18 @@ class InvokeAISettings(BaseSettings):
|
|||||||
for key, value in os.environ.items():
|
for key, value in os.environ.items():
|
||||||
upcase_environ[key.upper()] = value
|
upcase_environ[key.upper()] = value
|
||||||
|
|
||||||
fields = cls.__fields__
|
fields = cls.model_fields
|
||||||
cls.argparse_groups = {}
|
cls.argparse_groups = {}
|
||||||
|
|
||||||
for name, field in fields.items():
|
for name, field in fields.items():
|
||||||
if name not in cls._excluded():
|
if name not in cls._excluded():
|
||||||
current_default = field.default
|
current_default = field.default
|
||||||
|
|
||||||
category = field.field_info.extra.get("category", "Uncategorized")
|
category = (
|
||||||
|
field.json_schema_extra.get("category", "Uncategorized")
|
||||||
|
if field.json_schema_extra
|
||||||
|
else "Uncategorized"
|
||||||
|
)
|
||||||
env_name = env_prefix + "_" + name
|
env_name = env_prefix + "_" + name
|
||||||
if category in initconf and name in initconf.get(category):
|
if category in initconf and name in initconf.get(category):
|
||||||
field.default = initconf.get(category).get(name)
|
field.default = initconf.get(category).get(name)
|
||||||
@ -146,11 +154,6 @@ class InvokeAISettings(BaseSettings):
|
|||||||
"tiled_decode",
|
"tiled_decode",
|
||||||
]
|
]
|
||||||
|
|
||||||
class Config:
|
|
||||||
env_file_encoding = "utf-8"
|
|
||||||
arbitrary_types_allowed = True
|
|
||||||
case_sensitive = True
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def add_field_argument(cls, command_parser, name: str, field, default_override=None):
|
def add_field_argument(cls, command_parser, name: str, field, default_override=None):
|
||||||
field_type = get_type_hints(cls).get(name)
|
field_type = get_type_hints(cls).get(name)
|
||||||
@ -161,7 +164,7 @@ class InvokeAISettings(BaseSettings):
|
|||||||
if field.default_factory is None
|
if field.default_factory is None
|
||||||
else field.default_factory()
|
else field.default_factory()
|
||||||
)
|
)
|
||||||
if category := field.field_info.extra.get("category"):
|
if category := (field.json_schema_extra.get("category", None) if field.json_schema_extra else None):
|
||||||
if category not in cls.argparse_groups:
|
if category not in cls.argparse_groups:
|
||||||
cls.argparse_groups[category] = command_parser.add_argument_group(category)
|
cls.argparse_groups[category] = command_parser.add_argument_group(category)
|
||||||
argparse_group = cls.argparse_groups[category]
|
argparse_group = cls.argparse_groups[category]
|
||||||
@ -169,7 +172,7 @@ class InvokeAISettings(BaseSettings):
|
|||||||
argparse_group = command_parser
|
argparse_group = command_parser
|
||||||
|
|
||||||
if get_origin(field_type) == Literal:
|
if get_origin(field_type) == Literal:
|
||||||
allowed_values = get_args(field.type_)
|
allowed_values = get_args(field.annotation)
|
||||||
allowed_types = set()
|
allowed_types = set()
|
||||||
for val in allowed_values:
|
for val in allowed_values:
|
||||||
allowed_types.add(type(val))
|
allowed_types.add(type(val))
|
||||||
@ -182,7 +185,7 @@ class InvokeAISettings(BaseSettings):
|
|||||||
type=field_type,
|
type=field_type,
|
||||||
default=default,
|
default=default,
|
||||||
choices=allowed_values,
|
choices=allowed_values,
|
||||||
help=field.field_info.description,
|
help=field.description,
|
||||||
)
|
)
|
||||||
|
|
||||||
elif get_origin(field_type) == Union:
|
elif get_origin(field_type) == Union:
|
||||||
@ -191,7 +194,7 @@ class InvokeAISettings(BaseSettings):
|
|||||||
dest=name,
|
dest=name,
|
||||||
type=int_or_float_or_str,
|
type=int_or_float_or_str,
|
||||||
default=default,
|
default=default,
|
||||||
help=field.field_info.description,
|
help=field.description,
|
||||||
)
|
)
|
||||||
|
|
||||||
elif get_origin(field_type) == list:
|
elif get_origin(field_type) == list:
|
||||||
@ -199,17 +202,17 @@ class InvokeAISettings(BaseSettings):
|
|||||||
f"--{name}",
|
f"--{name}",
|
||||||
dest=name,
|
dest=name,
|
||||||
nargs="*",
|
nargs="*",
|
||||||
type=field.type_,
|
type=field.annotation,
|
||||||
default=default,
|
default=default,
|
||||||
action=argparse.BooleanOptionalAction if field.type_ == bool else "store",
|
action=argparse.BooleanOptionalAction if field.annotation == bool else "store",
|
||||||
help=field.field_info.description,
|
help=field.description,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
argparse_group.add_argument(
|
argparse_group.add_argument(
|
||||||
f"--{name}",
|
f"--{name}",
|
||||||
dest=name,
|
dest=name,
|
||||||
type=field.type_,
|
type=field.annotation,
|
||||||
default=default,
|
default=default,
|
||||||
action=argparse.BooleanOptionalAction if field.type_ == bool else "store",
|
action=argparse.BooleanOptionalAction if field.annotation == bool else "store",
|
||||||
help=field.field_info.description,
|
help=field.description,
|
||||||
)
|
)
|
||||||
|
@ -144,8 +144,8 @@ which is set to the desired top-level name. For example, to create a
|
|||||||
|
|
||||||
class InvokeBatch(InvokeAISettings):
|
class InvokeBatch(InvokeAISettings):
|
||||||
type: Literal["InvokeBatch"] = "InvokeBatch"
|
type: Literal["InvokeBatch"] = "InvokeBatch"
|
||||||
node_count : int = Field(default=1, description="Number of nodes to run on", category='Resources')
|
node_count : int = Field(default=1, description="Number of nodes to run on", json_schema_extra=dict(category='Resources'))
|
||||||
cpu_count : int = Field(default=8, description="Number of GPUs to run on per node", category='Resources')
|
cpu_count : int = Field(default=8, description="Number of GPUs to run on per node", json_schema_extra=dict(category='Resources'))
|
||||||
|
|
||||||
This will now read and write from the "InvokeBatch" section of the
|
This will now read and write from the "InvokeBatch" section of the
|
||||||
config file, look for environment variables named INVOKEBATCH_*, and
|
config file, look for environment variables named INVOKEBATCH_*, and
|
||||||
@ -175,7 +175,8 @@ from pathlib import Path
|
|||||||
from typing import ClassVar, Dict, List, Literal, Optional, Union, get_type_hints
|
from typing import ClassVar, Dict, List, Literal, Optional, Union, get_type_hints
|
||||||
|
|
||||||
from omegaconf import DictConfig, OmegaConf
|
from omegaconf import DictConfig, OmegaConf
|
||||||
from pydantic import Field, parse_obj_as
|
from pydantic import Field, TypeAdapter
|
||||||
|
from pydantic_settings import SettingsConfigDict
|
||||||
|
|
||||||
from .config_base import InvokeAISettings
|
from .config_base import InvokeAISettings
|
||||||
|
|
||||||
@ -185,6 +186,21 @@ LEGACY_INIT_FILE = Path("invokeai.init")
|
|||||||
DEFAULT_MAX_VRAM = 0.5
|
DEFAULT_MAX_VRAM = 0.5
|
||||||
|
|
||||||
|
|
||||||
|
class Categories(object):
|
||||||
|
WebServer = dict(category="Web Server")
|
||||||
|
Features = dict(category="Features")
|
||||||
|
Paths = dict(category="Paths")
|
||||||
|
Logging = dict(category="Logging")
|
||||||
|
Development = dict(category="Development")
|
||||||
|
Other = dict(category="Other")
|
||||||
|
ModelCache = dict(category="Model Cache")
|
||||||
|
Device = dict(category="Device")
|
||||||
|
Generation = dict(category="Generation")
|
||||||
|
Queue = dict(category="Queue")
|
||||||
|
Nodes = dict(category="Nodes")
|
||||||
|
MemoryPerformance = dict(category="Memory/Performance")
|
||||||
|
|
||||||
|
|
||||||
class InvokeAIAppConfig(InvokeAISettings):
|
class InvokeAIAppConfig(InvokeAISettings):
|
||||||
"""
|
"""
|
||||||
Generate images using Stable Diffusion. Use "invokeai" to launch
|
Generate images using Stable Diffusion. Use "invokeai" to launch
|
||||||
@ -201,86 +217,88 @@ class InvokeAIAppConfig(InvokeAISettings):
|
|||||||
type: Literal["InvokeAI"] = "InvokeAI"
|
type: Literal["InvokeAI"] = "InvokeAI"
|
||||||
|
|
||||||
# WEB
|
# WEB
|
||||||
host : str = Field(default="127.0.0.1", description="IP address to bind to", category='Web Server')
|
host : str = Field(default="127.0.0.1", description="IP address to bind to", json_schema_extra=Categories.WebServer)
|
||||||
port : int = Field(default=9090, description="Port to bind to", category='Web Server')
|
port : int = Field(default=9090, description="Port to bind to", json_schema_extra=Categories.WebServer)
|
||||||
allow_origins : List[str] = Field(default=[], description="Allowed CORS origins", category='Web Server')
|
allow_origins : List[str] = Field(default=[], description="Allowed CORS origins", json_schema_extra=Categories.WebServer)
|
||||||
allow_credentials : bool = Field(default=True, description="Allow CORS credentials", category='Web Server')
|
allow_credentials : bool = Field(default=True, description="Allow CORS credentials", json_schema_extra=Categories.WebServer)
|
||||||
allow_methods : List[str] = Field(default=["*"], description="Methods allowed for CORS", category='Web Server')
|
allow_methods : List[str] = Field(default=["*"], description="Methods allowed for CORS", json_schema_extra=Categories.WebServer)
|
||||||
allow_headers : List[str] = Field(default=["*"], description="Headers allowed for CORS", category='Web Server')
|
allow_headers : List[str] = Field(default=["*"], description="Headers allowed for CORS", json_schema_extra=Categories.WebServer)
|
||||||
|
|
||||||
# FEATURES
|
# FEATURES
|
||||||
esrgan : bool = Field(default=True, description="Enable/disable upscaling code", category='Features')
|
esrgan : bool = Field(default=True, description="Enable/disable upscaling code", json_schema_extra=Categories.Features)
|
||||||
internet_available : bool = Field(default=True, description="If true, attempt to download models on the fly; otherwise only use local models", category='Features')
|
internet_available : bool = Field(default=True, description="If true, attempt to download models on the fly; otherwise only use local models", json_schema_extra=Categories.Features)
|
||||||
log_tokenization : bool = Field(default=False, description="Enable logging of parsed prompt tokens.", category='Features')
|
log_tokenization : bool = Field(default=False, description="Enable logging of parsed prompt tokens.", json_schema_extra=Categories.Features)
|
||||||
patchmatch : bool = Field(default=True, description="Enable/disable patchmatch inpaint code", category='Features')
|
patchmatch : bool = Field(default=True, description="Enable/disable patchmatch inpaint code", json_schema_extra=Categories.Features)
|
||||||
ignore_missing_core_models : bool = Field(default=False, description='Ignore missing models in models/core/convert', category='Features')
|
ignore_missing_core_models : bool = Field(default=False, description='Ignore missing models in models/core/convert', json_schema_extra=Categories.Features)
|
||||||
|
|
||||||
# PATHS
|
# PATHS
|
||||||
root : Path = Field(default=None, description='InvokeAI runtime root directory', category='Paths')
|
root : Optional[Path] = Field(default=None, description='InvokeAI runtime root directory', json_schema_extra=Categories.Paths)
|
||||||
autoimport_dir : Path = Field(default='autoimport', description='Path to a directory of models files to be imported on startup.', category='Paths')
|
autoimport_dir : Optional[Path] = Field(default=Path('autoimport'), description='Path to a directory of models files to be imported on startup.', json_schema_extra=Categories.Paths)
|
||||||
lora_dir : Path = Field(default=None, description='Path to a directory of LoRA/LyCORIS models to be imported on startup.', category='Paths')
|
lora_dir : Optional[Path] = Field(default=None, description='Path to a directory of LoRA/LyCORIS models to be imported on startup.', json_schema_extra=Categories.Paths)
|
||||||
embedding_dir : Path = Field(default=None, description='Path to a directory of Textual Inversion embeddings to be imported on startup.', category='Paths')
|
embedding_dir : Optional[Path] = Field(default=None, description='Path to a directory of Textual Inversion embeddings to be imported on startup.', json_schema_extra=Categories.Paths)
|
||||||
controlnet_dir : Path = Field(default=None, description='Path to a directory of ControlNet embeddings to be imported on startup.', category='Paths')
|
controlnet_dir : Optional[Path] = Field(default=None, description='Path to a directory of ControlNet embeddings to be imported on startup.', json_schema_extra=Categories.Paths)
|
||||||
conf_path : Path = Field(default='configs/models.yaml', description='Path to models definition file', category='Paths')
|
conf_path : Optional[Path] = Field(default=Path('configs/models.yaml'), description='Path to models definition file', json_schema_extra=Categories.Paths)
|
||||||
models_dir : Path = Field(default='models', description='Path to the models directory', category='Paths')
|
models_dir : Optional[Path] = Field(default=Path('models'), description='Path to the models directory', json_schema_extra=Categories.Paths)
|
||||||
legacy_conf_dir : Path = Field(default='configs/stable-diffusion', description='Path to directory of legacy checkpoint config files', category='Paths')
|
legacy_conf_dir : Optional[Path] = Field(default=Path('configs/stable-diffusion'), description='Path to directory of legacy checkpoint config files', json_schema_extra=Categories.Paths)
|
||||||
db_dir : Path = Field(default='databases', description='Path to InvokeAI databases directory', category='Paths')
|
db_dir : Optional[Path] = Field(default=Path('databases'), description='Path to InvokeAI databases directory', json_schema_extra=Categories.Paths)
|
||||||
outdir : Path = Field(default='outputs', description='Default folder for output images', category='Paths')
|
outdir : Optional[Path] = Field(default=Path('outputs'), description='Default folder for output images', json_schema_extra=Categories.Paths)
|
||||||
use_memory_db : bool = Field(default=False, description='Use in-memory database for storing image metadata', category='Paths')
|
use_memory_db : bool = Field(default=False, description='Use in-memory database for storing image metadata', json_schema_extra=Categories.Paths)
|
||||||
from_file : Path = Field(default=None, description='Take command input from the indicated file (command-line client only)', category='Paths')
|
from_file : Optional[Path] = Field(default=None, description='Take command input from the indicated file (command-line client only)', json_schema_extra=Categories.Paths)
|
||||||
|
|
||||||
# LOGGING
|
# LOGGING
|
||||||
log_handlers : List[str] = Field(default=["console"], description='Log handler. Valid options are "console", "file=<path>", "syslog=path|address:host:port", "http=<url>"', category="Logging")
|
log_handlers : List[str] = Field(default=["console"], description='Log handler. Valid options are "console", "file=<path>", "syslog=path|address:host:port", "http=<url>"', json_schema_extra=Categories.Logging)
|
||||||
# note - would be better to read the log_format values from logging.py, but this creates circular dependencies issues
|
# note - would be better to read the log_format values from logging.py, but this creates circular dependencies issues
|
||||||
log_format : Literal['plain', 'color', 'syslog', 'legacy'] = Field(default="color", description='Log format. Use "plain" for text-only, "color" for colorized output, "legacy" for 2.3-style logging and "syslog" for syslog-style', category="Logging")
|
log_format : Literal['plain', 'color', 'syslog', 'legacy'] = Field(default="color", description='Log format. Use "plain" for text-only, "color" for colorized output, "legacy" for 2.3-style logging and "syslog" for syslog-style', json_schema_extra=Categories.Logging)
|
||||||
log_level : Literal["debug", "info", "warning", "error", "critical"] = Field(default="info", description="Emit logging messages at this level or higher", category="Logging")
|
log_level : Literal["debug", "info", "warning", "error", "critical"] = Field(default="info", description="Emit logging messages at this level or higher", json_schema_extra=Categories.Logging)
|
||||||
log_sql : bool = Field(default=False, description="Log SQL queries", category="Logging")
|
log_sql : bool = Field(default=False, description="Log SQL queries", json_schema_extra=Categories.Logging)
|
||||||
|
|
||||||
dev_reload : bool = Field(default=False, description="Automatically reload when Python sources are changed.", category="Development")
|
dev_reload : bool = Field(default=False, description="Automatically reload when Python sources are changed.", json_schema_extra=Categories.Development)
|
||||||
|
|
||||||
version : bool = Field(default=False, description="Show InvokeAI version and exit", category="Other")
|
version : bool = Field(default=False, description="Show InvokeAI version and exit", json_schema_extra=Categories.Other)
|
||||||
|
|
||||||
# CACHE
|
# CACHE
|
||||||
ram : float = Field(default=7.5, gt=0, description="Maximum memory amount used by model cache for rapid switching (floating point number, GB)", category="Model Cache", )
|
ram : float = Field(default=7.5, gt=0, description="Maximum memory amount used by model cache for rapid switching (floating point number, GB)", json_schema_extra=Categories.ModelCache, )
|
||||||
vram : float = Field(default=0.25, ge=0, description="Amount of VRAM reserved for model storage (floating point number, GB)", category="Model Cache", )
|
vram : float = Field(default=0.25, ge=0, description="Amount of VRAM reserved for model storage (floating point number, GB)", json_schema_extra=Categories.ModelCache, )
|
||||||
lazy_offload : bool = Field(default=True, description="Keep models in VRAM until their space is needed", category="Model Cache", )
|
lazy_offload : bool = Field(default=True, description="Keep models in VRAM until their space is needed", json_schema_extra=Categories.ModelCache, )
|
||||||
|
|
||||||
# DEVICE
|
# DEVICE
|
||||||
device : Literal["auto", "cpu", "cuda", "cuda:1", "mps"] = Field(default="auto", description="Generation device", category="Device", )
|
device : Literal["auto", "cpu", "cuda", "cuda:1", "mps"] = Field(default="auto", description="Generation device", json_schema_extra=Categories.Device)
|
||||||
precision : Literal["auto", "float16", "float32", "autocast"] = Field(default="auto", description="Floating point precision", category="Device", )
|
precision : Literal["auto", "float16", "float32", "autocast"] = Field(default="auto", description="Floating point precision", json_schema_extra=Categories.Device)
|
||||||
|
|
||||||
# GENERATION
|
# GENERATION
|
||||||
sequential_guidance : bool = Field(default=False, description="Whether to calculate guidance in serial instead of in parallel, lowering memory requirements", category="Generation", )
|
sequential_guidance : bool = Field(default=False, description="Whether to calculate guidance in serial instead of in parallel, lowering memory requirements", json_schema_extra=Categories.Generation)
|
||||||
attention_type : Literal["auto", "normal", "xformers", "sliced", "torch-sdp"] = Field(default="auto", description="Attention type", category="Generation", )
|
attention_type : Literal["auto", "normal", "xformers", "sliced", "torch-sdp"] = Field(default="auto", description="Attention type", json_schema_extra=Categories.Generation)
|
||||||
attention_slice_size: Literal["auto", "balanced", "max", 1, 2, 3, 4, 5, 6, 7, 8] = Field(default="auto", description='Slice size, valid when attention_type=="sliced"', category="Generation", )
|
attention_slice_size: Literal["auto", "balanced", "max", 1, 2, 3, 4, 5, 6, 7, 8] = Field(default="auto", description='Slice size, valid when attention_type=="sliced"', json_schema_extra=Categories.Generation)
|
||||||
force_tiled_decode : bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty)", category="Generation",)
|
force_tiled_decode : bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty)", json_schema_extra=Categories.Generation)
|
||||||
force_tiled_decode: bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty)", category="Generation",)
|
png_compress_level : int = Field(default=6, description="The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = fastest, largest filesize, 9 = slowest, smallest filesize", json_schema_extra=Categories.Generation)
|
||||||
png_compress_level : int = Field(default=6, description="The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = fastest, largest filesize, 9 = slowest, smallest filesize", category="Generation", )
|
|
||||||
|
|
||||||
# QUEUE
|
# QUEUE
|
||||||
max_queue_size : int = Field(default=10000, gt=0, description="Maximum number of items in the session queue", category="Queue", )
|
max_queue_size : int = Field(default=10000, gt=0, description="Maximum number of items in the session queue", json_schema_extra=Categories.Queue)
|
||||||
|
|
||||||
# NODES
|
# NODES
|
||||||
allow_nodes : Optional[List[str]] = Field(default=None, description="List of nodes to allow. Omit to allow all.", category="Nodes")
|
allow_nodes : Optional[List[str]] = Field(default=None, description="List of nodes to allow. Omit to allow all.", json_schema_extra=Categories.Nodes)
|
||||||
deny_nodes : Optional[List[str]] = Field(default=None, description="List of nodes to deny. Omit to deny none.", category="Nodes")
|
deny_nodes : Optional[List[str]] = Field(default=None, description="List of nodes to deny. Omit to deny none.", json_schema_extra=Categories.Nodes)
|
||||||
node_cache_size : int = Field(default=512, description="How many cached nodes to keep in memory", category="Nodes", )
|
node_cache_size : int = Field(default=512, description="How many cached nodes to keep in memory", json_schema_extra=Categories.Nodes)
|
||||||
|
|
||||||
# DEPRECATED FIELDS - STILL HERE IN ORDER TO OBTAN VALUES FROM PRE-3.1 CONFIG FILES
|
# DEPRECATED FIELDS - STILL HERE IN ORDER TO OBTAN VALUES FROM PRE-3.1 CONFIG FILES
|
||||||
always_use_cpu : bool = Field(default=False, description="If true, use the CPU for rendering even if a GPU is available.", category='Memory/Performance')
|
always_use_cpu : bool = Field(default=False, description="If true, use the CPU for rendering even if a GPU is available.", json_schema_extra=Categories.MemoryPerformance)
|
||||||
free_gpu_mem : Optional[bool] = Field(default=None, description="If true, purge model from GPU after each generation.", category='Memory/Performance')
|
free_gpu_mem : Optional[bool] = Field(default=None, description="If true, purge model from GPU after each generation.", json_schema_extra=Categories.MemoryPerformance)
|
||||||
max_cache_size : Optional[float] = Field(default=None, gt=0, description="Maximum memory amount used by model cache for rapid switching", category='Memory/Performance')
|
max_cache_size : Optional[float] = Field(default=None, gt=0, description="Maximum memory amount used by model cache for rapid switching", json_schema_extra=Categories.MemoryPerformance)
|
||||||
max_vram_cache_size : Optional[float] = Field(default=None, ge=0, description="Amount of VRAM reserved for model storage", category='Memory/Performance')
|
max_vram_cache_size : Optional[float] = Field(default=None, ge=0, description="Amount of VRAM reserved for model storage", json_schema_extra=Categories.MemoryPerformance)
|
||||||
xformers_enabled : bool = Field(default=True, description="Enable/disable memory-efficient attention", category='Memory/Performance')
|
xformers_enabled : bool = Field(default=True, description="Enable/disable memory-efficient attention", json_schema_extra=Categories.MemoryPerformance)
|
||||||
tiled_decode : bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty)", category='Memory/Performance')
|
tiled_decode : bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty)", json_schema_extra=Categories.MemoryPerformance)
|
||||||
|
|
||||||
# See InvokeAIAppConfig subclass below for CACHE and DEVICE categories
|
# See InvokeAIAppConfig subclass below for CACHE and DEVICE categories
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
class Config:
|
model_config = SettingsConfigDict(validate_assignment=True, env_prefix="INVOKEAI")
|
||||||
validate_assignment = True
|
|
||||||
env_prefix = "INVOKEAI"
|
|
||||||
|
|
||||||
def parse_args(self, argv: Optional[list[str]] = None, conf: Optional[DictConfig] = None, clobber=False):
|
def parse_args(
|
||||||
|
self,
|
||||||
|
argv: Optional[list[str]] = None,
|
||||||
|
conf: Optional[DictConfig] = None,
|
||||||
|
clobber=False,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Update settings with contents of init file, environment, and
|
Update settings with contents of init file, environment, and
|
||||||
command-line settings.
|
command-line settings.
|
||||||
@ -308,7 +326,11 @@ class InvokeAIAppConfig(InvokeAISettings):
|
|||||||
if self.singleton_init and not clobber:
|
if self.singleton_init and not clobber:
|
||||||
hints = get_type_hints(self.__class__)
|
hints = get_type_hints(self.__class__)
|
||||||
for k in self.singleton_init:
|
for k in self.singleton_init:
|
||||||
setattr(self, k, parse_obj_as(hints[k], self.singleton_init[k]))
|
setattr(
|
||||||
|
self,
|
||||||
|
k,
|
||||||
|
TypeAdapter(hints[k]).validate_python(self.singleton_init[k]),
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_config(cls, **kwargs) -> InvokeAIAppConfig:
|
def get_config(cls, **kwargs) -> InvokeAIAppConfig:
|
||||||
|
@ -2,7 +2,6 @@
|
|||||||
|
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
from invokeai.app.invocations.model import ModelInfo
|
|
||||||
from invokeai.app.services.invocation_processor.invocation_processor_common import ProgressImage
|
from invokeai.app.services.invocation_processor.invocation_processor_common import ProgressImage
|
||||||
from invokeai.app.services.session_queue.session_queue_common import (
|
from invokeai.app.services.session_queue.session_queue_common import (
|
||||||
BatchStatus,
|
BatchStatus,
|
||||||
@ -11,6 +10,7 @@ from invokeai.app.services.session_queue.session_queue_common import (
|
|||||||
SessionQueueStatus,
|
SessionQueueStatus,
|
||||||
)
|
)
|
||||||
from invokeai.app.util.misc import get_timestamp
|
from invokeai.app.util.misc import get_timestamp
|
||||||
|
from invokeai.backend.model_management.model_manager import ModelInfo
|
||||||
from invokeai.backend.model_management.models.base import BaseModelType, ModelType, SubModelType
|
from invokeai.backend.model_management.models.base import BaseModelType, ModelType, SubModelType
|
||||||
|
|
||||||
|
|
||||||
@ -55,7 +55,7 @@ class EventServiceBase:
|
|||||||
graph_execution_state_id=graph_execution_state_id,
|
graph_execution_state_id=graph_execution_state_id,
|
||||||
node_id=node.get("id"),
|
node_id=node.get("id"),
|
||||||
source_node_id=source_node_id,
|
source_node_id=source_node_id,
|
||||||
progress_image=progress_image.dict() if progress_image is not None else None,
|
progress_image=progress_image.model_dump() if progress_image is not None else None,
|
||||||
step=step,
|
step=step,
|
||||||
order=order,
|
order=order,
|
||||||
total_steps=total_steps,
|
total_steps=total_steps,
|
||||||
@ -291,8 +291,8 @@ class EventServiceBase:
|
|||||||
started_at=str(session_queue_item.started_at) if session_queue_item.started_at else None,
|
started_at=str(session_queue_item.started_at) if session_queue_item.started_at else None,
|
||||||
completed_at=str(session_queue_item.completed_at) if session_queue_item.completed_at else None,
|
completed_at=str(session_queue_item.completed_at) if session_queue_item.completed_at else None,
|
||||||
),
|
),
|
||||||
batch_status=batch_status.dict(),
|
batch_status=batch_status.model_dump(),
|
||||||
queue_status=queue_status.dict(),
|
queue_status=queue_status.model_dump(),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from PIL.Image import Image as PILImageType
|
from PIL.Image import Image as PILImageType
|
||||||
@ -13,7 +14,7 @@ class ImageFileStorageBase(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_path(self, image_name: str, thumbnail: bool = False) -> str:
|
def get_path(self, image_name: str, thumbnail: bool = False) -> Path:
|
||||||
"""Gets the internal path to an image or thumbnail."""
|
"""Gets the internal path to an image or thumbnail."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -34,8 +34,8 @@ class ImageRecordStorageBase(ABC):
|
|||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_many(
|
def get_many(
|
||||||
self,
|
self,
|
||||||
offset: Optional[int] = None,
|
offset: int = 0,
|
||||||
limit: Optional[int] = None,
|
limit: int = 10,
|
||||||
image_origin: Optional[ResourceOrigin] = None,
|
image_origin: Optional[ResourceOrigin] = None,
|
||||||
categories: Optional[list[ImageCategory]] = None,
|
categories: Optional[list[ImageCategory]] = None,
|
||||||
is_intermediate: Optional[bool] = None,
|
is_intermediate: Optional[bool] = None,
|
||||||
@ -69,11 +69,11 @@ class ImageRecordStorageBase(ABC):
|
|||||||
image_category: ImageCategory,
|
image_category: ImageCategory,
|
||||||
width: int,
|
width: int,
|
||||||
height: int,
|
height: int,
|
||||||
session_id: Optional[str],
|
is_intermediate: Optional[bool] = False,
|
||||||
node_id: Optional[str],
|
starred: Optional[bool] = False,
|
||||||
metadata: Optional[dict],
|
session_id: Optional[str] = None,
|
||||||
is_intermediate: bool = False,
|
node_id: Optional[str] = None,
|
||||||
starred: bool = False,
|
metadata: Optional[dict] = None,
|
||||||
) -> datetime:
|
) -> datetime:
|
||||||
"""Saves an image record."""
|
"""Saves an image record."""
|
||||||
pass
|
pass
|
||||||
|
@ -3,7 +3,7 @@ import datetime
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
from pydantic import Extra, Field, StrictBool, StrictStr
|
from pydantic import Field, StrictBool, StrictStr
|
||||||
|
|
||||||
from invokeai.app.util.metaenum import MetaEnum
|
from invokeai.app.util.metaenum import MetaEnum
|
||||||
from invokeai.app.util.misc import get_iso_timestamp
|
from invokeai.app.util.misc import get_iso_timestamp
|
||||||
@ -129,7 +129,9 @@ class ImageRecord(BaseModelExcludeNull):
|
|||||||
"""The created timestamp of the image."""
|
"""The created timestamp of the image."""
|
||||||
updated_at: Union[datetime.datetime, str] = Field(description="The updated timestamp of the image.")
|
updated_at: Union[datetime.datetime, str] = Field(description="The updated timestamp of the image.")
|
||||||
"""The updated timestamp of the image."""
|
"""The updated timestamp of the image."""
|
||||||
deleted_at: Union[datetime.datetime, str, None] = Field(description="The deleted timestamp of the image.")
|
deleted_at: Optional[Union[datetime.datetime, str]] = Field(
|
||||||
|
default=None, description="The deleted timestamp of the image."
|
||||||
|
)
|
||||||
"""The deleted timestamp of the image."""
|
"""The deleted timestamp of the image."""
|
||||||
is_intermediate: bool = Field(description="Whether this is an intermediate image.")
|
is_intermediate: bool = Field(description="Whether this is an intermediate image.")
|
||||||
"""Whether this is an intermediate image."""
|
"""Whether this is an intermediate image."""
|
||||||
@ -147,7 +149,7 @@ class ImageRecord(BaseModelExcludeNull):
|
|||||||
"""Whether this image is starred."""
|
"""Whether this image is starred."""
|
||||||
|
|
||||||
|
|
||||||
class ImageRecordChanges(BaseModelExcludeNull, extra=Extra.forbid):
|
class ImageRecordChanges(BaseModelExcludeNull, extra="allow"):
|
||||||
"""A set of changes to apply to an image record.
|
"""A set of changes to apply to an image record.
|
||||||
|
|
||||||
Only limited changes are valid:
|
Only limited changes are valid:
|
||||||
|
@ -2,7 +2,7 @@ import json
|
|||||||
import sqlite3
|
import sqlite3
|
||||||
import threading
|
import threading
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Optional, cast
|
from typing import Optional, Union, cast
|
||||||
|
|
||||||
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
|
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
|
||||||
from invokeai.app.services.shared.sqlite import SqliteDatabase
|
from invokeai.app.services.shared.sqlite import SqliteDatabase
|
||||||
@ -117,7 +117,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
|||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
def get(self, image_name: str) -> Optional[ImageRecord]:
|
def get(self, image_name: str) -> ImageRecord:
|
||||||
try:
|
try:
|
||||||
self._lock.acquire()
|
self._lock.acquire()
|
||||||
|
|
||||||
@ -223,8 +223,8 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
|||||||
|
|
||||||
def get_many(
|
def get_many(
|
||||||
self,
|
self,
|
||||||
offset: Optional[int] = None,
|
offset: int = 0,
|
||||||
limit: Optional[int] = None,
|
limit: int = 10,
|
||||||
image_origin: Optional[ResourceOrigin] = None,
|
image_origin: Optional[ResourceOrigin] = None,
|
||||||
categories: Optional[list[ImageCategory]] = None,
|
categories: Optional[list[ImageCategory]] = None,
|
||||||
is_intermediate: Optional[bool] = None,
|
is_intermediate: Optional[bool] = None,
|
||||||
@ -249,7 +249,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
query_conditions = ""
|
query_conditions = ""
|
||||||
query_params = []
|
query_params: list[Union[int, str, bool]] = []
|
||||||
|
|
||||||
if image_origin is not None:
|
if image_origin is not None:
|
||||||
query_conditions += """--sql
|
query_conditions += """--sql
|
||||||
@ -387,13 +387,13 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
|||||||
image_name: str,
|
image_name: str,
|
||||||
image_origin: ResourceOrigin,
|
image_origin: ResourceOrigin,
|
||||||
image_category: ImageCategory,
|
image_category: ImageCategory,
|
||||||
session_id: Optional[str],
|
|
||||||
width: int,
|
width: int,
|
||||||
height: int,
|
height: int,
|
||||||
node_id: Optional[str],
|
is_intermediate: Optional[bool] = False,
|
||||||
metadata: Optional[dict],
|
starred: Optional[bool] = False,
|
||||||
is_intermediate: bool = False,
|
session_id: Optional[str] = None,
|
||||||
starred: bool = False,
|
node_id: Optional[str] = None,
|
||||||
|
metadata: Optional[dict] = None,
|
||||||
) -> datetime:
|
) -> datetime:
|
||||||
try:
|
try:
|
||||||
metadata_json = None if metadata is None else json.dumps(metadata)
|
metadata_json = None if metadata is None else json.dumps(metadata)
|
||||||
|
@ -49,7 +49,7 @@ class ImageServiceABC(ABC):
|
|||||||
node_id: Optional[str] = None,
|
node_id: Optional[str] = None,
|
||||||
session_id: Optional[str] = None,
|
session_id: Optional[str] = None,
|
||||||
board_id: Optional[str] = None,
|
board_id: Optional[str] = None,
|
||||||
is_intermediate: bool = False,
|
is_intermediate: Optional[bool] = False,
|
||||||
metadata: Optional[dict] = None,
|
metadata: Optional[dict] = None,
|
||||||
workflow: Optional[str] = None,
|
workflow: Optional[str] = None,
|
||||||
) -> ImageDTO:
|
) -> ImageDTO:
|
||||||
|
@ -20,7 +20,9 @@ class ImageUrlsDTO(BaseModelExcludeNull):
|
|||||||
class ImageDTO(ImageRecord, ImageUrlsDTO):
|
class ImageDTO(ImageRecord, ImageUrlsDTO):
|
||||||
"""Deserialized image record, enriched for the frontend."""
|
"""Deserialized image record, enriched for the frontend."""
|
||||||
|
|
||||||
board_id: Optional[str] = Field(description="The id of the board the image belongs to, if one exists.")
|
board_id: Optional[str] = Field(
|
||||||
|
default=None, description="The id of the board the image belongs to, if one exists."
|
||||||
|
)
|
||||||
"""The id of the board the image belongs to, if one exists."""
|
"""The id of the board the image belongs to, if one exists."""
|
||||||
|
|
||||||
pass
|
pass
|
||||||
@ -34,7 +36,7 @@ def image_record_to_dto(
|
|||||||
) -> ImageDTO:
|
) -> ImageDTO:
|
||||||
"""Converts an image record to an image DTO."""
|
"""Converts an image record to an image DTO."""
|
||||||
return ImageDTO(
|
return ImageDTO(
|
||||||
**image_record.dict(),
|
**image_record.model_dump(),
|
||||||
image_url=image_url,
|
image_url=image_url,
|
||||||
thumbnail_url=thumbnail_url,
|
thumbnail_url=thumbnail_url,
|
||||||
board_id=board_id,
|
board_id=board_id,
|
||||||
|
@ -41,7 +41,7 @@ class ImageService(ImageServiceABC):
|
|||||||
node_id: Optional[str] = None,
|
node_id: Optional[str] = None,
|
||||||
session_id: Optional[str] = None,
|
session_id: Optional[str] = None,
|
||||||
board_id: Optional[str] = None,
|
board_id: Optional[str] = None,
|
||||||
is_intermediate: bool = False,
|
is_intermediate: Optional[bool] = False,
|
||||||
metadata: Optional[dict] = None,
|
metadata: Optional[dict] = None,
|
||||||
workflow: Optional[str] = None,
|
workflow: Optional[str] = None,
|
||||||
) -> ImageDTO:
|
) -> ImageDTO:
|
||||||
@ -146,7 +146,7 @@ class ImageService(ImageServiceABC):
|
|||||||
self.__invoker.services.logger.error("Problem getting image DTO")
|
self.__invoker.services.logger.error("Problem getting image DTO")
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
def get_metadata(self, image_name: str) -> Optional[ImageMetadata]:
|
def get_metadata(self, image_name: str) -> ImageMetadata:
|
||||||
try:
|
try:
|
||||||
image_record = self.__invoker.services.image_records.get(image_name)
|
image_record = self.__invoker.services.image_records.get(image_name)
|
||||||
metadata = self.__invoker.services.image_records.get_metadata(image_name)
|
metadata = self.__invoker.services.image_records.get_metadata(image_name)
|
||||||
@ -174,7 +174,7 @@ class ImageService(ImageServiceABC):
|
|||||||
|
|
||||||
def get_path(self, image_name: str, thumbnail: bool = False) -> str:
|
def get_path(self, image_name: str, thumbnail: bool = False) -> str:
|
||||||
try:
|
try:
|
||||||
return self.__invoker.services.image_files.get_path(image_name, thumbnail)
|
return str(self.__invoker.services.image_files.get_path(image_name, thumbnail))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.__invoker.services.logger.error("Problem getting image path")
|
self.__invoker.services.logger.error("Problem getting image path")
|
||||||
raise e
|
raise e
|
||||||
|
@ -58,7 +58,12 @@ class MemoryInvocationCache(InvocationCacheBase):
|
|||||||
# If the cache is full, we need to remove the least used
|
# If the cache is full, we need to remove the least used
|
||||||
number_to_delete = len(self._cache) + 1 - self._max_cache_size
|
number_to_delete = len(self._cache) + 1 - self._max_cache_size
|
||||||
self._delete_oldest_access(number_to_delete)
|
self._delete_oldest_access(number_to_delete)
|
||||||
self._cache[key] = CachedItem(invocation_output, invocation_output.json())
|
self._cache[key] = CachedItem(
|
||||||
|
invocation_output,
|
||||||
|
invocation_output.model_dump_json(
|
||||||
|
warnings=False, exclude_defaults=True, exclude_unset=True, include={"type"}
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
def _delete_oldest_access(self, number_to_delete: int) -> None:
|
def _delete_oldest_access(self, number_to_delete: int) -> None:
|
||||||
number_to_delete = min(number_to_delete, len(self._cache))
|
number_to_delete = min(number_to_delete, len(self._cache))
|
||||||
@ -85,7 +90,7 @@ class MemoryInvocationCache(InvocationCacheBase):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def create_key(invocation: BaseInvocation) -> int:
|
def create_key(invocation: BaseInvocation) -> int:
|
||||||
return hash(invocation.json(exclude={"id"}))
|
return hash(invocation.model_dump_json(exclude={"id"}, warnings=False))
|
||||||
|
|
||||||
def disable(self) -> None:
|
def disable(self) -> None:
|
||||||
with self._lock:
|
with self._lock:
|
||||||
|
@ -89,7 +89,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
|||||||
queue_item_id=queue_item.session_queue_item_id,
|
queue_item_id=queue_item.session_queue_item_id,
|
||||||
queue_id=queue_item.session_queue_id,
|
queue_id=queue_item.session_queue_id,
|
||||||
graph_execution_state_id=graph_execution_state.id,
|
graph_execution_state_id=graph_execution_state.id,
|
||||||
node=invocation.dict(),
|
node=invocation.model_dump(),
|
||||||
source_node_id=source_node_id,
|
source_node_id=source_node_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -127,9 +127,9 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
|||||||
queue_item_id=queue_item.session_queue_item_id,
|
queue_item_id=queue_item.session_queue_item_id,
|
||||||
queue_id=queue_item.session_queue_id,
|
queue_id=queue_item.session_queue_id,
|
||||||
graph_execution_state_id=graph_execution_state.id,
|
graph_execution_state_id=graph_execution_state.id,
|
||||||
node=invocation.dict(),
|
node=invocation.model_dump(),
|
||||||
source_node_id=source_node_id,
|
source_node_id=source_node_id,
|
||||||
result=outputs.dict(),
|
result=outputs.model_dump(),
|
||||||
)
|
)
|
||||||
self.__invoker.services.performance_statistics.log_stats()
|
self.__invoker.services.performance_statistics.log_stats()
|
||||||
|
|
||||||
@ -157,7 +157,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
|||||||
queue_item_id=queue_item.session_queue_item_id,
|
queue_item_id=queue_item.session_queue_item_id,
|
||||||
queue_id=queue_item.session_queue_id,
|
queue_id=queue_item.session_queue_id,
|
||||||
graph_execution_state_id=graph_execution_state.id,
|
graph_execution_state_id=graph_execution_state.id,
|
||||||
node=invocation.dict(),
|
node=invocation.model_dump(),
|
||||||
source_node_id=source_node_id,
|
source_node_id=source_node_id,
|
||||||
error_type=e.__class__.__name__,
|
error_type=e.__class__.__name__,
|
||||||
error=error,
|
error=error,
|
||||||
@ -187,7 +187,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
|||||||
queue_item_id=queue_item.session_queue_item_id,
|
queue_item_id=queue_item.session_queue_item_id,
|
||||||
queue_id=queue_item.session_queue_id,
|
queue_id=queue_item.session_queue_id,
|
||||||
graph_execution_state_id=graph_execution_state.id,
|
graph_execution_state_id=graph_execution_state.id,
|
||||||
node=invocation.dict(),
|
node=invocation.model_dump(),
|
||||||
source_node_id=source_node_id,
|
source_node_id=source_node_id,
|
||||||
error_type=e.__class__.__name__,
|
error_type=e.__class__.__name__,
|
||||||
error=traceback.format_exc(),
|
error=traceback.format_exc(),
|
||||||
|
@ -72,7 +72,7 @@ class InvocationStatsService(InvocationStatsServiceBase):
|
|||||||
)
|
)
|
||||||
self.collector.update_invocation_stats(
|
self.collector.update_invocation_stats(
|
||||||
graph_id=self.graph_id,
|
graph_id=self.graph_id,
|
||||||
invocation_type=self.invocation.type, # type: ignore - `type` is not on the `BaseInvocation` model, but *is* on all invocations
|
invocation_type=self.invocation.type, # type: ignore # `type` is not on the `BaseInvocation` model, but *is* on all invocations
|
||||||
time_used=time.time() - self.start_time,
|
time_used=time.time() - self.start_time,
|
||||||
vram_used=torch.cuda.max_memory_allocated() / GIG if torch.cuda.is_available() else 0.0,
|
vram_used=torch.cuda.max_memory_allocated() / GIG if torch.cuda.is_available() else 0.0,
|
||||||
)
|
)
|
||||||
|
@ -2,7 +2,7 @@ import sqlite3
|
|||||||
import threading
|
import threading
|
||||||
from typing import Generic, Optional, TypeVar, get_args
|
from typing import Generic, Optional, TypeVar, get_args
|
||||||
|
|
||||||
from pydantic import BaseModel, parse_raw_as
|
from pydantic import BaseModel, TypeAdapter
|
||||||
|
|
||||||
from invokeai.app.services.shared.pagination import PaginatedResults
|
from invokeai.app.services.shared.pagination import PaginatedResults
|
||||||
from invokeai.app.services.shared.sqlite import SqliteDatabase
|
from invokeai.app.services.shared.sqlite import SqliteDatabase
|
||||||
@ -18,6 +18,7 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
|||||||
_cursor: sqlite3.Cursor
|
_cursor: sqlite3.Cursor
|
||||||
_id_field: str
|
_id_field: str
|
||||||
_lock: threading.RLock
|
_lock: threading.RLock
|
||||||
|
_adapter: Optional[TypeAdapter[T]]
|
||||||
|
|
||||||
def __init__(self, db: SqliteDatabase, table_name: str, id_field: str = "id"):
|
def __init__(self, db: SqliteDatabase, table_name: str, id_field: str = "id"):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -27,6 +28,7 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
|||||||
self._table_name = table_name
|
self._table_name = table_name
|
||||||
self._id_field = id_field # TODO: validate that T has this field
|
self._id_field = id_field # TODO: validate that T has this field
|
||||||
self._cursor = self._conn.cursor()
|
self._cursor = self._conn.cursor()
|
||||||
|
self._adapter: Optional[TypeAdapter[T]] = None
|
||||||
|
|
||||||
self._create_table()
|
self._create_table()
|
||||||
|
|
||||||
@ -45,16 +47,21 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
|||||||
self._lock.release()
|
self._lock.release()
|
||||||
|
|
||||||
def _parse_item(self, item: str) -> T:
|
def _parse_item(self, item: str) -> T:
|
||||||
# __orig_class__ is technically an implementation detail of the typing module, not a supported API
|
if self._adapter is None:
|
||||||
item_type = get_args(self.__orig_class__)[0] # type: ignore
|
"""
|
||||||
return parse_raw_as(item_type, item)
|
We don't get access to `__orig_class__` in `__init__()`, and we need this before start(), so
|
||||||
|
we can create it when it is first needed instead.
|
||||||
|
__orig_class__ is technically an implementation detail of the typing module, not a supported API
|
||||||
|
"""
|
||||||
|
self._adapter = TypeAdapter(get_args(self.__orig_class__)[0]) # type: ignore [attr-defined]
|
||||||
|
return self._adapter.validate_json(item)
|
||||||
|
|
||||||
def set(self, item: T):
|
def set(self, item: T):
|
||||||
try:
|
try:
|
||||||
self._lock.acquire()
|
self._lock.acquire()
|
||||||
self._cursor.execute(
|
self._cursor.execute(
|
||||||
f"""INSERT OR REPLACE INTO {self._table_name} (item) VALUES (?);""",
|
f"""INSERT OR REPLACE INTO {self._table_name} (item) VALUES (?);""",
|
||||||
(item.json(),),
|
(item.model_dump_json(warnings=False, exclude_none=True),),
|
||||||
)
|
)
|
||||||
self._conn.commit()
|
self._conn.commit()
|
||||||
finally:
|
finally:
|
||||||
|
@ -231,7 +231,7 @@ class ModelManagerServiceBase(ABC):
|
|||||||
def merge_models(
|
def merge_models(
|
||||||
self,
|
self,
|
||||||
model_names: List[str] = Field(
|
model_names: List[str] = Field(
|
||||||
default=None, min_items=2, max_items=3, description="List of model names to merge"
|
default=None, min_length=2, max_length=3, description="List of model names to merge"
|
||||||
),
|
),
|
||||||
base_model: Union[BaseModelType, str] = Field(
|
base_model: Union[BaseModelType, str] = Field(
|
||||||
default=None, description="Base model shared by all models to be merged"
|
default=None, description="Base model shared by all models to be merged"
|
||||||
|
@ -327,7 +327,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
def merge_models(
|
def merge_models(
|
||||||
self,
|
self,
|
||||||
model_names: List[str] = Field(
|
model_names: List[str] = Field(
|
||||||
default=None, min_items=2, max_items=3, description="List of model names to merge"
|
default=None, min_length=2, max_length=3, description="List of model names to merge"
|
||||||
),
|
),
|
||||||
base_model: Union[BaseModelType, str] = Field(
|
base_model: Union[BaseModelType, str] = Field(
|
||||||
default=None, description="Base model shared by all models to be merged"
|
default=None, description="Base model shared by all models to be merged"
|
||||||
|
@ -3,8 +3,8 @@ import json
|
|||||||
from itertools import chain, product
|
from itertools import chain, product
|
||||||
from typing import Generator, Iterable, Literal, NamedTuple, Optional, TypeAlias, Union, cast
|
from typing import Generator, Iterable, Literal, NamedTuple, Optional, TypeAlias, Union, cast
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, StrictStr, parse_raw_as, root_validator, validator
|
from pydantic import BaseModel, ConfigDict, Field, StrictStr, TypeAdapter, field_validator, model_validator
|
||||||
from pydantic.json import pydantic_encoder
|
from pydantic_core import to_jsonable_python
|
||||||
|
|
||||||
from invokeai.app.invocations.baseinvocation import BaseInvocation
|
from invokeai.app.invocations.baseinvocation import BaseInvocation
|
||||||
from invokeai.app.services.shared.graph import Graph, GraphExecutionState, NodeNotFoundError
|
from invokeai.app.services.shared.graph import Graph, GraphExecutionState, NodeNotFoundError
|
||||||
@ -17,7 +17,7 @@ class BatchZippedLengthError(ValueError):
|
|||||||
"""Raise when a batch has items of different lengths."""
|
"""Raise when a batch has items of different lengths."""
|
||||||
|
|
||||||
|
|
||||||
class BatchItemsTypeError(TypeError):
|
class BatchItemsTypeError(ValueError): # this cannot be a TypeError in pydantic v2
|
||||||
"""Raise when a batch has items of different types."""
|
"""Raise when a batch has items of different types."""
|
||||||
|
|
||||||
|
|
||||||
@ -70,7 +70,7 @@ class Batch(BaseModel):
|
|||||||
default=1, ge=1, description="Int stating how many times to iterate through all possible batch indices"
|
default=1, ge=1, description="Int stating how many times to iterate through all possible batch indices"
|
||||||
)
|
)
|
||||||
|
|
||||||
@validator("data")
|
@field_validator("data")
|
||||||
def validate_lengths(cls, v: Optional[BatchDataCollection]):
|
def validate_lengths(cls, v: Optional[BatchDataCollection]):
|
||||||
if v is None:
|
if v is None:
|
||||||
return v
|
return v
|
||||||
@ -81,7 +81,7 @@ class Batch(BaseModel):
|
|||||||
raise BatchZippedLengthError("Zipped batch items must all have the same length")
|
raise BatchZippedLengthError("Zipped batch items must all have the same length")
|
||||||
return v
|
return v
|
||||||
|
|
||||||
@validator("data")
|
@field_validator("data")
|
||||||
def validate_types(cls, v: Optional[BatchDataCollection]):
|
def validate_types(cls, v: Optional[BatchDataCollection]):
|
||||||
if v is None:
|
if v is None:
|
||||||
return v
|
return v
|
||||||
@ -94,7 +94,7 @@ class Batch(BaseModel):
|
|||||||
raise BatchItemsTypeError("All items in a batch must have the same type")
|
raise BatchItemsTypeError("All items in a batch must have the same type")
|
||||||
return v
|
return v
|
||||||
|
|
||||||
@validator("data")
|
@field_validator("data")
|
||||||
def validate_unique_field_mappings(cls, v: Optional[BatchDataCollection]):
|
def validate_unique_field_mappings(cls, v: Optional[BatchDataCollection]):
|
||||||
if v is None:
|
if v is None:
|
||||||
return v
|
return v
|
||||||
@ -107,34 +107,35 @@ class Batch(BaseModel):
|
|||||||
paths.add(pair)
|
paths.add(pair)
|
||||||
return v
|
return v
|
||||||
|
|
||||||
@root_validator(skip_on_failure=True)
|
@model_validator(mode="after")
|
||||||
def validate_batch_nodes_and_edges(cls, values):
|
def validate_batch_nodes_and_edges(cls, values):
|
||||||
batch_data_collection = cast(Optional[BatchDataCollection], values["data"])
|
batch_data_collection = cast(Optional[BatchDataCollection], values.data)
|
||||||
if batch_data_collection is None:
|
if batch_data_collection is None:
|
||||||
return values
|
return values
|
||||||
graph = cast(Graph, values["graph"])
|
graph = cast(Graph, values.graph)
|
||||||
for batch_data_list in batch_data_collection:
|
for batch_data_list in batch_data_collection:
|
||||||
for batch_data in batch_data_list:
|
for batch_data in batch_data_list:
|
||||||
try:
|
try:
|
||||||
node = cast(BaseInvocation, graph.get_node(batch_data.node_path))
|
node = cast(BaseInvocation, graph.get_node(batch_data.node_path))
|
||||||
except NodeNotFoundError:
|
except NodeNotFoundError:
|
||||||
raise NodeNotFoundError(f"Node {batch_data.node_path} not found in graph")
|
raise NodeNotFoundError(f"Node {batch_data.node_path} not found in graph")
|
||||||
if batch_data.field_name not in node.__fields__:
|
if batch_data.field_name not in node.model_fields:
|
||||||
raise NodeNotFoundError(f"Field {batch_data.field_name} not found in node {batch_data.node_path}")
|
raise NodeNotFoundError(f"Field {batch_data.field_name} not found in node {batch_data.node_path}")
|
||||||
return values
|
return values
|
||||||
|
|
||||||
@validator("graph")
|
@field_validator("graph")
|
||||||
def validate_graph(cls, v: Graph):
|
def validate_graph(cls, v: Graph):
|
||||||
v.validate_self()
|
v.validate_self()
|
||||||
return v
|
return v
|
||||||
|
|
||||||
class Config:
|
model_config = ConfigDict(
|
||||||
schema_extra = {
|
json_schema_extra=dict(
|
||||||
"required": [
|
required=[
|
||||||
"graph",
|
"graph",
|
||||||
"runs",
|
"runs",
|
||||||
]
|
]
|
||||||
}
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# endregion Batch
|
# endregion Batch
|
||||||
@ -146,15 +147,21 @@ DEFAULT_QUEUE_ID = "default"
|
|||||||
|
|
||||||
QUEUE_ITEM_STATUS = Literal["pending", "in_progress", "completed", "failed", "canceled"]
|
QUEUE_ITEM_STATUS = Literal["pending", "in_progress", "completed", "failed", "canceled"]
|
||||||
|
|
||||||
|
adapter_NodeFieldValue = TypeAdapter(list[NodeFieldValue])
|
||||||
|
|
||||||
|
|
||||||
def get_field_values(queue_item_dict: dict) -> Optional[list[NodeFieldValue]]:
|
def get_field_values(queue_item_dict: dict) -> Optional[list[NodeFieldValue]]:
|
||||||
field_values_raw = queue_item_dict.get("field_values", None)
|
field_values_raw = queue_item_dict.get("field_values", None)
|
||||||
return parse_raw_as(list[NodeFieldValue], field_values_raw) if field_values_raw is not None else None
|
return adapter_NodeFieldValue.validate_json(field_values_raw) if field_values_raw is not None else None
|
||||||
|
|
||||||
|
|
||||||
|
adapter_GraphExecutionState = TypeAdapter(GraphExecutionState)
|
||||||
|
|
||||||
|
|
||||||
def get_session(queue_item_dict: dict) -> GraphExecutionState:
|
def get_session(queue_item_dict: dict) -> GraphExecutionState:
|
||||||
session_raw = queue_item_dict.get("session", "{}")
|
session_raw = queue_item_dict.get("session", "{}")
|
||||||
return parse_raw_as(GraphExecutionState, session_raw)
|
session = adapter_GraphExecutionState.validate_json(session_raw, strict=False)
|
||||||
|
return session
|
||||||
|
|
||||||
|
|
||||||
class SessionQueueItemWithoutGraph(BaseModel):
|
class SessionQueueItemWithoutGraph(BaseModel):
|
||||||
@ -178,14 +185,14 @@ class SessionQueueItemWithoutGraph(BaseModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_dict(cls, queue_item_dict: dict) -> "SessionQueueItemDTO":
|
def queue_item_dto_from_dict(cls, queue_item_dict: dict) -> "SessionQueueItemDTO":
|
||||||
# must parse these manually
|
# must parse these manually
|
||||||
queue_item_dict["field_values"] = get_field_values(queue_item_dict)
|
queue_item_dict["field_values"] = get_field_values(queue_item_dict)
|
||||||
return SessionQueueItemDTO(**queue_item_dict)
|
return SessionQueueItemDTO(**queue_item_dict)
|
||||||
|
|
||||||
class Config:
|
model_config = ConfigDict(
|
||||||
schema_extra = {
|
json_schema_extra=dict(
|
||||||
"required": [
|
required=[
|
||||||
"item_id",
|
"item_id",
|
||||||
"status",
|
"status",
|
||||||
"batch_id",
|
"batch_id",
|
||||||
@ -196,7 +203,8 @@ class SessionQueueItemWithoutGraph(BaseModel):
|
|||||||
"created_at",
|
"created_at",
|
||||||
"updated_at",
|
"updated_at",
|
||||||
]
|
]
|
||||||
}
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class SessionQueueItemDTO(SessionQueueItemWithoutGraph):
|
class SessionQueueItemDTO(SessionQueueItemWithoutGraph):
|
||||||
@ -207,15 +215,15 @@ class SessionQueueItem(SessionQueueItemWithoutGraph):
|
|||||||
session: GraphExecutionState = Field(description="The fully-populated session to be executed")
|
session: GraphExecutionState = Field(description="The fully-populated session to be executed")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_dict(cls, queue_item_dict: dict) -> "SessionQueueItem":
|
def queue_item_from_dict(cls, queue_item_dict: dict) -> "SessionQueueItem":
|
||||||
# must parse these manually
|
# must parse these manually
|
||||||
queue_item_dict["field_values"] = get_field_values(queue_item_dict)
|
queue_item_dict["field_values"] = get_field_values(queue_item_dict)
|
||||||
queue_item_dict["session"] = get_session(queue_item_dict)
|
queue_item_dict["session"] = get_session(queue_item_dict)
|
||||||
return SessionQueueItem(**queue_item_dict)
|
return SessionQueueItem(**queue_item_dict)
|
||||||
|
|
||||||
class Config:
|
model_config = ConfigDict(
|
||||||
schema_extra = {
|
json_schema_extra=dict(
|
||||||
"required": [
|
required=[
|
||||||
"item_id",
|
"item_id",
|
||||||
"status",
|
"status",
|
||||||
"batch_id",
|
"batch_id",
|
||||||
@ -227,7 +235,8 @@ class SessionQueueItem(SessionQueueItemWithoutGraph):
|
|||||||
"created_at",
|
"created_at",
|
||||||
"updated_at",
|
"updated_at",
|
||||||
]
|
]
|
||||||
}
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# endregion Queue Items
|
# endregion Queue Items
|
||||||
@ -321,7 +330,7 @@ def populate_graph(graph: Graph, node_field_values: Iterable[NodeFieldValue]) ->
|
|||||||
"""
|
"""
|
||||||
Populates the given graph with the given batch data items.
|
Populates the given graph with the given batch data items.
|
||||||
"""
|
"""
|
||||||
graph_clone = graph.copy(deep=True)
|
graph_clone = graph.model_copy(deep=True)
|
||||||
for item in node_field_values:
|
for item in node_field_values:
|
||||||
node = graph_clone.get_node(item.node_path)
|
node = graph_clone.get_node(item.node_path)
|
||||||
if node is None:
|
if node is None:
|
||||||
@ -354,7 +363,7 @@ def create_session_nfv_tuples(
|
|||||||
for item in batch_datum.items
|
for item in batch_datum.items
|
||||||
]
|
]
|
||||||
node_field_values_to_zip.append(node_field_values)
|
node_field_values_to_zip.append(node_field_values)
|
||||||
data.append(list(zip(*node_field_values_to_zip)))
|
data.append(list(zip(*node_field_values_to_zip))) # type: ignore [arg-type]
|
||||||
|
|
||||||
# create generator to yield session,nfv tuples
|
# create generator to yield session,nfv tuples
|
||||||
count = 0
|
count = 0
|
||||||
@ -409,11 +418,11 @@ def prepare_values_to_insert(queue_id: str, batch: Batch, priority: int, max_new
|
|||||||
values_to_insert.append(
|
values_to_insert.append(
|
||||||
SessionQueueValueToInsert(
|
SessionQueueValueToInsert(
|
||||||
queue_id, # queue_id
|
queue_id, # queue_id
|
||||||
session.json(), # session (json)
|
session.model_dump_json(warnings=False, exclude_none=True), # session (json)
|
||||||
session.id, # session_id
|
session.id, # session_id
|
||||||
batch.batch_id, # batch_id
|
batch.batch_id, # batch_id
|
||||||
# must use pydantic_encoder bc field_values is a list of models
|
# must use pydantic_encoder bc field_values is a list of models
|
||||||
json.dumps(field_values, default=pydantic_encoder) if field_values else None, # field_values (json)
|
json.dumps(field_values, default=to_jsonable_python) if field_values else None, # field_values (json)
|
||||||
priority, # priority
|
priority, # priority
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@ -421,3 +430,6 @@ def prepare_values_to_insert(queue_id: str, batch: Batch, priority: int, max_new
|
|||||||
|
|
||||||
|
|
||||||
# endregion Util
|
# endregion Util
|
||||||
|
|
||||||
|
Batch.model_rebuild(force=True)
|
||||||
|
SessionQueueItem.model_rebuild(force=True)
|
||||||
|
@ -277,8 +277,8 @@ class SqliteSessionQueue(SessionQueueBase):
|
|||||||
if result is None:
|
if result is None:
|
||||||
raise SessionQueueItemNotFoundError(f"No queue item with batch id {enqueue_result.batch.batch_id}")
|
raise SessionQueueItemNotFoundError(f"No queue item with batch id {enqueue_result.batch.batch_id}")
|
||||||
return EnqueueGraphResult(
|
return EnqueueGraphResult(
|
||||||
**enqueue_result.dict(),
|
**enqueue_result.model_dump(),
|
||||||
queue_item=SessionQueueItemDTO.from_dict(dict(result)),
|
queue_item=SessionQueueItemDTO.queue_item_dto_from_dict(dict(result)),
|
||||||
)
|
)
|
||||||
|
|
||||||
def enqueue_batch(self, queue_id: str, batch: Batch, prepend: bool) -> EnqueueBatchResult:
|
def enqueue_batch(self, queue_id: str, batch: Batch, prepend: bool) -> EnqueueBatchResult:
|
||||||
@ -351,7 +351,7 @@ class SqliteSessionQueue(SessionQueueBase):
|
|||||||
self.__lock.release()
|
self.__lock.release()
|
||||||
if result is None:
|
if result is None:
|
||||||
return None
|
return None
|
||||||
queue_item = SessionQueueItem.from_dict(dict(result))
|
queue_item = SessionQueueItem.queue_item_from_dict(dict(result))
|
||||||
queue_item = self._set_queue_item_status(item_id=queue_item.item_id, status="in_progress")
|
queue_item = self._set_queue_item_status(item_id=queue_item.item_id, status="in_progress")
|
||||||
return queue_item
|
return queue_item
|
||||||
|
|
||||||
@ -380,7 +380,7 @@ class SqliteSessionQueue(SessionQueueBase):
|
|||||||
self.__lock.release()
|
self.__lock.release()
|
||||||
if result is None:
|
if result is None:
|
||||||
return None
|
return None
|
||||||
return SessionQueueItem.from_dict(dict(result))
|
return SessionQueueItem.queue_item_from_dict(dict(result))
|
||||||
|
|
||||||
def get_current(self, queue_id: str) -> Optional[SessionQueueItem]:
|
def get_current(self, queue_id: str) -> Optional[SessionQueueItem]:
|
||||||
try:
|
try:
|
||||||
@ -404,7 +404,7 @@ class SqliteSessionQueue(SessionQueueBase):
|
|||||||
self.__lock.release()
|
self.__lock.release()
|
||||||
if result is None:
|
if result is None:
|
||||||
return None
|
return None
|
||||||
return SessionQueueItem.from_dict(dict(result))
|
return SessionQueueItem.queue_item_from_dict(dict(result))
|
||||||
|
|
||||||
def _set_queue_item_status(
|
def _set_queue_item_status(
|
||||||
self, item_id: int, status: QUEUE_ITEM_STATUS, error: Optional[str] = None
|
self, item_id: int, status: QUEUE_ITEM_STATUS, error: Optional[str] = None
|
||||||
@ -564,7 +564,7 @@ class SqliteSessionQueue(SessionQueueBase):
|
|||||||
queue_item = self.get_queue_item(item_id)
|
queue_item = self.get_queue_item(item_id)
|
||||||
if queue_item.status not in ["canceled", "failed", "completed"]:
|
if queue_item.status not in ["canceled", "failed", "completed"]:
|
||||||
status = "failed" if error is not None else "canceled"
|
status = "failed" if error is not None else "canceled"
|
||||||
queue_item = self._set_queue_item_status(item_id=item_id, status=status, error=error)
|
queue_item = self._set_queue_item_status(item_id=item_id, status=status, error=error) # type: ignore [arg-type] # mypy seems to not narrow the Literals here
|
||||||
self.__invoker.services.queue.cancel(queue_item.session_id)
|
self.__invoker.services.queue.cancel(queue_item.session_id)
|
||||||
self.__invoker.services.events.emit_session_canceled(
|
self.__invoker.services.events.emit_session_canceled(
|
||||||
queue_item_id=queue_item.item_id,
|
queue_item_id=queue_item.item_id,
|
||||||
@ -699,7 +699,7 @@ class SqliteSessionQueue(SessionQueueBase):
|
|||||||
self.__lock.release()
|
self.__lock.release()
|
||||||
if result is None:
|
if result is None:
|
||||||
raise SessionQueueItemNotFoundError(f"No queue item with id {item_id}")
|
raise SessionQueueItemNotFoundError(f"No queue item with id {item_id}")
|
||||||
return SessionQueueItem.from_dict(dict(result))
|
return SessionQueueItem.queue_item_from_dict(dict(result))
|
||||||
|
|
||||||
def list_queue_items(
|
def list_queue_items(
|
||||||
self,
|
self,
|
||||||
@ -751,7 +751,7 @@ class SqliteSessionQueue(SessionQueueBase):
|
|||||||
params.append(limit + 1)
|
params.append(limit + 1)
|
||||||
self.__cursor.execute(query, params)
|
self.__cursor.execute(query, params)
|
||||||
results = cast(list[sqlite3.Row], self.__cursor.fetchall())
|
results = cast(list[sqlite3.Row], self.__cursor.fetchall())
|
||||||
items = [SessionQueueItemDTO.from_dict(dict(result)) for result in results]
|
items = [SessionQueueItemDTO.queue_item_dto_from_dict(dict(result)) for result in results]
|
||||||
has_more = False
|
has_more = False
|
||||||
if len(items) > limit:
|
if len(items) > limit:
|
||||||
# remove the extra item
|
# remove the extra item
|
||||||
|
@ -80,10 +80,10 @@ def create_system_graphs(graph_library: ItemStorageABC[LibraryGraph]) -> list[Li
|
|||||||
# TODO: Uncomment this when we are ready to fix this up to prevent breaking changes
|
# TODO: Uncomment this when we are ready to fix this up to prevent breaking changes
|
||||||
graphs: list[LibraryGraph] = list()
|
graphs: list[LibraryGraph] = list()
|
||||||
|
|
||||||
# text_to_image = graph_library.get(default_text_to_image_graph_id)
|
text_to_image = graph_library.get(default_text_to_image_graph_id)
|
||||||
|
|
||||||
# # TODO: Check if the graph is the same as the default one, and if not, update it
|
# TODO: Check if the graph is the same as the default one, and if not, update it
|
||||||
# #if text_to_image is None:
|
# if text_to_image is None:
|
||||||
text_to_image = create_text_to_image()
|
text_to_image = create_text_to_image()
|
||||||
graph_library.set(text_to_image)
|
graph_library.set(text_to_image)
|
||||||
|
|
||||||
|
@ -5,7 +5,7 @@ import itertools
|
|||||||
from typing import Annotated, Any, Optional, Union, get_args, get_origin, get_type_hints
|
from typing import Annotated, Any, Optional, Union, get_args, get_origin, get_type_hints
|
||||||
|
|
||||||
import networkx as nx
|
import networkx as nx
|
||||||
from pydantic import BaseModel, root_validator, validator
|
from pydantic import BaseModel, ConfigDict, field_validator, model_validator
|
||||||
from pydantic.fields import Field
|
from pydantic.fields import Field
|
||||||
|
|
||||||
# Importing * is bad karma but needed here for node detection
|
# Importing * is bad karma but needed here for node detection
|
||||||
@ -235,7 +235,8 @@ class CollectInvocationOutput(BaseInvocationOutput):
|
|||||||
class CollectInvocation(BaseInvocation):
|
class CollectInvocation(BaseInvocation):
|
||||||
"""Collects values into a collection"""
|
"""Collects values into a collection"""
|
||||||
|
|
||||||
item: Any = InputField(
|
item: Optional[Any] = InputField(
|
||||||
|
default=None,
|
||||||
description="The item to collect (all inputs must be of the same type)",
|
description="The item to collect (all inputs must be of the same type)",
|
||||||
ui_type=UIType.CollectionItem,
|
ui_type=UIType.CollectionItem,
|
||||||
title="Collection Item",
|
title="Collection Item",
|
||||||
@ -250,8 +251,8 @@ class CollectInvocation(BaseInvocation):
|
|||||||
return CollectInvocationOutput(collection=copy.copy(self.collection))
|
return CollectInvocationOutput(collection=copy.copy(self.collection))
|
||||||
|
|
||||||
|
|
||||||
InvocationsUnion = Union[BaseInvocation.get_invocations()] # type: ignore
|
InvocationsUnion: Any = BaseInvocation.get_invocations_union()
|
||||||
InvocationOutputsUnion = Union[BaseInvocationOutput.get_all_subclasses_tuple()] # type: ignore
|
InvocationOutputsUnion: Any = BaseInvocationOutput.get_outputs_union()
|
||||||
|
|
||||||
|
|
||||||
class Graph(BaseModel):
|
class Graph(BaseModel):
|
||||||
@ -378,13 +379,13 @@ class Graph(BaseModel):
|
|||||||
raise NodeNotFoundError(f"Edge destination node {edge.destination.node_id} does not exist in the graph")
|
raise NodeNotFoundError(f"Edge destination node {edge.destination.node_id} does not exist in the graph")
|
||||||
|
|
||||||
# output fields are not on the node object directly, they are on the output type
|
# output fields are not on the node object directly, they are on the output type
|
||||||
if edge.source.field not in source_node.get_output_type().__fields__:
|
if edge.source.field not in source_node.get_output_type().model_fields:
|
||||||
raise NodeFieldNotFoundError(
|
raise NodeFieldNotFoundError(
|
||||||
f"Edge source field {edge.source.field} does not exist in node {edge.source.node_id}"
|
f"Edge source field {edge.source.field} does not exist in node {edge.source.node_id}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# input fields are on the node
|
# input fields are on the node
|
||||||
if edge.destination.field not in destination_node.__fields__:
|
if edge.destination.field not in destination_node.model_fields:
|
||||||
raise NodeFieldNotFoundError(
|
raise NodeFieldNotFoundError(
|
||||||
f"Edge destination field {edge.destination.field} does not exist in node {edge.destination.node_id}"
|
f"Edge destination field {edge.destination.field} does not exist in node {edge.destination.node_id}"
|
||||||
)
|
)
|
||||||
@ -395,24 +396,24 @@ class Graph(BaseModel):
|
|||||||
raise CyclicalGraphError("Graph contains cycles")
|
raise CyclicalGraphError("Graph contains cycles")
|
||||||
|
|
||||||
# Validate all edge connections are valid
|
# Validate all edge connections are valid
|
||||||
for e in self.edges:
|
for edge in self.edges:
|
||||||
if not are_connections_compatible(
|
if not are_connections_compatible(
|
||||||
self.get_node(e.source.node_id),
|
self.get_node(edge.source.node_id),
|
||||||
e.source.field,
|
edge.source.field,
|
||||||
self.get_node(e.destination.node_id),
|
self.get_node(edge.destination.node_id),
|
||||||
e.destination.field,
|
edge.destination.field,
|
||||||
):
|
):
|
||||||
raise InvalidEdgeError(
|
raise InvalidEdgeError(
|
||||||
f"Invalid edge from {e.source.node_id}.{e.source.field} to {e.destination.node_id}.{e.destination.field}"
|
f"Invalid edge from {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Validate all iterators & collectors
|
# Validate all iterators & collectors
|
||||||
# TODO: may need to validate all iterators & collectors in subgraphs so edge connections in parent graphs will be available
|
# TODO: may need to validate all iterators & collectors in subgraphs so edge connections in parent graphs will be available
|
||||||
for n in self.nodes.values():
|
for node in self.nodes.values():
|
||||||
if isinstance(n, IterateInvocation) and not self._is_iterator_connection_valid(n.id):
|
if isinstance(node, IterateInvocation) and not self._is_iterator_connection_valid(node.id):
|
||||||
raise InvalidEdgeError(f"Invalid iterator node {n.id}")
|
raise InvalidEdgeError(f"Invalid iterator node {node.id}")
|
||||||
if isinstance(n, CollectInvocation) and not self._is_collector_connection_valid(n.id):
|
if isinstance(node, CollectInvocation) and not self._is_collector_connection_valid(node.id):
|
||||||
raise InvalidEdgeError(f"Invalid collector node {n.id}")
|
raise InvalidEdgeError(f"Invalid collector node {node.id}")
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@ -594,7 +595,7 @@ class Graph(BaseModel):
|
|||||||
|
|
||||||
def _get_input_edges_and_graphs(
|
def _get_input_edges_and_graphs(
|
||||||
self, node_path: str, prefix: Optional[str] = None
|
self, node_path: str, prefix: Optional[str] = None
|
||||||
) -> list[tuple["Graph", str, Edge]]:
|
) -> list[tuple["Graph", Union[str, None], Edge]]:
|
||||||
"""Gets all input edges for a node along with the graph they are in and the graph's path"""
|
"""Gets all input edges for a node along with the graph they are in and the graph's path"""
|
||||||
edges = list()
|
edges = list()
|
||||||
|
|
||||||
@ -636,7 +637,7 @@ class Graph(BaseModel):
|
|||||||
|
|
||||||
def _get_output_edges_and_graphs(
|
def _get_output_edges_and_graphs(
|
||||||
self, node_path: str, prefix: Optional[str] = None
|
self, node_path: str, prefix: Optional[str] = None
|
||||||
) -> list[tuple["Graph", str, Edge]]:
|
) -> list[tuple["Graph", Union[str, None], Edge]]:
|
||||||
"""Gets all output edges for a node along with the graph they are in and the graph's path"""
|
"""Gets all output edges for a node along with the graph they are in and the graph's path"""
|
||||||
edges = list()
|
edges = list()
|
||||||
|
|
||||||
@ -817,15 +818,15 @@ class GraphExecutionState(BaseModel):
|
|||||||
default_factory=dict,
|
default_factory=dict,
|
||||||
)
|
)
|
||||||
|
|
||||||
@validator("graph")
|
@field_validator("graph")
|
||||||
def graph_is_valid(cls, v: Graph):
|
def graph_is_valid(cls, v: Graph):
|
||||||
"""Validates that the graph is valid"""
|
"""Validates that the graph is valid"""
|
||||||
v.validate_self()
|
v.validate_self()
|
||||||
return v
|
return v
|
||||||
|
|
||||||
class Config:
|
model_config = ConfigDict(
|
||||||
schema_extra = {
|
json_schema_extra=dict(
|
||||||
"required": [
|
required=[
|
||||||
"id",
|
"id",
|
||||||
"graph",
|
"graph",
|
||||||
"execution_graph",
|
"execution_graph",
|
||||||
@ -836,7 +837,8 @@ class GraphExecutionState(BaseModel):
|
|||||||
"prepared_source_mapping",
|
"prepared_source_mapping",
|
||||||
"source_prepared_mapping",
|
"source_prepared_mapping",
|
||||||
]
|
]
|
||||||
}
|
)
|
||||||
|
)
|
||||||
|
|
||||||
def next(self) -> Optional[BaseInvocation]:
|
def next(self) -> Optional[BaseInvocation]:
|
||||||
"""Gets the next node ready to execute."""
|
"""Gets the next node ready to execute."""
|
||||||
@ -910,7 +912,7 @@ class GraphExecutionState(BaseModel):
|
|||||||
input_collection = getattr(input_collection_prepared_node_output, input_collection_edge.source.field)
|
input_collection = getattr(input_collection_prepared_node_output, input_collection_edge.source.field)
|
||||||
self_iteration_count = len(input_collection)
|
self_iteration_count = len(input_collection)
|
||||||
|
|
||||||
new_nodes = list()
|
new_nodes: list[str] = list()
|
||||||
if self_iteration_count == 0:
|
if self_iteration_count == 0:
|
||||||
# TODO: should this raise a warning? It might just happen if an empty collection is input, and should be valid.
|
# TODO: should this raise a warning? It might just happen if an empty collection is input, and should be valid.
|
||||||
return new_nodes
|
return new_nodes
|
||||||
@ -920,7 +922,7 @@ class GraphExecutionState(BaseModel):
|
|||||||
|
|
||||||
# Create new edges for this iteration
|
# Create new edges for this iteration
|
||||||
# For collect nodes, this may contain multiple inputs to the same field
|
# For collect nodes, this may contain multiple inputs to the same field
|
||||||
new_edges = list()
|
new_edges: list[Edge] = list()
|
||||||
for edge in input_edges:
|
for edge in input_edges:
|
||||||
for input_node_id in (n[1] for n in iteration_node_map if n[0] == edge.source.node_id):
|
for input_node_id in (n[1] for n in iteration_node_map if n[0] == edge.source.node_id):
|
||||||
new_edge = Edge(
|
new_edge = Edge(
|
||||||
@ -1179,18 +1181,18 @@ class LibraryGraph(BaseModel):
|
|||||||
description="The outputs exposed by this graph", default_factory=list
|
description="The outputs exposed by this graph", default_factory=list
|
||||||
)
|
)
|
||||||
|
|
||||||
@validator("exposed_inputs", "exposed_outputs")
|
@field_validator("exposed_inputs", "exposed_outputs")
|
||||||
def validate_exposed_aliases(cls, v):
|
def validate_exposed_aliases(cls, v: list[Union[ExposedNodeInput, ExposedNodeOutput]]):
|
||||||
if len(v) != len(set(i.alias for i in v)):
|
if len(v) != len(set(i.alias for i in v)):
|
||||||
raise ValueError("Duplicate exposed alias")
|
raise ValueError("Duplicate exposed alias")
|
||||||
return v
|
return v
|
||||||
|
|
||||||
@root_validator
|
@model_validator(mode="after")
|
||||||
def validate_exposed_nodes(cls, values):
|
def validate_exposed_nodes(cls, values):
|
||||||
graph = values["graph"]
|
graph = values.graph
|
||||||
|
|
||||||
# Validate exposed inputs
|
# Validate exposed inputs
|
||||||
for exposed_input in values["exposed_inputs"]:
|
for exposed_input in values.exposed_inputs:
|
||||||
if not graph.has_node(exposed_input.node_path):
|
if not graph.has_node(exposed_input.node_path):
|
||||||
raise ValueError(f"Exposed input node {exposed_input.node_path} does not exist")
|
raise ValueError(f"Exposed input node {exposed_input.node_path} does not exist")
|
||||||
node = graph.get_node(exposed_input.node_path)
|
node = graph.get_node(exposed_input.node_path)
|
||||||
@ -1200,7 +1202,7 @@ class LibraryGraph(BaseModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Validate exposed outputs
|
# Validate exposed outputs
|
||||||
for exposed_output in values["exposed_outputs"]:
|
for exposed_output in values.exposed_outputs:
|
||||||
if not graph.has_node(exposed_output.node_path):
|
if not graph.has_node(exposed_output.node_path):
|
||||||
raise ValueError(f"Exposed output node {exposed_output.node_path} does not exist")
|
raise ValueError(f"Exposed output node {exposed_output.node_path} does not exist")
|
||||||
node = graph.get_node(exposed_output.node_path)
|
node = graph.get_node(exposed_output.node_path)
|
||||||
@ -1212,4 +1214,6 @@ class LibraryGraph(BaseModel):
|
|||||||
return values
|
return values
|
||||||
|
|
||||||
|
|
||||||
GraphInvocation.update_forward_refs()
|
GraphInvocation.model_rebuild(force=True)
|
||||||
|
Graph.model_rebuild(force=True)
|
||||||
|
GraphExecutionState.model_rebuild(force=True)
|
||||||
|
@ -1,12 +1,11 @@
|
|||||||
from typing import Generic, TypeVar
|
from typing import Generic, TypeVar
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from pydantic.generics import GenericModel
|
|
||||||
|
|
||||||
GenericBaseModel = TypeVar("GenericBaseModel", bound=BaseModel)
|
GenericBaseModel = TypeVar("GenericBaseModel", bound=BaseModel)
|
||||||
|
|
||||||
|
|
||||||
class CursorPaginatedResults(GenericModel, Generic[GenericBaseModel]):
|
class CursorPaginatedResults(BaseModel, Generic[GenericBaseModel]):
|
||||||
"""
|
"""
|
||||||
Cursor-paginated results
|
Cursor-paginated results
|
||||||
Generic must be a Pydantic model
|
Generic must be a Pydantic model
|
||||||
@ -17,7 +16,7 @@ class CursorPaginatedResults(GenericModel, Generic[GenericBaseModel]):
|
|||||||
items: list[GenericBaseModel] = Field(..., description="Items")
|
items: list[GenericBaseModel] = Field(..., description="Items")
|
||||||
|
|
||||||
|
|
||||||
class OffsetPaginatedResults(GenericModel, Generic[GenericBaseModel]):
|
class OffsetPaginatedResults(BaseModel, Generic[GenericBaseModel]):
|
||||||
"""
|
"""
|
||||||
Offset-paginated results
|
Offset-paginated results
|
||||||
Generic must be a Pydantic model
|
Generic must be a Pydantic model
|
||||||
@ -29,7 +28,7 @@ class OffsetPaginatedResults(GenericModel, Generic[GenericBaseModel]):
|
|||||||
items: list[GenericBaseModel] = Field(description="Items")
|
items: list[GenericBaseModel] = Field(description="Items")
|
||||||
|
|
||||||
|
|
||||||
class PaginatedResults(GenericModel, Generic[GenericBaseModel]):
|
class PaginatedResults(BaseModel, Generic[GenericBaseModel]):
|
||||||
"""
|
"""
|
||||||
Paginated results
|
Paginated results
|
||||||
Generic must be a Pydantic model
|
Generic must be a Pydantic model
|
||||||
|
@ -265,7 +265,7 @@ def np_img_resize(np_img: np.ndarray, resize_mode: str, h: int, w: int, device:
|
|||||||
|
|
||||||
|
|
||||||
def prepare_control_image(
|
def prepare_control_image(
|
||||||
image: Image,
|
image: Image.Image,
|
||||||
width: int,
|
width: int,
|
||||||
height: int,
|
height: int,
|
||||||
num_channels: int = 3,
|
num_channels: int = 3,
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
import datetime
|
import datetime
|
||||||
|
import typing
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -27,3 +28,8 @@ def get_random_seed():
|
|||||||
def uuid_string():
|
def uuid_string():
|
||||||
res = uuid.uuid4()
|
res = uuid.uuid4()
|
||||||
return str(res)
|
return str(res)
|
||||||
|
|
||||||
|
|
||||||
|
def is_optional(value: typing.Any):
|
||||||
|
"""Checks if a value is typed as Optional. Note that Optional is sugar for Union[x, None]."""
|
||||||
|
return typing.get_origin(value) is typing.Union and type(None) in typing.get_args(value)
|
||||||
|
@ -13,11 +13,11 @@ From https://github.com/tiangolo/fastapi/discussions/8882#discussioncomment-5154
|
|||||||
|
|
||||||
|
|
||||||
class BaseModelExcludeNull(BaseModel):
|
class BaseModelExcludeNull(BaseModel):
|
||||||
def dict(self, *args, **kwargs) -> dict[str, Any]:
|
def model_dump(self, *args, **kwargs) -> dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Override the default dict method to exclude None values in the response
|
Override the default dict method to exclude None values in the response
|
||||||
"""
|
"""
|
||||||
kwargs.pop("exclude_none", None)
|
kwargs.pop("exclude_none", None)
|
||||||
return super().dict(*args, exclude_none=True, **kwargs)
|
return super().model_dump(*args, exclude_none=True, **kwargs)
|
||||||
|
|
||||||
pass
|
pass
|
||||||
|
0
invokeai/assets/__init__.py
Normal file
0
invokeai/assets/__init__.py
Normal file
@ -41,18 +41,18 @@ config = InvokeAIAppConfig.get_config()
|
|||||||
|
|
||||||
|
|
||||||
class SegmentedGrayscale(object):
|
class SegmentedGrayscale(object):
|
||||||
def __init__(self, image: Image, heatmap: torch.Tensor):
|
def __init__(self, image: Image.Image, heatmap: torch.Tensor):
|
||||||
self.heatmap = heatmap
|
self.heatmap = heatmap
|
||||||
self.image = image
|
self.image = image
|
||||||
|
|
||||||
def to_grayscale(self, invert: bool = False) -> Image:
|
def to_grayscale(self, invert: bool = False) -> Image.Image:
|
||||||
return self._rescale(Image.fromarray(np.uint8(255 - self.heatmap * 255 if invert else self.heatmap * 255)))
|
return self._rescale(Image.fromarray(np.uint8(255 - self.heatmap * 255 if invert else self.heatmap * 255)))
|
||||||
|
|
||||||
def to_mask(self, threshold: float = 0.5) -> Image:
|
def to_mask(self, threshold: float = 0.5) -> Image.Image:
|
||||||
discrete_heatmap = self.heatmap.lt(threshold).int()
|
discrete_heatmap = self.heatmap.lt(threshold).int()
|
||||||
return self._rescale(Image.fromarray(np.uint8(discrete_heatmap * 255), mode="L"))
|
return self._rescale(Image.fromarray(np.uint8(discrete_heatmap * 255), mode="L"))
|
||||||
|
|
||||||
def to_transparent(self, invert: bool = False) -> Image:
|
def to_transparent(self, invert: bool = False) -> Image.Image:
|
||||||
transparent_image = self.image.copy()
|
transparent_image = self.image.copy()
|
||||||
# For img2img, we want the selected regions to be transparent,
|
# For img2img, we want the selected regions to be transparent,
|
||||||
# but to_grayscale() returns the opposite. Thus invert.
|
# but to_grayscale() returns the opposite. Thus invert.
|
||||||
@ -61,7 +61,7 @@ class SegmentedGrayscale(object):
|
|||||||
return transparent_image
|
return transparent_image
|
||||||
|
|
||||||
# unscales and uncrops the 352x352 heatmap so that it matches the image again
|
# unscales and uncrops the 352x352 heatmap so that it matches the image again
|
||||||
def _rescale(self, heatmap: Image) -> Image:
|
def _rescale(self, heatmap: Image.Image) -> Image.Image:
|
||||||
size = self.image.width if (self.image.width > self.image.height) else self.image.height
|
size = self.image.width if (self.image.width > self.image.height) else self.image.height
|
||||||
resized_image = heatmap.resize((size, size), resample=Image.Resampling.LANCZOS)
|
resized_image = heatmap.resize((size, size), resample=Image.Resampling.LANCZOS)
|
||||||
return resized_image.crop((0, 0, self.image.width, self.image.height))
|
return resized_image.crop((0, 0, self.image.width, self.image.height))
|
||||||
@ -82,7 +82,7 @@ class Txt2Mask(object):
|
|||||||
self.model = CLIPSegForImageSegmentation.from_pretrained(CLIPSEG_MODEL, cache_dir=config.cache_dir)
|
self.model = CLIPSegForImageSegmentation.from_pretrained(CLIPSEG_MODEL, cache_dir=config.cache_dir)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def segment(self, image, prompt: str) -> SegmentedGrayscale:
|
def segment(self, image: Image.Image, prompt: str) -> SegmentedGrayscale:
|
||||||
"""
|
"""
|
||||||
Given a prompt string such as "a bagel", tries to identify the object in the
|
Given a prompt string such as "a bagel", tries to identify the object in the
|
||||||
provided image and returns a SegmentedGrayscale object in which the brighter
|
provided image and returns a SegmentedGrayscale object in which the brighter
|
||||||
@ -99,7 +99,7 @@ class Txt2Mask(object):
|
|||||||
heatmap = torch.sigmoid(outputs.logits)
|
heatmap = torch.sigmoid(outputs.logits)
|
||||||
return SegmentedGrayscale(image, heatmap)
|
return SegmentedGrayscale(image, heatmap)
|
||||||
|
|
||||||
def _scale_and_crop(self, image: Image) -> Image:
|
def _scale_and_crop(self, image: Image.Image) -> Image.Image:
|
||||||
scaled_image = Image.new("RGB", (CLIPSEG_SIZE, CLIPSEG_SIZE))
|
scaled_image = Image.new("RGB", (CLIPSEG_SIZE, CLIPSEG_SIZE))
|
||||||
if image.width > image.height: # width is constraint
|
if image.width > image.height: # width is constraint
|
||||||
scale = CLIPSEG_SIZE / image.width
|
scale = CLIPSEG_SIZE / image.width
|
||||||
|
@ -9,7 +9,7 @@ class InitImageResizer:
|
|||||||
def __init__(self, Image):
|
def __init__(self, Image):
|
||||||
self.image = Image
|
self.image = Image
|
||||||
|
|
||||||
def resize(self, width=None, height=None) -> Image:
|
def resize(self, width=None, height=None) -> Image.Image:
|
||||||
"""
|
"""
|
||||||
Return a copy of the image resized to fit within
|
Return a copy of the image resized to fit within
|
||||||
a box width x height. The aspect ratio is
|
a box width x height. The aspect ratio is
|
||||||
|
@ -793,7 +793,11 @@ def migrate_init_file(legacy_format: Path):
|
|||||||
old = legacy_parser.parse_args([f"@{str(legacy_format)}"])
|
old = legacy_parser.parse_args([f"@{str(legacy_format)}"])
|
||||||
new = InvokeAIAppConfig.get_config()
|
new = InvokeAIAppConfig.get_config()
|
||||||
|
|
||||||
fields = [x for x, y in InvokeAIAppConfig.__fields__.items() if y.field_info.extra.get("category") != "DEPRECATED"]
|
fields = [
|
||||||
|
x
|
||||||
|
for x, y in InvokeAIAppConfig.model_fields.items()
|
||||||
|
if (y.json_schema_extra.get("category", None) if y.json_schema_extra else None) != "DEPRECATED"
|
||||||
|
]
|
||||||
for attr in fields:
|
for attr in fields:
|
||||||
if hasattr(old, attr):
|
if hasattr(old, attr):
|
||||||
try:
|
try:
|
||||||
|
@ -236,13 +236,13 @@ import types
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from shutil import move, rmtree
|
from shutil import move, rmtree
|
||||||
from typing import Callable, Dict, List, Literal, Optional, Set, Tuple, Union
|
from typing import Callable, Dict, List, Literal, Optional, Set, Tuple, Union, cast
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import yaml
|
import yaml
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
from omegaconf.dictconfig import DictConfig
|
from omegaconf.dictconfig import DictConfig
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
@ -294,6 +294,8 @@ class AddModelResult(BaseModel):
|
|||||||
base_model: BaseModelType = Field(description="The base model")
|
base_model: BaseModelType = Field(description="The base model")
|
||||||
config: ModelConfigBase = Field(description="The configuration of the model")
|
config: ModelConfigBase = Field(description="The configuration of the model")
|
||||||
|
|
||||||
|
model_config = ConfigDict(protected_namespaces=())
|
||||||
|
|
||||||
|
|
||||||
MAX_CACHE_SIZE = 6.0 # GB
|
MAX_CACHE_SIZE = 6.0 # GB
|
||||||
|
|
||||||
@ -576,7 +578,7 @@ class ModelManager(object):
|
|||||||
"""
|
"""
|
||||||
model_key = self.create_key(model_name, base_model, model_type)
|
model_key = self.create_key(model_name, base_model, model_type)
|
||||||
if model_key in self.models:
|
if model_key in self.models:
|
||||||
return self.models[model_key].dict(exclude_defaults=True)
|
return self.models[model_key].model_dump(exclude_defaults=True)
|
||||||
else:
|
else:
|
||||||
return None # TODO: None or empty dict on not found
|
return None # TODO: None or empty dict on not found
|
||||||
|
|
||||||
@ -632,7 +634,7 @@ class ModelManager(object):
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
model_dict = dict(
|
model_dict = dict(
|
||||||
**model_config.dict(exclude_defaults=True),
|
**model_config.model_dump(exclude_defaults=True),
|
||||||
# OpenAPIModelInfoBase
|
# OpenAPIModelInfoBase
|
||||||
model_name=cur_model_name,
|
model_name=cur_model_name,
|
||||||
base_model=cur_base_model,
|
base_model=cur_base_model,
|
||||||
@ -900,14 +902,16 @@ class ModelManager(object):
|
|||||||
Write current configuration out to the indicated file.
|
Write current configuration out to the indicated file.
|
||||||
"""
|
"""
|
||||||
data_to_save = dict()
|
data_to_save = dict()
|
||||||
data_to_save["__metadata__"] = self.config_meta.dict()
|
data_to_save["__metadata__"] = self.config_meta.model_dump()
|
||||||
|
|
||||||
for model_key, model_config in self.models.items():
|
for model_key, model_config in self.models.items():
|
||||||
model_name, base_model, model_type = self.parse_key(model_key)
|
model_name, base_model, model_type = self.parse_key(model_key)
|
||||||
model_class = self._get_implementation(base_model, model_type)
|
model_class = self._get_implementation(base_model, model_type)
|
||||||
if model_class.save_to_config:
|
if model_class.save_to_config:
|
||||||
# TODO: or exclude_unset better fits here?
|
# TODO: or exclude_unset better fits here?
|
||||||
data_to_save[model_key] = model_config.dict(exclude_defaults=True, exclude={"error"})
|
data_to_save[model_key] = cast(BaseModel, model_config).model_dump(
|
||||||
|
exclude_defaults=True, exclude={"error"}, mode="json"
|
||||||
|
)
|
||||||
# alias for config file
|
# alias for config file
|
||||||
data_to_save[model_key]["format"] = data_to_save[model_key].pop("model_format")
|
data_to_save[model_key]["format"] = data_to_save[model_key].pop("model_format")
|
||||||
|
|
||||||
|
@ -2,7 +2,7 @@ import inspect
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Literal, get_origin
|
from typing import Literal, get_origin
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel, ConfigDict, create_model
|
||||||
|
|
||||||
from .base import ( # noqa: F401
|
from .base import ( # noqa: F401
|
||||||
BaseModelType,
|
BaseModelType,
|
||||||
@ -106,6 +106,8 @@ class OpenAPIModelInfoBase(BaseModel):
|
|||||||
base_model: BaseModelType
|
base_model: BaseModelType
|
||||||
model_type: ModelType
|
model_type: ModelType
|
||||||
|
|
||||||
|
model_config = ConfigDict(protected_namespaces=())
|
||||||
|
|
||||||
|
|
||||||
for base_model, models in MODEL_CLASSES.items():
|
for base_model, models in MODEL_CLASSES.items():
|
||||||
for model_type, model_class in models.items():
|
for model_type, model_class in models.items():
|
||||||
@ -121,17 +123,11 @@ for base_model, models in MODEL_CLASSES.items():
|
|||||||
if openapi_cfg_name in vars():
|
if openapi_cfg_name in vars():
|
||||||
continue
|
continue
|
||||||
|
|
||||||
api_wrapper = type(
|
api_wrapper = create_model(
|
||||||
openapi_cfg_name,
|
openapi_cfg_name,
|
||||||
(cfg, OpenAPIModelInfoBase),
|
__base__=(cfg, OpenAPIModelInfoBase),
|
||||||
dict(
|
model_type=(Literal[model_type], model_type), # type: ignore
|
||||||
__annotations__=dict(
|
|
||||||
model_type=Literal[model_type.value],
|
|
||||||
),
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# globals()[openapi_cfg_name] = api_wrapper
|
|
||||||
vars()[openapi_cfg_name] = api_wrapper
|
vars()[openapi_cfg_name] = api_wrapper
|
||||||
OPENAPI_MODEL_CONFIGS.append(api_wrapper)
|
OPENAPI_MODEL_CONFIGS.append(api_wrapper)
|
||||||
|
|
||||||
|
@ -19,7 +19,7 @@ from diffusers import logging as diffusers_logging
|
|||||||
from onnx import numpy_helper
|
from onnx import numpy_helper
|
||||||
from onnxruntime import InferenceSession, SessionOptions, get_available_providers
|
from onnxruntime import InferenceSession, SessionOptions, get_available_providers
|
||||||
from picklescan.scanner import scan_file_path
|
from picklescan.scanner import scan_file_path
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
from transformers import logging as transformers_logging
|
from transformers import logging as transformers_logging
|
||||||
|
|
||||||
|
|
||||||
@ -86,14 +86,21 @@ class ModelError(str, Enum):
|
|||||||
NotFound = "not_found"
|
NotFound = "not_found"
|
||||||
|
|
||||||
|
|
||||||
|
def model_config_json_schema_extra(schema: dict[str, Any]) -> None:
|
||||||
|
if "required" not in schema:
|
||||||
|
schema["required"] = []
|
||||||
|
schema["required"].append("model_type")
|
||||||
|
|
||||||
|
|
||||||
class ModelConfigBase(BaseModel):
|
class ModelConfigBase(BaseModel):
|
||||||
path: str # or Path
|
path: str # or Path
|
||||||
description: Optional[str] = Field(None)
|
description: Optional[str] = Field(None)
|
||||||
model_format: Optional[str] = Field(None)
|
model_format: Optional[str] = Field(None)
|
||||||
error: Optional[ModelError] = Field(None)
|
error: Optional[ModelError] = Field(None)
|
||||||
|
|
||||||
class Config:
|
model_config = ConfigDict(
|
||||||
use_enum_values = True
|
use_enum_values=True, protected_namespaces=(), json_schema_extra=model_config_json_schema_extra
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class EmptyConfigLoader(ConfigMixin):
|
class EmptyConfigLoader(ConfigMixin):
|
||||||
|
@ -58,14 +58,16 @@ class IPAdapterModel(ModelBase):
|
|||||||
|
|
||||||
def get_model(
|
def get_model(
|
||||||
self,
|
self,
|
||||||
torch_dtype: Optional[torch.dtype],
|
torch_dtype: torch.dtype,
|
||||||
child_type: Optional[SubModelType] = None,
|
child_type: Optional[SubModelType] = None,
|
||||||
) -> typing.Union[IPAdapter, IPAdapterPlus]:
|
) -> typing.Union[IPAdapter, IPAdapterPlus]:
|
||||||
if child_type is not None:
|
if child_type is not None:
|
||||||
raise ValueError("There are no child models in an IP-Adapter model.")
|
raise ValueError("There are no child models in an IP-Adapter model.")
|
||||||
|
|
||||||
model = build_ip_adapter(
|
model = build_ip_adapter(
|
||||||
ip_adapter_ckpt_path=os.path.join(self.model_path, "ip_adapter.bin"), device="cpu", dtype=torch_dtype
|
ip_adapter_ckpt_path=os.path.join(self.model_path, "ip_adapter.bin"),
|
||||||
|
device=torch.device("cpu"),
|
||||||
|
dtype=torch_dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.model_size = model.calc_size()
|
self.model_size = model.calc_size()
|
||||||
|
@ -96,7 +96,7 @@ def set_seamless(model: Union[UNet2DConditionModel, AutoencoderKL], seamless_axe
|
|||||||
finally:
|
finally:
|
||||||
for module, orig_conv_forward in to_restore:
|
for module, orig_conv_forward in to_restore:
|
||||||
module._conv_forward = orig_conv_forward
|
module._conv_forward = orig_conv_forward
|
||||||
if hasattr(m, "asymmetric_padding_mode"):
|
if hasattr(module, "asymmetric_padding_mode"):
|
||||||
del m.asymmetric_padding_mode
|
del module.asymmetric_padding_mode
|
||||||
if hasattr(m, "asymmetric_padding"):
|
if hasattr(module, "asymmetric_padding"):
|
||||||
del m.asymmetric_padding
|
del module.asymmetric_padding
|
||||||
|
@ -1,7 +1,8 @@
|
|||||||
import math
|
import math
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import PIL
|
|
||||||
import torch
|
import torch
|
||||||
|
from PIL import Image
|
||||||
from torchvision.transforms.functional import InterpolationMode
|
from torchvision.transforms.functional import InterpolationMode
|
||||||
from torchvision.transforms.functional import resize as tv_resize
|
from torchvision.transforms.functional import resize as tv_resize
|
||||||
|
|
||||||
@ -11,7 +12,7 @@ class AttentionMapSaver:
|
|||||||
self.token_ids = token_ids
|
self.token_ids = token_ids
|
||||||
self.latents_shape = latents_shape
|
self.latents_shape = latents_shape
|
||||||
# self.collated_maps = #torch.zeros([len(token_ids), latents_shape[0], latents_shape[1]])
|
# self.collated_maps = #torch.zeros([len(token_ids), latents_shape[0], latents_shape[1]])
|
||||||
self.collated_maps = {}
|
self.collated_maps: dict[str, torch.Tensor] = {}
|
||||||
|
|
||||||
def clear_maps(self):
|
def clear_maps(self):
|
||||||
self.collated_maps = {}
|
self.collated_maps = {}
|
||||||
@ -38,9 +39,10 @@ class AttentionMapSaver:
|
|||||||
|
|
||||||
def write_maps_to_disk(self, path: str):
|
def write_maps_to_disk(self, path: str):
|
||||||
pil_image = self.get_stacked_maps_image()
|
pil_image = self.get_stacked_maps_image()
|
||||||
pil_image.save(path, "PNG")
|
if pil_image is not None:
|
||||||
|
pil_image.save(path, "PNG")
|
||||||
|
|
||||||
def get_stacked_maps_image(self) -> PIL.Image:
|
def get_stacked_maps_image(self) -> Optional[Image.Image]:
|
||||||
"""
|
"""
|
||||||
Scale all collected attention maps to the same size, blend them together and return as an image.
|
Scale all collected attention maps to the same size, blend them together and return as an image.
|
||||||
:return: An image containing a vertical stack of blended attention maps, one for each requested token.
|
:return: An image containing a vertical stack of blended attention maps, one for each requested token.
|
||||||
@ -95,4 +97,4 @@ class AttentionMapSaver:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
merged_bytes = merged.mul(0xFF).byte()
|
merged_bytes = merged.mul(0xFF).byte()
|
||||||
return PIL.Image.fromarray(merged_bytes.numpy(), mode="L")
|
return Image.fromarray(merged_bytes.numpy(), mode="L")
|
||||||
|
@ -151,7 +151,9 @@ export const addRequestedSingleImageDeletionListener = () => {
|
|||||||
|
|
||||||
if (wasImageDeleted) {
|
if (wasImageDeleted) {
|
||||||
dispatch(
|
dispatch(
|
||||||
api.util.invalidateTags([{ type: 'Board', id: imageDTO.board_id }])
|
api.util.invalidateTags([
|
||||||
|
{ type: 'Board', id: imageDTO.board_id ?? 'none' },
|
||||||
|
])
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
@ -6,7 +6,7 @@ import { useMantineMultiSelectStyles } from 'mantine-theme/hooks/useMantineMulti
|
|||||||
import { KeyboardEvent, RefObject, memo, useCallback } from 'react';
|
import { KeyboardEvent, RefObject, memo, useCallback } from 'react';
|
||||||
|
|
||||||
type IAIMultiSelectProps = Omit<MultiSelectProps, 'label'> & {
|
type IAIMultiSelectProps = Omit<MultiSelectProps, 'label'> & {
|
||||||
tooltip?: string;
|
tooltip?: string | null;
|
||||||
inputRef?: RefObject<HTMLInputElement>;
|
inputRef?: RefObject<HTMLInputElement>;
|
||||||
label?: string;
|
label?: string;
|
||||||
};
|
};
|
||||||
|
@ -12,7 +12,7 @@ export type IAISelectDataType = {
|
|||||||
};
|
};
|
||||||
|
|
||||||
type IAISelectProps = Omit<SelectProps, 'label'> & {
|
type IAISelectProps = Omit<SelectProps, 'label'> & {
|
||||||
tooltip?: string;
|
tooltip?: string | null;
|
||||||
label?: string;
|
label?: string;
|
||||||
inputRef?: RefObject<HTMLInputElement>;
|
inputRef?: RefObject<HTMLInputElement>;
|
||||||
};
|
};
|
||||||
|
@ -10,7 +10,7 @@ export type IAISelectDataType = {
|
|||||||
};
|
};
|
||||||
|
|
||||||
export type IAISelectProps = Omit<SelectProps, 'label'> & {
|
export type IAISelectProps = Omit<SelectProps, 'label'> & {
|
||||||
tooltip?: string;
|
tooltip?: string | null;
|
||||||
inputRef?: RefObject<HTMLInputElement>;
|
inputRef?: RefObject<HTMLInputElement>;
|
||||||
label?: string;
|
label?: string;
|
||||||
};
|
};
|
||||||
|
@ -39,7 +39,10 @@ export const dynamicPromptsSlice = createSlice({
|
|||||||
promptsChanged: (state, action: PayloadAction<string[]>) => {
|
promptsChanged: (state, action: PayloadAction<string[]>) => {
|
||||||
state.prompts = action.payload;
|
state.prompts = action.payload;
|
||||||
},
|
},
|
||||||
parsingErrorChanged: (state, action: PayloadAction<string | undefined>) => {
|
parsingErrorChanged: (
|
||||||
|
state,
|
||||||
|
action: PayloadAction<string | null | undefined>
|
||||||
|
) => {
|
||||||
state.parsingError = action.payload;
|
state.parsingError = action.payload;
|
||||||
},
|
},
|
||||||
isErrorChanged: (state, action: PayloadAction<boolean>) => {
|
isErrorChanged: (state, action: PayloadAction<boolean>) => {
|
||||||
|
@ -10,7 +10,7 @@ import {
|
|||||||
} from 'features/parameters/types/parameterSchemas';
|
} from 'features/parameters/types/parameterSchemas';
|
||||||
import i18n from 'i18next';
|
import i18n from 'i18next';
|
||||||
import { has, keyBy } from 'lodash-es';
|
import { has, keyBy } from 'lodash-es';
|
||||||
import { OpenAPIV3 } from 'openapi-types';
|
import { OpenAPIV3_1 } from 'openapi-types';
|
||||||
import { RgbaColor } from 'react-colorful';
|
import { RgbaColor } from 'react-colorful';
|
||||||
import { Node } from 'reactflow';
|
import { Node } from 'reactflow';
|
||||||
import { Graph, _InputField, _OutputField } from 'services/api/types';
|
import { Graph, _InputField, _OutputField } from 'services/api/types';
|
||||||
@ -791,9 +791,9 @@ export type IntegerInputFieldTemplate = InputFieldTemplateBase & {
|
|||||||
default: number;
|
default: number;
|
||||||
multipleOf?: number;
|
multipleOf?: number;
|
||||||
maximum?: number;
|
maximum?: number;
|
||||||
exclusiveMaximum?: boolean;
|
exclusiveMaximum?: number;
|
||||||
minimum?: number;
|
minimum?: number;
|
||||||
exclusiveMinimum?: boolean;
|
exclusiveMinimum?: number;
|
||||||
};
|
};
|
||||||
|
|
||||||
export type IntegerCollectionInputFieldTemplate = InputFieldTemplateBase & {
|
export type IntegerCollectionInputFieldTemplate = InputFieldTemplateBase & {
|
||||||
@ -814,9 +814,9 @@ export type FloatInputFieldTemplate = InputFieldTemplateBase & {
|
|||||||
default: number;
|
default: number;
|
||||||
multipleOf?: number;
|
multipleOf?: number;
|
||||||
maximum?: number;
|
maximum?: number;
|
||||||
exclusiveMaximum?: boolean;
|
exclusiveMaximum?: number;
|
||||||
minimum?: number;
|
minimum?: number;
|
||||||
exclusiveMinimum?: boolean;
|
exclusiveMinimum?: number;
|
||||||
};
|
};
|
||||||
|
|
||||||
export type FloatCollectionInputFieldTemplate = InputFieldTemplateBase & {
|
export type FloatCollectionInputFieldTemplate = InputFieldTemplateBase & {
|
||||||
@ -1163,20 +1163,20 @@ export type TypeHints = {
|
|||||||
};
|
};
|
||||||
|
|
||||||
export type InvocationSchemaExtra = {
|
export type InvocationSchemaExtra = {
|
||||||
output: OpenAPIV3.ReferenceObject; // the output of the invocation
|
output: OpenAPIV3_1.ReferenceObject; // the output of the invocation
|
||||||
title: string;
|
title: string;
|
||||||
category?: string;
|
category?: string;
|
||||||
tags?: string[];
|
tags?: string[];
|
||||||
version?: string;
|
version?: string;
|
||||||
properties: Omit<
|
properties: Omit<
|
||||||
NonNullable<OpenAPIV3.SchemaObject['properties']> &
|
NonNullable<OpenAPIV3_1.SchemaObject['properties']> &
|
||||||
(_InputField | _OutputField),
|
(_InputField | _OutputField),
|
||||||
'type'
|
'type'
|
||||||
> & {
|
> & {
|
||||||
type: Omit<OpenAPIV3.SchemaObject, 'default'> & {
|
type: Omit<OpenAPIV3_1.SchemaObject, 'default'> & {
|
||||||
default: AnyInvocationType;
|
default: AnyInvocationType;
|
||||||
};
|
};
|
||||||
use_cache: Omit<OpenAPIV3.SchemaObject, 'default'> & {
|
use_cache: Omit<OpenAPIV3_1.SchemaObject, 'default'> & {
|
||||||
default: boolean;
|
default: boolean;
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
@ -1187,17 +1187,17 @@ export type InvocationSchemaType = {
|
|||||||
};
|
};
|
||||||
|
|
||||||
export type InvocationBaseSchemaObject = Omit<
|
export type InvocationBaseSchemaObject = Omit<
|
||||||
OpenAPIV3.BaseSchemaObject,
|
OpenAPIV3_1.BaseSchemaObject,
|
||||||
'title' | 'type' | 'properties'
|
'title' | 'type' | 'properties'
|
||||||
> &
|
> &
|
||||||
InvocationSchemaExtra;
|
InvocationSchemaExtra;
|
||||||
|
|
||||||
export type InvocationOutputSchemaObject = Omit<
|
export type InvocationOutputSchemaObject = Omit<
|
||||||
OpenAPIV3.SchemaObject,
|
OpenAPIV3_1.SchemaObject,
|
||||||
'properties'
|
'properties'
|
||||||
> & {
|
> & {
|
||||||
properties: OpenAPIV3.SchemaObject['properties'] & {
|
properties: OpenAPIV3_1.SchemaObject['properties'] & {
|
||||||
type: Omit<OpenAPIV3.SchemaObject, 'default'> & {
|
type: Omit<OpenAPIV3_1.SchemaObject, 'default'> & {
|
||||||
default: string;
|
default: string;
|
||||||
};
|
};
|
||||||
} & {
|
} & {
|
||||||
@ -1205,14 +1205,18 @@ export type InvocationOutputSchemaObject = Omit<
|
|||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
export type InvocationFieldSchema = OpenAPIV3.SchemaObject & _InputField;
|
export type InvocationFieldSchema = OpenAPIV3_1.SchemaObject & _InputField;
|
||||||
|
|
||||||
|
export type OpenAPIV3_1SchemaOrRef =
|
||||||
|
| OpenAPIV3_1.ReferenceObject
|
||||||
|
| OpenAPIV3_1.SchemaObject;
|
||||||
|
|
||||||
export interface ArraySchemaObject extends InvocationBaseSchemaObject {
|
export interface ArraySchemaObject extends InvocationBaseSchemaObject {
|
||||||
type: OpenAPIV3.ArraySchemaObjectType;
|
type: OpenAPIV3_1.ArraySchemaObjectType;
|
||||||
items: OpenAPIV3.ReferenceObject | OpenAPIV3.SchemaObject;
|
items: OpenAPIV3_1.ReferenceObject | OpenAPIV3_1.SchemaObject;
|
||||||
}
|
}
|
||||||
export interface NonArraySchemaObject extends InvocationBaseSchemaObject {
|
export interface NonArraySchemaObject extends InvocationBaseSchemaObject {
|
||||||
type?: OpenAPIV3.NonArraySchemaObjectType;
|
type?: OpenAPIV3_1.NonArraySchemaObjectType;
|
||||||
}
|
}
|
||||||
|
|
||||||
export type InvocationSchemaObject = (
|
export type InvocationSchemaObject = (
|
||||||
@ -1221,41 +1225,41 @@ export type InvocationSchemaObject = (
|
|||||||
) & { class: 'invocation' };
|
) & { class: 'invocation' };
|
||||||
|
|
||||||
export const isSchemaObject = (
|
export const isSchemaObject = (
|
||||||
obj: OpenAPIV3.ReferenceObject | OpenAPIV3.SchemaObject | undefined
|
obj: OpenAPIV3_1.ReferenceObject | OpenAPIV3_1.SchemaObject | undefined
|
||||||
): obj is OpenAPIV3.SchemaObject => Boolean(obj && !('$ref' in obj));
|
): obj is OpenAPIV3_1.SchemaObject => Boolean(obj && !('$ref' in obj));
|
||||||
|
|
||||||
export const isArraySchemaObject = (
|
export const isArraySchemaObject = (
|
||||||
obj: OpenAPIV3.ReferenceObject | OpenAPIV3.SchemaObject | undefined
|
obj: OpenAPIV3_1.ReferenceObject | OpenAPIV3_1.SchemaObject | undefined
|
||||||
): obj is OpenAPIV3.ArraySchemaObject =>
|
): obj is OpenAPIV3_1.ArraySchemaObject =>
|
||||||
Boolean(obj && !('$ref' in obj) && obj.type === 'array');
|
Boolean(obj && !('$ref' in obj) && obj.type === 'array');
|
||||||
|
|
||||||
export const isNonArraySchemaObject = (
|
export const isNonArraySchemaObject = (
|
||||||
obj: OpenAPIV3.ReferenceObject | OpenAPIV3.SchemaObject | undefined
|
obj: OpenAPIV3_1.ReferenceObject | OpenAPIV3_1.SchemaObject | undefined
|
||||||
): obj is OpenAPIV3.NonArraySchemaObject =>
|
): obj is OpenAPIV3_1.NonArraySchemaObject =>
|
||||||
Boolean(obj && !('$ref' in obj) && obj.type !== 'array');
|
Boolean(obj && !('$ref' in obj) && obj.type !== 'array');
|
||||||
|
|
||||||
export const isRefObject = (
|
export const isRefObject = (
|
||||||
obj: OpenAPIV3.ReferenceObject | OpenAPIV3.SchemaObject | undefined
|
obj: OpenAPIV3_1.ReferenceObject | OpenAPIV3_1.SchemaObject | undefined
|
||||||
): obj is OpenAPIV3.ReferenceObject => Boolean(obj && '$ref' in obj);
|
): obj is OpenAPIV3_1.ReferenceObject => Boolean(obj && '$ref' in obj);
|
||||||
|
|
||||||
export const isInvocationSchemaObject = (
|
export const isInvocationSchemaObject = (
|
||||||
obj:
|
obj:
|
||||||
| OpenAPIV3.ReferenceObject
|
| OpenAPIV3_1.ReferenceObject
|
||||||
| OpenAPIV3.SchemaObject
|
| OpenAPIV3_1.SchemaObject
|
||||||
| InvocationSchemaObject
|
| InvocationSchemaObject
|
||||||
): obj is InvocationSchemaObject =>
|
): obj is InvocationSchemaObject =>
|
||||||
'class' in obj && obj.class === 'invocation';
|
'class' in obj && obj.class === 'invocation';
|
||||||
|
|
||||||
export const isInvocationOutputSchemaObject = (
|
export const isInvocationOutputSchemaObject = (
|
||||||
obj:
|
obj:
|
||||||
| OpenAPIV3.ReferenceObject
|
| OpenAPIV3_1.ReferenceObject
|
||||||
| OpenAPIV3.SchemaObject
|
| OpenAPIV3_1.SchemaObject
|
||||||
| InvocationOutputSchemaObject
|
| InvocationOutputSchemaObject
|
||||||
): obj is InvocationOutputSchemaObject =>
|
): obj is InvocationOutputSchemaObject =>
|
||||||
'class' in obj && obj.class === 'output';
|
'class' in obj && obj.class === 'output';
|
||||||
|
|
||||||
export const isInvocationFieldSchema = (
|
export const isInvocationFieldSchema = (
|
||||||
obj: OpenAPIV3.ReferenceObject | OpenAPIV3.SchemaObject
|
obj: OpenAPIV3_1.ReferenceObject | OpenAPIV3_1.SchemaObject
|
||||||
): obj is InvocationFieldSchema => !('$ref' in obj);
|
): obj is InvocationFieldSchema => !('$ref' in obj);
|
||||||
|
|
||||||
export type InvocationEdgeExtra = { type: 'default' | 'collapsed' };
|
export type InvocationEdgeExtra = { type: 'default' | 'collapsed' };
|
||||||
|
@ -1,5 +1,12 @@
|
|||||||
import { isBoolean, isInteger, isNumber, isString } from 'lodash-es';
|
import {
|
||||||
import { OpenAPIV3 } from 'openapi-types';
|
isArray,
|
||||||
|
isBoolean,
|
||||||
|
isInteger,
|
||||||
|
isNumber,
|
||||||
|
isString,
|
||||||
|
startCase,
|
||||||
|
} from 'lodash-es';
|
||||||
|
import { OpenAPIV3_1 } from 'openapi-types';
|
||||||
import {
|
import {
|
||||||
COLLECTION_MAP,
|
COLLECTION_MAP,
|
||||||
POLYMORPHIC_TYPES,
|
POLYMORPHIC_TYPES,
|
||||||
@ -72,6 +79,7 @@ import {
|
|||||||
T2IAdapterCollectionInputFieldTemplate,
|
T2IAdapterCollectionInputFieldTemplate,
|
||||||
BoardInputFieldTemplate,
|
BoardInputFieldTemplate,
|
||||||
InputFieldTemplate,
|
InputFieldTemplate,
|
||||||
|
OpenAPIV3_1SchemaOrRef,
|
||||||
} from '../types/types';
|
} from '../types/types';
|
||||||
import { ControlField } from 'services/api/types';
|
import { ControlField } from 'services/api/types';
|
||||||
|
|
||||||
@ -90,7 +98,7 @@ export type BuildInputFieldArg = {
|
|||||||
* @example
|
* @example
|
||||||
* refObjectToFieldType({ "$ref": "#/components/schemas/ImageField" }) --> 'ImageField'
|
* refObjectToFieldType({ "$ref": "#/components/schemas/ImageField" }) --> 'ImageField'
|
||||||
*/
|
*/
|
||||||
export const refObjectToSchemaName = (refObject: OpenAPIV3.ReferenceObject) =>
|
export const refObjectToSchemaName = (refObject: OpenAPIV3_1.ReferenceObject) =>
|
||||||
refObject.$ref.split('/').slice(-1)[0];
|
refObject.$ref.split('/').slice(-1)[0];
|
||||||
|
|
||||||
const buildIntegerInputFieldTemplate = ({
|
const buildIntegerInputFieldTemplate = ({
|
||||||
@ -111,7 +119,10 @@ const buildIntegerInputFieldTemplate = ({
|
|||||||
template.maximum = schemaObject.maximum;
|
template.maximum = schemaObject.maximum;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (schemaObject.exclusiveMaximum !== undefined) {
|
if (
|
||||||
|
schemaObject.exclusiveMaximum !== undefined &&
|
||||||
|
isNumber(schemaObject.exclusiveMaximum)
|
||||||
|
) {
|
||||||
template.exclusiveMaximum = schemaObject.exclusiveMaximum;
|
template.exclusiveMaximum = schemaObject.exclusiveMaximum;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -119,7 +130,10 @@ const buildIntegerInputFieldTemplate = ({
|
|||||||
template.minimum = schemaObject.minimum;
|
template.minimum = schemaObject.minimum;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (schemaObject.exclusiveMinimum !== undefined) {
|
if (
|
||||||
|
schemaObject.exclusiveMinimum !== undefined &&
|
||||||
|
isNumber(schemaObject.exclusiveMinimum)
|
||||||
|
) {
|
||||||
template.exclusiveMinimum = schemaObject.exclusiveMinimum;
|
template.exclusiveMinimum = schemaObject.exclusiveMinimum;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -144,7 +158,10 @@ const buildIntegerPolymorphicInputFieldTemplate = ({
|
|||||||
template.maximum = schemaObject.maximum;
|
template.maximum = schemaObject.maximum;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (schemaObject.exclusiveMaximum !== undefined) {
|
if (
|
||||||
|
schemaObject.exclusiveMaximum !== undefined &&
|
||||||
|
isNumber(schemaObject.exclusiveMaximum)
|
||||||
|
) {
|
||||||
template.exclusiveMaximum = schemaObject.exclusiveMaximum;
|
template.exclusiveMaximum = schemaObject.exclusiveMaximum;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -152,7 +169,10 @@ const buildIntegerPolymorphicInputFieldTemplate = ({
|
|||||||
template.minimum = schemaObject.minimum;
|
template.minimum = schemaObject.minimum;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (schemaObject.exclusiveMinimum !== undefined) {
|
if (
|
||||||
|
schemaObject.exclusiveMinimum !== undefined &&
|
||||||
|
isNumber(schemaObject.exclusiveMinimum)
|
||||||
|
) {
|
||||||
template.exclusiveMinimum = schemaObject.exclusiveMinimum;
|
template.exclusiveMinimum = schemaObject.exclusiveMinimum;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -195,7 +215,10 @@ const buildFloatInputFieldTemplate = ({
|
|||||||
template.maximum = schemaObject.maximum;
|
template.maximum = schemaObject.maximum;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (schemaObject.exclusiveMaximum !== undefined) {
|
if (
|
||||||
|
schemaObject.exclusiveMaximum !== undefined &&
|
||||||
|
isNumber(schemaObject.exclusiveMaximum)
|
||||||
|
) {
|
||||||
template.exclusiveMaximum = schemaObject.exclusiveMaximum;
|
template.exclusiveMaximum = schemaObject.exclusiveMaximum;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -203,7 +226,10 @@ const buildFloatInputFieldTemplate = ({
|
|||||||
template.minimum = schemaObject.minimum;
|
template.minimum = schemaObject.minimum;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (schemaObject.exclusiveMinimum !== undefined) {
|
if (
|
||||||
|
schemaObject.exclusiveMinimum !== undefined &&
|
||||||
|
isNumber(schemaObject.exclusiveMinimum)
|
||||||
|
) {
|
||||||
template.exclusiveMinimum = schemaObject.exclusiveMinimum;
|
template.exclusiveMinimum = schemaObject.exclusiveMinimum;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -227,7 +253,10 @@ const buildFloatPolymorphicInputFieldTemplate = ({
|
|||||||
template.maximum = schemaObject.maximum;
|
template.maximum = schemaObject.maximum;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (schemaObject.exclusiveMaximum !== undefined) {
|
if (
|
||||||
|
schemaObject.exclusiveMaximum !== undefined &&
|
||||||
|
isNumber(schemaObject.exclusiveMaximum)
|
||||||
|
) {
|
||||||
template.exclusiveMaximum = schemaObject.exclusiveMaximum;
|
template.exclusiveMaximum = schemaObject.exclusiveMaximum;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -235,7 +264,10 @@ const buildFloatPolymorphicInputFieldTemplate = ({
|
|||||||
template.minimum = schemaObject.minimum;
|
template.minimum = schemaObject.minimum;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (schemaObject.exclusiveMinimum !== undefined) {
|
if (
|
||||||
|
schemaObject.exclusiveMinimum !== undefined &&
|
||||||
|
isNumber(schemaObject.exclusiveMinimum)
|
||||||
|
) {
|
||||||
template.exclusiveMinimum = schemaObject.exclusiveMinimum;
|
template.exclusiveMinimum = schemaObject.exclusiveMinimum;
|
||||||
}
|
}
|
||||||
return template;
|
return template;
|
||||||
@ -872,84 +904,106 @@ const buildSchedulerInputFieldTemplate = ({
|
|||||||
};
|
};
|
||||||
|
|
||||||
export const getFieldType = (
|
export const getFieldType = (
|
||||||
schemaObject: InvocationFieldSchema
|
schemaObject: OpenAPIV3_1SchemaOrRef
|
||||||
): string | undefined => {
|
): string | undefined => {
|
||||||
if (schemaObject?.ui_type) {
|
if (isSchemaObject(schemaObject)) {
|
||||||
return schemaObject.ui_type;
|
if (!schemaObject.type) {
|
||||||
} else if (!schemaObject.type) {
|
// if schemaObject has no type, then it should have one of allOf, anyOf, oneOf
|
||||||
// if schemaObject has no type, then it should have one of allOf, anyOf, oneOf
|
|
||||||
|
|
||||||
if (schemaObject.allOf) {
|
if (schemaObject.allOf) {
|
||||||
const allOf = schemaObject.allOf;
|
const allOf = schemaObject.allOf;
|
||||||
if (allOf && allOf[0] && isRefObject(allOf[0])) {
|
if (allOf && allOf[0] && isRefObject(allOf[0])) {
|
||||||
return refObjectToSchemaName(allOf[0]);
|
return refObjectToSchemaName(allOf[0]);
|
||||||
}
|
|
||||||
} else if (schemaObject.anyOf) {
|
|
||||||
const anyOf = schemaObject.anyOf;
|
|
||||||
/**
|
|
||||||
* Handle Polymorphic inputs, eg string | string[]. In OpenAPI, this is:
|
|
||||||
* - an `anyOf` with two items
|
|
||||||
* - one is an `ArraySchemaObject` with a single `SchemaObject or ReferenceObject` of type T in its `items`
|
|
||||||
* - the other is a `SchemaObject` or `ReferenceObject` of type T
|
|
||||||
*
|
|
||||||
* Any other cases we ignore.
|
|
||||||
*/
|
|
||||||
|
|
||||||
let firstType: string | undefined;
|
|
||||||
let secondType: string | undefined;
|
|
||||||
|
|
||||||
if (isArraySchemaObject(anyOf[0])) {
|
|
||||||
// first is array, second is not
|
|
||||||
const first = anyOf[0].items;
|
|
||||||
const second = anyOf[1];
|
|
||||||
if (isRefObject(first) && isRefObject(second)) {
|
|
||||||
firstType = refObjectToSchemaName(first);
|
|
||||||
secondType = refObjectToSchemaName(second);
|
|
||||||
} else if (
|
|
||||||
isNonArraySchemaObject(first) &&
|
|
||||||
isNonArraySchemaObject(second)
|
|
||||||
) {
|
|
||||||
firstType = first.type;
|
|
||||||
secondType = second.type;
|
|
||||||
}
|
}
|
||||||
} else if (isArraySchemaObject(anyOf[1])) {
|
} else if (schemaObject.anyOf) {
|
||||||
// first is not array, second is
|
// ignore null types
|
||||||
const first = anyOf[0];
|
const anyOf = schemaObject.anyOf.filter((i) => {
|
||||||
const second = anyOf[1].items;
|
if (isSchemaObject(i)) {
|
||||||
if (isRefObject(first) && isRefObject(second)) {
|
if (i.type === 'null') {
|
||||||
firstType = refObjectToSchemaName(first);
|
return false;
|
||||||
secondType = refObjectToSchemaName(second);
|
}
|
||||||
} else if (
|
}
|
||||||
isNonArraySchemaObject(first) &&
|
return true;
|
||||||
isNonArraySchemaObject(second)
|
});
|
||||||
) {
|
if (anyOf.length === 1) {
|
||||||
firstType = first.type;
|
if (isRefObject(anyOf[0])) {
|
||||||
secondType = second.type;
|
return refObjectToSchemaName(anyOf[0]);
|
||||||
|
} else if (isSchemaObject(anyOf[0])) {
|
||||||
|
return getFieldType(anyOf[0]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
/**
|
||||||
|
* Handle Polymorphic inputs, eg string | string[]. In OpenAPI, this is:
|
||||||
|
* - an `anyOf` with two items
|
||||||
|
* - one is an `ArraySchemaObject` with a single `SchemaObject or ReferenceObject` of type T in its `items`
|
||||||
|
* - the other is a `SchemaObject` or `ReferenceObject` of type T
|
||||||
|
*
|
||||||
|
* Any other cases we ignore.
|
||||||
|
*/
|
||||||
|
|
||||||
|
let firstType: string | undefined;
|
||||||
|
let secondType: string | undefined;
|
||||||
|
|
||||||
|
if (isArraySchemaObject(anyOf[0])) {
|
||||||
|
// first is array, second is not
|
||||||
|
const first = anyOf[0].items;
|
||||||
|
const second = anyOf[1];
|
||||||
|
if (isRefObject(first) && isRefObject(second)) {
|
||||||
|
firstType = refObjectToSchemaName(first);
|
||||||
|
secondType = refObjectToSchemaName(second);
|
||||||
|
} else if (
|
||||||
|
isNonArraySchemaObject(first) &&
|
||||||
|
isNonArraySchemaObject(second)
|
||||||
|
) {
|
||||||
|
firstType = first.type;
|
||||||
|
secondType = second.type;
|
||||||
|
}
|
||||||
|
} else if (isArraySchemaObject(anyOf[1])) {
|
||||||
|
// first is not array, second is
|
||||||
|
const first = anyOf[0];
|
||||||
|
const second = anyOf[1].items;
|
||||||
|
if (isRefObject(first) && isRefObject(second)) {
|
||||||
|
firstType = refObjectToSchemaName(first);
|
||||||
|
secondType = refObjectToSchemaName(second);
|
||||||
|
} else if (
|
||||||
|
isNonArraySchemaObject(first) &&
|
||||||
|
isNonArraySchemaObject(second)
|
||||||
|
) {
|
||||||
|
firstType = first.type;
|
||||||
|
secondType = second.type;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (firstType === secondType && isPolymorphicItemType(firstType)) {
|
||||||
|
return SINGLE_TO_POLYMORPHIC_MAP[firstType];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (firstType === secondType && isPolymorphicItemType(firstType)) {
|
} else if (schemaObject.enum) {
|
||||||
return SINGLE_TO_POLYMORPHIC_MAP[firstType];
|
return 'enum';
|
||||||
|
} else if (schemaObject.type) {
|
||||||
|
if (schemaObject.type === 'number') {
|
||||||
|
// floats are "number" in OpenAPI, while ints are "integer" - we need to distinguish them
|
||||||
|
return 'float';
|
||||||
|
} else if (schemaObject.type === 'array') {
|
||||||
|
const itemType = isSchemaObject(schemaObject.items)
|
||||||
|
? schemaObject.items.type
|
||||||
|
: refObjectToSchemaName(schemaObject.items);
|
||||||
|
|
||||||
|
if (isArray(itemType)) {
|
||||||
|
// This is a nested array, which we don't support
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (isCollectionItemType(itemType)) {
|
||||||
|
return COLLECTION_MAP[itemType];
|
||||||
|
}
|
||||||
|
|
||||||
|
return;
|
||||||
|
} else if (!isArray(schemaObject.type)) {
|
||||||
|
return schemaObject.type;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else if (schemaObject.enum) {
|
} else if (isRefObject(schemaObject)) {
|
||||||
return 'enum';
|
return refObjectToSchemaName(schemaObject);
|
||||||
} else if (schemaObject.type) {
|
|
||||||
if (schemaObject.type === 'number') {
|
|
||||||
// floats are "number" in OpenAPI, while ints are "integer" - we need to distinguish them
|
|
||||||
return 'float';
|
|
||||||
} else if (schemaObject.type === 'array') {
|
|
||||||
const itemType = isSchemaObject(schemaObject.items)
|
|
||||||
? schemaObject.items.type
|
|
||||||
: refObjectToSchemaName(schemaObject.items);
|
|
||||||
|
|
||||||
if (isCollectionItemType(itemType)) {
|
|
||||||
return COLLECTION_MAP[itemType];
|
|
||||||
}
|
|
||||||
|
|
||||||
return;
|
|
||||||
} else {
|
|
||||||
return schemaObject.type;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return;
|
return;
|
||||||
};
|
};
|
||||||
@ -1025,7 +1079,15 @@ export const buildInputFieldTemplate = (
|
|||||||
name: string,
|
name: string,
|
||||||
fieldType: FieldType
|
fieldType: FieldType
|
||||||
) => {
|
) => {
|
||||||
const { input, ui_hidden, ui_component, ui_type, ui_order } = fieldSchema;
|
const {
|
||||||
|
input,
|
||||||
|
ui_hidden,
|
||||||
|
ui_component,
|
||||||
|
ui_type,
|
||||||
|
ui_order,
|
||||||
|
ui_choice_labels,
|
||||||
|
item_default,
|
||||||
|
} = fieldSchema;
|
||||||
|
|
||||||
const extra = {
|
const extra = {
|
||||||
// TODO: Can we support polymorphic inputs in the UI?
|
// TODO: Can we support polymorphic inputs in the UI?
|
||||||
@ -1035,11 +1097,13 @@ export const buildInputFieldTemplate = (
|
|||||||
ui_type,
|
ui_type,
|
||||||
required: nodeSchema.required?.includes(name) ?? false,
|
required: nodeSchema.required?.includes(name) ?? false,
|
||||||
ui_order,
|
ui_order,
|
||||||
|
ui_choice_labels,
|
||||||
|
item_default,
|
||||||
};
|
};
|
||||||
|
|
||||||
const baseField = {
|
const baseField = {
|
||||||
name,
|
name,
|
||||||
title: fieldSchema.title ?? '',
|
title: fieldSchema.title ?? (name ? startCase(name) : ''),
|
||||||
description: fieldSchema.description ?? '',
|
description: fieldSchema.description ?? '',
|
||||||
fieldKind: 'input' as const,
|
fieldKind: 'input' as const,
|
||||||
...extra,
|
...extra,
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import { logger } from 'app/logging/logger';
|
import { logger } from 'app/logging/logger';
|
||||||
import { parseify } from 'common/util/serialize';
|
import { parseify } from 'common/util/serialize';
|
||||||
import { reduce } from 'lodash-es';
|
import { reduce, startCase } from 'lodash-es';
|
||||||
import { OpenAPIV3 } from 'openapi-types';
|
import { OpenAPIV3_1 } from 'openapi-types';
|
||||||
import { AnyInvocationType } from 'services/events/types';
|
import { AnyInvocationType } from 'services/events/types';
|
||||||
import {
|
import {
|
||||||
FieldType,
|
FieldType,
|
||||||
@ -60,7 +60,7 @@ const isNotInDenylist = (schema: InvocationSchemaObject) =>
|
|||||||
!invocationDenylist.includes(schema.properties.type.default);
|
!invocationDenylist.includes(schema.properties.type.default);
|
||||||
|
|
||||||
export const parseSchema = (
|
export const parseSchema = (
|
||||||
openAPI: OpenAPIV3.Document,
|
openAPI: OpenAPIV3_1.Document,
|
||||||
nodesAllowlistExtra: string[] | undefined = undefined,
|
nodesAllowlistExtra: string[] | undefined = undefined,
|
||||||
nodesDenylistExtra: string[] | undefined = undefined
|
nodesDenylistExtra: string[] | undefined = undefined
|
||||||
): Record<string, InvocationTemplate> => {
|
): Record<string, InvocationTemplate> => {
|
||||||
@ -110,7 +110,7 @@ export const parseSchema = (
|
|||||||
return inputsAccumulator;
|
return inputsAccumulator;
|
||||||
}
|
}
|
||||||
|
|
||||||
const fieldType = getFieldType(property);
|
const fieldType = property.ui_type ?? getFieldType(property);
|
||||||
|
|
||||||
if (!isFieldType(fieldType)) {
|
if (!isFieldType(fieldType)) {
|
||||||
logger('nodes').warn(
|
logger('nodes').warn(
|
||||||
@ -209,7 +209,7 @@ export const parseSchema = (
|
|||||||
return outputsAccumulator;
|
return outputsAccumulator;
|
||||||
}
|
}
|
||||||
|
|
||||||
const fieldType = getFieldType(property);
|
const fieldType = property.ui_type ?? getFieldType(property);
|
||||||
|
|
||||||
if (!isFieldType(fieldType)) {
|
if (!isFieldType(fieldType)) {
|
||||||
logger('nodes').warn(
|
logger('nodes').warn(
|
||||||
@ -222,7 +222,8 @@ export const parseSchema = (
|
|||||||
outputsAccumulator[propertyName] = {
|
outputsAccumulator[propertyName] = {
|
||||||
fieldKind: 'output',
|
fieldKind: 'output',
|
||||||
name: propertyName,
|
name: propertyName,
|
||||||
title: property.title ?? '',
|
title:
|
||||||
|
property.title ?? (propertyName ? startCase(propertyName) : ''),
|
||||||
description: property.description ?? '',
|
description: property.description ?? '',
|
||||||
type: fieldType,
|
type: fieldType,
|
||||||
ui_hidden: property.ui_hidden ?? false,
|
ui_hidden: property.ui_hidden ?? false,
|
||||||
|
@ -7,7 +7,7 @@ const QueueItemCard = ({
|
|||||||
session_queue_item,
|
session_queue_item,
|
||||||
label,
|
label,
|
||||||
}: {
|
}: {
|
||||||
session_queue_item?: components['schemas']['SessionQueueItem'];
|
session_queue_item?: components['schemas']['SessionQueueItem'] | null;
|
||||||
label: string;
|
label: string;
|
||||||
}) => {
|
}) => {
|
||||||
return (
|
return (
|
||||||
|
@ -112,7 +112,7 @@ export default function MergeModelsPanel() {
|
|||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
const mergeModelsInfo: MergeModelConfig = {
|
const mergeModelsInfo: MergeModelConfig['body'] = {
|
||||||
model_names: models_names,
|
model_names: models_names,
|
||||||
merged_model_name:
|
merged_model_name:
|
||||||
mergedModelName !== '' ? mergedModelName : models_names.join('-'),
|
mergedModelName !== '' ? mergedModelName : models_names.join('-'),
|
||||||
@ -125,7 +125,7 @@ export default function MergeModelsPanel() {
|
|||||||
|
|
||||||
mergeModels({
|
mergeModels({
|
||||||
base_model: baseModel,
|
base_model: baseModel,
|
||||||
body: mergeModelsInfo,
|
body: { body: mergeModelsInfo },
|
||||||
})
|
})
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.then((_) => {
|
.then((_) => {
|
||||||
|
@ -520,7 +520,7 @@ export const imagesApi = api.injectEndpoints({
|
|||||||
// assume all images are on the same board/category
|
// assume all images are on the same board/category
|
||||||
if (images[0]) {
|
if (images[0]) {
|
||||||
const categories = getCategories(images[0]);
|
const categories = getCategories(images[0]);
|
||||||
const boardId = images[0].board_id;
|
const boardId = images[0].board_id ?? undefined;
|
||||||
|
|
||||||
return [
|
return [
|
||||||
{
|
{
|
||||||
@ -637,7 +637,7 @@ export const imagesApi = api.injectEndpoints({
|
|||||||
// assume all images are on the same board/category
|
// assume all images are on the same board/category
|
||||||
if (images[0]) {
|
if (images[0]) {
|
||||||
const categories = getCategories(images[0]);
|
const categories = getCategories(images[0]);
|
||||||
const boardId = images[0].board_id;
|
const boardId = images[0].board_id ?? undefined;
|
||||||
return [
|
return [
|
||||||
{
|
{
|
||||||
type: 'ImageList',
|
type: 'ImageList',
|
||||||
|
3486
invokeai/frontend/web/src/services/api/schema.d.ts
vendored
3486
invokeai/frontend/web/src/services/api/schema.d.ts
vendored
File diff suppressed because one or more lines are too long
110
pyproject.toml
110
pyproject.toml
@ -35,10 +35,10 @@ dependencies = [
|
|||||||
"accelerate~=0.23.0",
|
"accelerate~=0.23.0",
|
||||||
"albumentations",
|
"albumentations",
|
||||||
"click",
|
"click",
|
||||||
"clip_anytorch", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip",
|
"clip_anytorch", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip",
|
||||||
"compel~=2.0.2",
|
"compel~=2.0.2",
|
||||||
"controlnet-aux>=0.0.6",
|
"controlnet-aux>=0.0.6",
|
||||||
"timm==0.6.13", # needed to override timm latest in controlnet_aux, see https://github.com/isl-org/ZoeDepth/issues/26
|
"timm==0.6.13", # needed to override timm latest in controlnet_aux, see https://github.com/isl-org/ZoeDepth/issues/26
|
||||||
"datasets",
|
"datasets",
|
||||||
# When bumping diffusers beyond 0.21, make sure to address this:
|
# When bumping diffusers beyond 0.21, make sure to address this:
|
||||||
# https://github.com/invoke-ai/InvokeAI/blob/fc09ab7e13cb7ca5389100d149b6422ace7b8ed3/invokeai/app/invocations/latent.py#L513
|
# https://github.com/invoke-ai/InvokeAI/blob/fc09ab7e13cb7ca5389100d149b6422ace7b8ed3/invokeai/app/invocations/latent.py#L513
|
||||||
@ -48,19 +48,20 @@ dependencies = [
|
|||||||
"easing-functions",
|
"easing-functions",
|
||||||
"einops",
|
"einops",
|
||||||
"facexlib",
|
"facexlib",
|
||||||
"fastapi==0.88.0",
|
"fastapi~=0.103.2",
|
||||||
"fastapi-events==0.8.0",
|
"fastapi-events~=0.9.1",
|
||||||
"huggingface-hub~=0.16.4",
|
"huggingface-hub~=0.16.4",
|
||||||
"invisible-watermark~=0.2.0", # needed to install SDXL base and refiner using their repo_ids
|
"invisible-watermark~=0.2.0", # needed to install SDXL base and refiner using their repo_ids
|
||||||
"matplotlib", # needed for plotting of Penner easing functions
|
"matplotlib", # needed for plotting of Penner easing functions
|
||||||
"mediapipe", # needed for "mediapipeface" controlnet model
|
"mediapipe", # needed for "mediapipeface" controlnet model
|
||||||
"numpy",
|
"numpy",
|
||||||
"npyscreen",
|
"npyscreen",
|
||||||
"omegaconf",
|
"omegaconf",
|
||||||
"onnx",
|
"onnx",
|
||||||
"onnxruntime",
|
"onnxruntime",
|
||||||
"opencv-python",
|
"opencv-python",
|
||||||
"pydantic==1.*",
|
"pydantic~=2.4.2",
|
||||||
|
"pydantic-settings~=2.0.3",
|
||||||
"picklescan",
|
"picklescan",
|
||||||
"pillow",
|
"pillow",
|
||||||
"prompt-toolkit",
|
"prompt-toolkit",
|
||||||
@ -95,33 +96,25 @@ dependencies = [
|
|||||||
"mkdocs-git-revision-date-localized-plugin",
|
"mkdocs-git-revision-date-localized-plugin",
|
||||||
"mkdocs-redirects==1.2.0",
|
"mkdocs-redirects==1.2.0",
|
||||||
]
|
]
|
||||||
"dev" = [
|
"dev" = ["jurigged", "pudb"]
|
||||||
"jurigged",
|
|
||||||
"pudb",
|
|
||||||
]
|
|
||||||
"test" = [
|
"test" = [
|
||||||
"black",
|
"black",
|
||||||
"flake8",
|
"flake8",
|
||||||
"Flake8-pyproject",
|
"Flake8-pyproject",
|
||||||
"isort",
|
"isort",
|
||||||
|
"mypy",
|
||||||
"pre-commit",
|
"pre-commit",
|
||||||
"pytest>6.0.0",
|
"pytest>6.0.0",
|
||||||
"pytest-cov",
|
"pytest-cov",
|
||||||
"pytest-datadir",
|
"pytest-datadir",
|
||||||
]
|
]
|
||||||
"xformers" = [
|
"xformers" = [
|
||||||
"xformers~=0.0.19; sys_platform!='darwin'",
|
"xformers~=0.0.19; sys_platform!='darwin'",
|
||||||
"triton; sys_platform=='linux'",
|
"triton; sys_platform=='linux'",
|
||||||
]
|
|
||||||
"onnx" = [
|
|
||||||
"onnxruntime",
|
|
||||||
]
|
|
||||||
"onnx-cuda" = [
|
|
||||||
"onnxruntime-gpu",
|
|
||||||
]
|
|
||||||
"onnx-directml" = [
|
|
||||||
"onnxruntime-directml",
|
|
||||||
]
|
]
|
||||||
|
"onnx" = ["onnxruntime"]
|
||||||
|
"onnx-cuda" = ["onnxruntime-gpu"]
|
||||||
|
"onnx-directml" = ["onnxruntime-directml"]
|
||||||
|
|
||||||
[project.scripts]
|
[project.scripts]
|
||||||
|
|
||||||
@ -163,12 +156,15 @@ version = { attr = "invokeai.version.__version__" }
|
|||||||
[tool.setuptools.packages.find]
|
[tool.setuptools.packages.find]
|
||||||
"where" = ["."]
|
"where" = ["."]
|
||||||
"include" = [
|
"include" = [
|
||||||
"invokeai.assets.fonts*","invokeai.version*",
|
"invokeai.assets.fonts*",
|
||||||
"invokeai.generator*","invokeai.backend*",
|
"invokeai.version*",
|
||||||
"invokeai.frontend*", "invokeai.frontend.web.dist*",
|
"invokeai.generator*",
|
||||||
"invokeai.frontend.web.static*",
|
"invokeai.backend*",
|
||||||
"invokeai.configs*",
|
"invokeai.frontend*",
|
||||||
"invokeai.app*",
|
"invokeai.frontend.web.dist*",
|
||||||
|
"invokeai.frontend.web.static*",
|
||||||
|
"invokeai.configs*",
|
||||||
|
"invokeai.app*",
|
||||||
]
|
]
|
||||||
|
|
||||||
[tool.setuptools.package-data]
|
[tool.setuptools.package-data]
|
||||||
@ -182,7 +178,7 @@ version = { attr = "invokeai.version.__version__" }
|
|||||||
[tool.pytest.ini_options]
|
[tool.pytest.ini_options]
|
||||||
addopts = "--cov-report term --cov-report html --cov-report xml --strict-markers -m \"not slow\""
|
addopts = "--cov-report term --cov-report html --cov-report xml --strict-markers -m \"not slow\""
|
||||||
markers = [
|
markers = [
|
||||||
"slow: Marks tests as slow. Disabled by default. To run all tests, use -m \"\". To run only slow tests, use -m \"slow\"."
|
"slow: Marks tests as slow. Disabled by default. To run all tests, use -m \"\". To run only slow tests, use -m \"slow\".",
|
||||||
]
|
]
|
||||||
[tool.coverage.run]
|
[tool.coverage.run]
|
||||||
branch = true
|
branch = true
|
||||||
@ -190,7 +186,7 @@ source = ["invokeai"]
|
|||||||
omit = ["*tests*", "*migrations*", ".venv/*", "*.env"]
|
omit = ["*tests*", "*migrations*", ".venv/*", "*.env"]
|
||||||
[tool.coverage.report]
|
[tool.coverage.report]
|
||||||
show_missing = true
|
show_missing = true
|
||||||
fail_under = 85 # let's set something sensible on Day 1 ...
|
fail_under = 85 # let's set something sensible on Day 1 ...
|
||||||
[tool.coverage.json]
|
[tool.coverage.json]
|
||||||
output = "coverage/coverage.json"
|
output = "coverage/coverage.json"
|
||||||
pretty_print = true
|
pretty_print = true
|
||||||
@ -209,7 +205,7 @@ exclude = [
|
|||||||
"__pycache__",
|
"__pycache__",
|
||||||
"build",
|
"build",
|
||||||
"dist",
|
"dist",
|
||||||
"invokeai/frontend/web/node_modules/"
|
"invokeai/frontend/web/node_modules/",
|
||||||
]
|
]
|
||||||
|
|
||||||
[tool.black]
|
[tool.black]
|
||||||
@ -218,3 +214,53 @@ line-length = 120
|
|||||||
[tool.isort]
|
[tool.isort]
|
||||||
profile = "black"
|
profile = "black"
|
||||||
line_length = 120
|
line_length = 120
|
||||||
|
|
||||||
|
[tool.mypy]
|
||||||
|
ignore_missing_imports = true # ignores missing types in third-party libraries
|
||||||
|
|
||||||
|
[[tool.mypy.overrides]]
|
||||||
|
follow_imports = "skip"
|
||||||
|
module = [
|
||||||
|
"invokeai.app.api.routers.models",
|
||||||
|
"invokeai.app.invocations.compel",
|
||||||
|
"invokeai.app.invocations.latent",
|
||||||
|
"invokeai.app.services.config.config_base",
|
||||||
|
"invokeai.app.services.config.config_default",
|
||||||
|
"invokeai.app.services.invocation_stats.invocation_stats_default",
|
||||||
|
"invokeai.app.services.model_manager.model_manager_base",
|
||||||
|
"invokeai.app.services.model_manager.model_manager_default",
|
||||||
|
"invokeai.app.util.controlnet_utils",
|
||||||
|
"invokeai.backend.image_util.txt2mask",
|
||||||
|
"invokeai.backend.image_util.safety_checker",
|
||||||
|
"invokeai.backend.image_util.patchmatch",
|
||||||
|
"invokeai.backend.image_util.invisible_watermark",
|
||||||
|
"invokeai.backend.install.model_install_backend",
|
||||||
|
"invokeai.backend.ip_adapter.ip_adapter",
|
||||||
|
"invokeai.backend.ip_adapter.resampler",
|
||||||
|
"invokeai.backend.ip_adapter.unet_patcher",
|
||||||
|
"invokeai.backend.model_management.convert_ckpt_to_diffusers",
|
||||||
|
"invokeai.backend.model_management.lora",
|
||||||
|
"invokeai.backend.model_management.model_cache",
|
||||||
|
"invokeai.backend.model_management.model_manager",
|
||||||
|
"invokeai.backend.model_management.model_merge",
|
||||||
|
"invokeai.backend.model_management.model_probe",
|
||||||
|
"invokeai.backend.model_management.model_search",
|
||||||
|
"invokeai.backend.model_management.models.*", # this is needed to ignore the module's `__init__.py`
|
||||||
|
"invokeai.backend.model_management.models.base",
|
||||||
|
"invokeai.backend.model_management.models.controlnet",
|
||||||
|
"invokeai.backend.model_management.models.ip_adapter",
|
||||||
|
"invokeai.backend.model_management.models.lora",
|
||||||
|
"invokeai.backend.model_management.models.sdxl",
|
||||||
|
"invokeai.backend.model_management.models.stable_diffusion",
|
||||||
|
"invokeai.backend.model_management.models.vae",
|
||||||
|
"invokeai.backend.model_management.seamless",
|
||||||
|
"invokeai.backend.model_management.util",
|
||||||
|
"invokeai.backend.stable_diffusion.diffusers_pipeline",
|
||||||
|
"invokeai.backend.stable_diffusion.diffusion.cross_attention_control",
|
||||||
|
"invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion",
|
||||||
|
"invokeai.backend.util.hotfixes",
|
||||||
|
"invokeai.backend.util.logging",
|
||||||
|
"invokeai.backend.util.mps_fixes",
|
||||||
|
"invokeai.backend.util.util",
|
||||||
|
"invokeai.frontend.install.model_install",
|
||||||
|
]
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
import pytest
|
import pytest
|
||||||
|
from pydantic import TypeAdapter
|
||||||
|
|
||||||
from invokeai.app.invocations.baseinvocation import (
|
from invokeai.app.invocations.baseinvocation import (
|
||||||
BaseInvocation,
|
BaseInvocation,
|
||||||
@ -593,20 +594,21 @@ def test_graph_can_serialize():
|
|||||||
g.add_edge(e)
|
g.add_edge(e)
|
||||||
|
|
||||||
# Not throwing on this line is sufficient
|
# Not throwing on this line is sufficient
|
||||||
_ = g.json()
|
_ = g.model_dump_json()
|
||||||
|
|
||||||
|
|
||||||
def test_graph_can_deserialize():
|
def test_graph_can_deserialize():
|
||||||
g = Graph()
|
g = Graph()
|
||||||
n1 = TextToImageTestInvocation(id="1", prompt="Banana sushi")
|
n1 = TextToImageTestInvocation(id="1", prompt="Banana sushi")
|
||||||
n2 = ESRGANInvocation(id="2")
|
n2 = ImageToImageTestInvocation(id="2")
|
||||||
g.add_node(n1)
|
g.add_node(n1)
|
||||||
g.add_node(n2)
|
g.add_node(n2)
|
||||||
e = create_edge(n1.id, "image", n2.id, "image")
|
e = create_edge(n1.id, "image", n2.id, "image")
|
||||||
g.add_edge(e)
|
g.add_edge(e)
|
||||||
|
|
||||||
json = g.json()
|
json = g.model_dump_json()
|
||||||
g2 = Graph.parse_raw(json)
|
adapter_graph = TypeAdapter(Graph)
|
||||||
|
g2 = adapter_graph.validate_json(json)
|
||||||
|
|
||||||
assert g2 is not None
|
assert g2 is not None
|
||||||
assert g2.nodes["1"] is not None
|
assert g2.nodes["1"] is not None
|
||||||
@ -619,7 +621,7 @@ def test_graph_can_deserialize():
|
|||||||
|
|
||||||
|
|
||||||
def test_invocation_decorator():
|
def test_invocation_decorator():
|
||||||
invocation_type = "test_invocation"
|
invocation_type = "test_invocation_decorator"
|
||||||
title = "Test Invocation"
|
title = "Test Invocation"
|
||||||
tags = ["first", "second", "third"]
|
tags = ["first", "second", "third"]
|
||||||
category = "category"
|
category = "category"
|
||||||
@ -630,7 +632,7 @@ def test_invocation_decorator():
|
|||||||
def invoke(self):
|
def invoke(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
schema = TestInvocation.schema()
|
schema = TestInvocation.model_json_schema()
|
||||||
|
|
||||||
assert schema.get("title") == title
|
assert schema.get("title") == title
|
||||||
assert schema.get("tags") == tags
|
assert schema.get("tags") == tags
|
||||||
@ -640,18 +642,17 @@ def test_invocation_decorator():
|
|||||||
|
|
||||||
|
|
||||||
def test_invocation_version_must_be_semver():
|
def test_invocation_version_must_be_semver():
|
||||||
invocation_type = "test_invocation"
|
|
||||||
valid_version = "1.0.0"
|
valid_version = "1.0.0"
|
||||||
invalid_version = "not_semver"
|
invalid_version = "not_semver"
|
||||||
|
|
||||||
@invocation(invocation_type, version=valid_version)
|
@invocation("test_invocation_version_valid", version=valid_version)
|
||||||
class ValidVersionInvocation(BaseInvocation):
|
class ValidVersionInvocation(BaseInvocation):
|
||||||
def invoke(self):
|
def invoke(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
with pytest.raises(InvalidVersionError):
|
with pytest.raises(InvalidVersionError):
|
||||||
|
|
||||||
@invocation(invocation_type, version=invalid_version)
|
@invocation("test_invocation_version_invalid", version=invalid_version)
|
||||||
class InvalidVersionInvocation(BaseInvocation):
|
class InvalidVersionInvocation(BaseInvocation):
|
||||||
def invoke(self):
|
def invoke(self):
|
||||||
pass
|
pass
|
||||||
@ -694,4 +695,4 @@ def test_ints_do_not_accept_floats():
|
|||||||
def test_graph_can_generate_schema():
|
def test_graph_can_generate_schema():
|
||||||
# Not throwing on this line is sufficient
|
# Not throwing on this line is sufficient
|
||||||
# NOTE: if this test fails, it's PROBABLY because a new invocation type is breaking schema generation
|
# NOTE: if this test fails, it's PROBABLY because a new invocation type is breaking schema generation
|
||||||
_ = Graph.schema_json(indent=2)
|
_ = Graph.model_json_schema()
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
import pytest
|
import pytest
|
||||||
from pydantic import ValidationError, parse_raw_as
|
from pydantic import TypeAdapter, ValidationError
|
||||||
|
|
||||||
from invokeai.app.services.session_queue.session_queue_common import (
|
from invokeai.app.services.session_queue.session_queue_common import (
|
||||||
Batch,
|
Batch,
|
||||||
@ -150,8 +150,9 @@ def test_prepare_values_to_insert(batch_data_collection, batch_graph):
|
|||||||
values = prepare_values_to_insert(queue_id="default", batch=b, priority=0, max_new_queue_items=1000)
|
values = prepare_values_to_insert(queue_id="default", batch=b, priority=0, max_new_queue_items=1000)
|
||||||
assert len(values) == 8
|
assert len(values) == 8
|
||||||
|
|
||||||
|
session_adapter = TypeAdapter(GraphExecutionState)
|
||||||
# graph should be serialized
|
# graph should be serialized
|
||||||
ges = parse_raw_as(GraphExecutionState, values[0].session)
|
ges = session_adapter.validate_json(values[0].session)
|
||||||
|
|
||||||
# graph values should be populated
|
# graph values should be populated
|
||||||
assert ges.graph.get_node("1").prompt == "Banana sushi"
|
assert ges.graph.get_node("1").prompt == "Banana sushi"
|
||||||
@ -160,15 +161,16 @@ def test_prepare_values_to_insert(batch_data_collection, batch_graph):
|
|||||||
assert ges.graph.get_node("4").prompt == "Nissan"
|
assert ges.graph.get_node("4").prompt == "Nissan"
|
||||||
|
|
||||||
# session ids should match deserialized graph
|
# session ids should match deserialized graph
|
||||||
assert [v.session_id for v in values] == [parse_raw_as(GraphExecutionState, v.session).id for v in values]
|
assert [v.session_id for v in values] == [session_adapter.validate_json(v.session).id for v in values]
|
||||||
|
|
||||||
# should unique session ids
|
# should unique session ids
|
||||||
sids = [v.session_id for v in values]
|
sids = [v.session_id for v in values]
|
||||||
assert len(sids) == len(set(sids))
|
assert len(sids) == len(set(sids))
|
||||||
|
|
||||||
|
nfv_list_adapter = TypeAdapter(list[NodeFieldValue])
|
||||||
# should have 3 node field values
|
# should have 3 node field values
|
||||||
assert type(values[0].field_values) is str
|
assert type(values[0].field_values) is str
|
||||||
assert len(parse_raw_as(list[NodeFieldValue], values[0].field_values)) == 3
|
assert len(nfv_list_adapter.validate_json(values[0].field_values)) == 3
|
||||||
|
|
||||||
# should have batch id and priority
|
# should have batch id and priority
|
||||||
assert all(v.batch_id == b.batch_id for v in values)
|
assert all(v.batch_id == b.batch_id for v in values)
|
||||||
|
@ -15,7 +15,8 @@ class TestModel(BaseModel):
|
|||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def db() -> SqliteItemStorage[TestModel]:
|
def db() -> SqliteItemStorage[TestModel]:
|
||||||
sqlite_db = SqliteDatabase(InvokeAIAppConfig(use_memory_db=True), InvokeAILogger.get_logger())
|
sqlite_db = SqliteDatabase(InvokeAIAppConfig(use_memory_db=True), InvokeAILogger.get_logger())
|
||||||
return SqliteItemStorage[TestModel](db=sqlite_db, table_name="test", id_field="id")
|
sqlite_item_storage = SqliteItemStorage[TestModel](db=sqlite_db, table_name="test", id_field="id")
|
||||||
|
return sqlite_item_storage
|
||||||
|
|
||||||
|
|
||||||
def test_sqlite_service_can_create_and_get(db: SqliteItemStorage[TestModel]):
|
def test_sqlite_service_can_create_and_get(db: SqliteItemStorage[TestModel]):
|
||||||
|
Loading…
Reference in New Issue
Block a user