mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'main' into feat/dev_reload
This commit is contained in:
@ -123,6 +123,7 @@ def custom_openapi():
|
||||
|
||||
output_schemas = schema(output_types, ref_prefix="#/components/schemas/")
|
||||
for schema_key, output_schema in output_schemas["definitions"].items():
|
||||
output_schema["class"] = "output"
|
||||
openapi_schema["components"]["schemas"][schema_key] = output_schema
|
||||
|
||||
# TODO: note that we assume the schema_key here is the TYPE.__name__
|
||||
@ -131,8 +132,8 @@ def custom_openapi():
|
||||
|
||||
# Add Node Editor UI helper schemas
|
||||
ui_config_schemas = schema([UIConfigBase, _InputField, _OutputField], ref_prefix="#/components/schemas/")
|
||||
for schema_key, output_schema in ui_config_schemas["definitions"].items():
|
||||
openapi_schema["components"]["schemas"][schema_key] = output_schema
|
||||
for schema_key, ui_config_schema in ui_config_schemas["definitions"].items():
|
||||
openapi_schema["components"]["schemas"][schema_key] = ui_config_schema
|
||||
|
||||
# Add a reference to the output type to additionalProperties of the invoker schema
|
||||
for invoker in all_invocations:
|
||||
@ -141,8 +142,8 @@ def custom_openapi():
|
||||
output_type_title = output_type_titles[output_type.__name__]
|
||||
invoker_schema = openapi_schema["components"]["schemas"][invoker_name]
|
||||
outputs_ref = {"$ref": f"#/components/schemas/{output_type_title}"}
|
||||
|
||||
invoker_schema["output"] = outputs_ref
|
||||
invoker_schema["class"] = "invocation"
|
||||
|
||||
from invokeai.backend.model_management.models import get_model_config_enums
|
||||
|
||||
|
@ -143,6 +143,7 @@ class UIType(str, Enum):
|
||||
# region Misc
|
||||
FilePath = "FilePath"
|
||||
Enum = "enum"
|
||||
Scheduler = "Scheduler"
|
||||
# endregion
|
||||
|
||||
|
||||
@ -392,6 +393,13 @@ class BaseInvocationOutput(BaseModel):
|
||||
toprocess.extend(next_subclasses)
|
||||
return tuple(subclasses)
|
||||
|
||||
class Config:
|
||||
@staticmethod
|
||||
def schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None:
|
||||
if "required" not in schema or not isinstance(schema["required"], list):
|
||||
schema["required"] = list()
|
||||
schema["required"].extend(["type"])
|
||||
|
||||
|
||||
class RequiredConnectionException(Exception):
|
||||
"""Raised when an field which requires a connection did not receive a value."""
|
||||
@ -452,6 +460,9 @@ class BaseInvocation(ABC, BaseModel):
|
||||
schema["title"] = uiconfig.title
|
||||
if uiconfig and hasattr(uiconfig, "tags"):
|
||||
schema["tags"] = uiconfig.tags
|
||||
if "required" not in schema or not isinstance(schema["required"], list):
|
||||
schema["required"] = list()
|
||||
schema["required"].extend(["type", "id"])
|
||||
|
||||
@abstractmethod
|
||||
def invoke(self, context: InvocationContext) -> BaseInvocationOutput:
|
||||
|
@ -115,12 +115,14 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
noise: Optional[LatentsField] = InputField(description=FieldDescriptions.noise, input=Input.Connection)
|
||||
steps: int = InputField(default=10, gt=0, description=FieldDescriptions.steps)
|
||||
cfg_scale: Union[float, List[float]] = InputField(
|
||||
default=7.5, ge=1, description=FieldDescriptions.cfg_scale, ui_type=UIType.Float
|
||||
default=7.5, ge=1, description=FieldDescriptions.cfg_scale, ui_type=UIType.Float, title="CFG Scale"
|
||||
)
|
||||
denoising_start: float = InputField(default=0.0, ge=0, le=1, description=FieldDescriptions.denoising_start)
|
||||
denoising_end: float = InputField(default=1.0, ge=0, le=1, description=FieldDescriptions.denoising_end)
|
||||
scheduler: SAMPLER_NAME_VALUES = InputField(default="euler", description=FieldDescriptions.scheduler)
|
||||
unet: UNetField = InputField(description=FieldDescriptions.unet, input=Input.Connection)
|
||||
scheduler: SAMPLER_NAME_VALUES = InputField(
|
||||
default="euler", description=FieldDescriptions.scheduler, ui_type=UIType.Scheduler
|
||||
)
|
||||
unet: UNetField = InputField(description=FieldDescriptions.unet, input=Input.Connection, title="UNet")
|
||||
control: Union[ControlField, list[ControlField]] = InputField(
|
||||
default=None, description=FieldDescriptions.control, input=Input.Connection
|
||||
)
|
||||
@ -454,7 +456,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
|
||||
|
||||
@title("Latents to Image")
|
||||
@tags("latents", "image", "vae")
|
||||
@tags("latents", "image", "vae", "l2i")
|
||||
class LatentsToImageInvocation(BaseInvocation):
|
||||
"""Generates an image from latents."""
|
||||
|
||||
@ -642,7 +644,7 @@ class ScaleLatentsInvocation(BaseInvocation):
|
||||
|
||||
|
||||
@title("Image to Latents")
|
||||
@tags("latents", "image", "vae")
|
||||
@tags("latents", "image", "vae", "i2l")
|
||||
class ImageToLatentsInvocation(BaseInvocation):
|
||||
"""Encodes an image into latents."""
|
||||
|
||||
|
@ -21,7 +21,7 @@ class AddInvocation(BaseInvocation):
|
||||
b: int = InputField(default=0, description=FieldDescriptions.num_2)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IntegerOutput:
|
||||
return IntegerOutput(a=self.a + self.b)
|
||||
return IntegerOutput(value=self.a + self.b)
|
||||
|
||||
|
||||
@title("Subtract Integers")
|
||||
@ -36,7 +36,7 @@ class SubtractInvocation(BaseInvocation):
|
||||
b: int = InputField(default=0, description=FieldDescriptions.num_2)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IntegerOutput:
|
||||
return IntegerOutput(a=self.a - self.b)
|
||||
return IntegerOutput(value=self.a - self.b)
|
||||
|
||||
|
||||
@title("Multiply Integers")
|
||||
@ -51,7 +51,7 @@ class MultiplyInvocation(BaseInvocation):
|
||||
b: int = InputField(default=0, description=FieldDescriptions.num_2)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IntegerOutput:
|
||||
return IntegerOutput(a=self.a * self.b)
|
||||
return IntegerOutput(value=self.a * self.b)
|
||||
|
||||
|
||||
@title("Divide Integers")
|
||||
@ -66,7 +66,7 @@ class DivideInvocation(BaseInvocation):
|
||||
b: int = InputField(default=0, description=FieldDescriptions.num_2)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IntegerOutput:
|
||||
return IntegerOutput(a=int(self.a / self.b))
|
||||
return IntegerOutput(value=int(self.a / self.b))
|
||||
|
||||
|
||||
@title("Random Integer")
|
||||
@ -81,4 +81,4 @@ class RandomIntInvocation(BaseInvocation):
|
||||
high: int = InputField(default=np.iinfo(np.int32).max, description="The exclusive high value")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IntegerOutput:
|
||||
return IntegerOutput(a=np.random.randint(self.low, self.high))
|
||||
return IntegerOutput(value=np.random.randint(self.low, self.high))
|
||||
|
@ -72,7 +72,7 @@ class LoRAModelField(BaseModel):
|
||||
base_model: BaseModelType = Field(description="Base model")
|
||||
|
||||
|
||||
@title("Main Model Loader")
|
||||
@title("Main Model")
|
||||
@tags("model")
|
||||
class MainModelLoaderInvocation(BaseInvocation):
|
||||
"""Loads a main model, outputting its submodels."""
|
||||
@ -179,7 +179,7 @@ class LoraLoaderOutput(BaseInvocationOutput):
|
||||
# fmt: on
|
||||
|
||||
|
||||
@title("LoRA Loader")
|
||||
@title("LoRA")
|
||||
@tags("lora", "model")
|
||||
class LoraLoaderInvocation(BaseInvocation):
|
||||
"""Apply selected lora to unet and text_encoder."""
|
||||
@ -257,7 +257,7 @@ class SDXLLoraLoaderOutput(BaseInvocationOutput):
|
||||
# fmt: on
|
||||
|
||||
|
||||
@title("SDXL LoRA Loader")
|
||||
@title("SDXL LoRA")
|
||||
@tags("sdxl", "lora", "model")
|
||||
class SDXLLoraLoaderInvocation(BaseInvocation):
|
||||
"""Apply selected lora to unet and text_encoder."""
|
||||
@ -356,7 +356,7 @@ class VaeLoaderOutput(BaseInvocationOutput):
|
||||
vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE")
|
||||
|
||||
|
||||
@title("VAE Loader")
|
||||
@title("VAE")
|
||||
@tags("vae", "model")
|
||||
class VaeLoaderInvocation(BaseInvocation):
|
||||
"""Loads a VAE model, outputting a VaeLoaderOutput"""
|
||||
|
@ -169,7 +169,7 @@ class ONNXTextToLatentsInvocation(BaseInvocation):
|
||||
ui_type=UIType.Float,
|
||||
)
|
||||
scheduler: SAMPLER_NAME_VALUES = InputField(
|
||||
default="euler", description=FieldDescriptions.scheduler, input=Input.Direct
|
||||
default="euler", description=FieldDescriptions.scheduler, input=Input.Direct, ui_type=UIType.Scheduler
|
||||
)
|
||||
precision: PRECISION_VALUES = InputField(default="tensor(float16)", description=FieldDescriptions.precision)
|
||||
unet: UNetField = InputField(
|
||||
@ -406,7 +406,7 @@ class OnnxModelField(BaseModel):
|
||||
model_type: ModelType = Field(description="Model Type")
|
||||
|
||||
|
||||
@title("ONNX Model Loader")
|
||||
@title("ONNX Main Model")
|
||||
@tags("onnx", "model")
|
||||
class OnnxModelLoaderInvocation(BaseInvocation):
|
||||
"""Loads a main model, outputting its submodels."""
|
||||
|
@ -2,8 +2,8 @@
|
||||
|
||||
from typing import Literal, Optional, Tuple
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
import torch
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
@ -33,7 +33,7 @@ class BooleanOutput(BaseInvocationOutput):
|
||||
"""Base class for nodes that output a single boolean"""
|
||||
|
||||
type: Literal["boolean_output"] = "boolean_output"
|
||||
a: bool = OutputField(description="The output boolean")
|
||||
value: bool = OutputField(description="The output boolean")
|
||||
|
||||
|
||||
class BooleanCollectionOutput(BaseInvocationOutput):
|
||||
@ -42,9 +42,7 @@ class BooleanCollectionOutput(BaseInvocationOutput):
|
||||
type: Literal["boolean_collection_output"] = "boolean_collection_output"
|
||||
|
||||
# Outputs
|
||||
collection: list[bool] = OutputField(
|
||||
default_factory=list, description="The output boolean collection", ui_type=UIType.BooleanCollection
|
||||
)
|
||||
collection: list[bool] = OutputField(description="The output boolean collection", ui_type=UIType.BooleanCollection)
|
||||
|
||||
|
||||
@title("Boolean Primitive")
|
||||
@ -55,10 +53,10 @@ class BooleanInvocation(BaseInvocation):
|
||||
type: Literal["boolean"] = "boolean"
|
||||
|
||||
# Inputs
|
||||
a: bool = InputField(default=False, description="The boolean value")
|
||||
value: bool = InputField(default=False, description="The boolean value")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> BooleanOutput:
|
||||
return BooleanOutput(a=self.a)
|
||||
return BooleanOutput(value=self.value)
|
||||
|
||||
|
||||
@title("Boolean Primitive Collection")
|
||||
@ -70,7 +68,7 @@ class BooleanCollectionInvocation(BaseInvocation):
|
||||
|
||||
# Inputs
|
||||
collection: list[bool] = InputField(
|
||||
default=False, description="The collection of boolean values", ui_type=UIType.BooleanCollection
|
||||
default_factory=list, description="The collection of boolean values", ui_type=UIType.BooleanCollection
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> BooleanCollectionOutput:
|
||||
@ -86,7 +84,7 @@ class IntegerOutput(BaseInvocationOutput):
|
||||
"""Base class for nodes that output a single integer"""
|
||||
|
||||
type: Literal["integer_output"] = "integer_output"
|
||||
a: int = OutputField(description="The output integer")
|
||||
value: int = OutputField(description="The output integer")
|
||||
|
||||
|
||||
class IntegerCollectionOutput(BaseInvocationOutput):
|
||||
@ -95,9 +93,7 @@ class IntegerCollectionOutput(BaseInvocationOutput):
|
||||
type: Literal["integer_collection_output"] = "integer_collection_output"
|
||||
|
||||
# Outputs
|
||||
collection: list[int] = OutputField(
|
||||
default_factory=list, description="The int collection", ui_type=UIType.IntegerCollection
|
||||
)
|
||||
collection: list[int] = OutputField(description="The int collection", ui_type=UIType.IntegerCollection)
|
||||
|
||||
|
||||
@title("Integer Primitive")
|
||||
@ -108,10 +104,10 @@ class IntegerInvocation(BaseInvocation):
|
||||
type: Literal["integer"] = "integer"
|
||||
|
||||
# Inputs
|
||||
a: int = InputField(default=0, description="The integer value")
|
||||
value: int = InputField(default=0, description="The integer value")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IntegerOutput:
|
||||
return IntegerOutput(a=self.a)
|
||||
return IntegerOutput(value=self.value)
|
||||
|
||||
|
||||
@title("Integer Primitive Collection")
|
||||
@ -139,7 +135,7 @@ class FloatOutput(BaseInvocationOutput):
|
||||
"""Base class for nodes that output a single float"""
|
||||
|
||||
type: Literal["float_output"] = "float_output"
|
||||
a: float = OutputField(description="The output float")
|
||||
value: float = OutputField(description="The output float")
|
||||
|
||||
|
||||
class FloatCollectionOutput(BaseInvocationOutput):
|
||||
@ -148,9 +144,7 @@ class FloatCollectionOutput(BaseInvocationOutput):
|
||||
type: Literal["float_collection_output"] = "float_collection_output"
|
||||
|
||||
# Outputs
|
||||
collection: list[float] = OutputField(
|
||||
default_factory=list, description="The float collection", ui_type=UIType.FloatCollection
|
||||
)
|
||||
collection: list[float] = OutputField(description="The float collection", ui_type=UIType.FloatCollection)
|
||||
|
||||
|
||||
@title("Float Primitive")
|
||||
@ -161,10 +155,10 @@ class FloatInvocation(BaseInvocation):
|
||||
type: Literal["float"] = "float"
|
||||
|
||||
# Inputs
|
||||
param: float = InputField(default=0.0, description="The float value")
|
||||
value: float = InputField(default=0.0, description="The float value")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> FloatOutput:
|
||||
return FloatOutput(a=self.param)
|
||||
return FloatOutput(value=self.value)
|
||||
|
||||
|
||||
@title("Float Primitive Collection")
|
||||
@ -176,7 +170,7 @@ class FloatCollectionInvocation(BaseInvocation):
|
||||
|
||||
# Inputs
|
||||
collection: list[float] = InputField(
|
||||
default=0, description="The collection of float values", ui_type=UIType.FloatCollection
|
||||
default_factory=list, description="The collection of float values", ui_type=UIType.FloatCollection
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> FloatCollectionOutput:
|
||||
@ -192,7 +186,7 @@ class StringOutput(BaseInvocationOutput):
|
||||
"""Base class for nodes that output a single string"""
|
||||
|
||||
type: Literal["string_output"] = "string_output"
|
||||
text: str = OutputField(description="The output string")
|
||||
value: str = OutputField(description="The output string")
|
||||
|
||||
|
||||
class StringCollectionOutput(BaseInvocationOutput):
|
||||
@ -201,9 +195,7 @@ class StringCollectionOutput(BaseInvocationOutput):
|
||||
type: Literal["string_collection_output"] = "string_collection_output"
|
||||
|
||||
# Outputs
|
||||
collection: list[str] = OutputField(
|
||||
default_factory=list, description="The output strings", ui_type=UIType.StringCollection
|
||||
)
|
||||
collection: list[str] = OutputField(description="The output strings", ui_type=UIType.StringCollection)
|
||||
|
||||
|
||||
@title("String Primitive")
|
||||
@ -214,10 +206,10 @@ class StringInvocation(BaseInvocation):
|
||||
type: Literal["string"] = "string"
|
||||
|
||||
# Inputs
|
||||
text: str = InputField(default="", description="The string value", ui_component=UIComponent.Textarea)
|
||||
value: str = InputField(default="", description="The string value", ui_component=UIComponent.Textarea)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> StringOutput:
|
||||
return StringOutput(text=self.text)
|
||||
return StringOutput(value=self.value)
|
||||
|
||||
|
||||
@title("String Primitive Collection")
|
||||
@ -229,7 +221,7 @@ class StringCollectionInvocation(BaseInvocation):
|
||||
|
||||
# Inputs
|
||||
collection: list[str] = InputField(
|
||||
default=0, description="The collection of string values", ui_type=UIType.StringCollection
|
||||
default_factory=list, description="The collection of string values", ui_type=UIType.StringCollection
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> StringCollectionOutput:
|
||||
@ -262,9 +254,7 @@ class ImageCollectionOutput(BaseInvocationOutput):
|
||||
type: Literal["image_collection_output"] = "image_collection_output"
|
||||
|
||||
# Outputs
|
||||
collection: list[ImageField] = OutputField(
|
||||
default_factory=list, description="The output images", ui_type=UIType.ImageCollection
|
||||
)
|
||||
collection: list[ImageField] = OutputField(description="The output images", ui_type=UIType.ImageCollection)
|
||||
|
||||
|
||||
@title("Image Primitive")
|
||||
@ -334,7 +324,6 @@ class LatentsCollectionOutput(BaseInvocationOutput):
|
||||
type: Literal["latents_collection_output"] = "latents_collection_output"
|
||||
|
||||
collection: list[LatentsField] = OutputField(
|
||||
default_factory=list,
|
||||
description=FieldDescriptions.latents,
|
||||
ui_type=UIType.LatentsCollection,
|
||||
)
|
||||
@ -365,7 +354,7 @@ class LatentsCollectionInvocation(BaseInvocation):
|
||||
|
||||
# Inputs
|
||||
collection: list[LatentsField] = InputField(
|
||||
default=0, description="The collection of latents tensors", ui_type=UIType.LatentsCollection
|
||||
description="The collection of latents tensors", ui_type=UIType.LatentsCollection
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> LatentsCollectionOutput:
|
||||
@ -410,9 +399,7 @@ class ColorCollectionOutput(BaseInvocationOutput):
|
||||
type: Literal["color_collection_output"] = "color_collection_output"
|
||||
|
||||
# Outputs
|
||||
collection: list[ColorField] = OutputField(
|
||||
default_factory=list, description="The output colors", ui_type=UIType.ColorCollection
|
||||
)
|
||||
collection: list[ColorField] = OutputField(description="The output colors", ui_type=UIType.ColorCollection)
|
||||
|
||||
|
||||
@title("Color Primitive")
|
||||
@ -455,7 +442,6 @@ class ConditioningCollectionOutput(BaseInvocationOutput):
|
||||
|
||||
# Outputs
|
||||
collection: list[ConditioningField] = OutputField(
|
||||
default_factory=list,
|
||||
description="The output conditioning tensors",
|
||||
ui_type=UIType.ConditioningCollection,
|
||||
)
|
||||
|
@ -37,7 +37,7 @@ class SDXLRefinerModelLoaderOutput(BaseInvocationOutput):
|
||||
vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE")
|
||||
|
||||
|
||||
@title("SDXL Main Model Loader")
|
||||
@title("SDXL Main Model")
|
||||
@tags("model", "sdxl")
|
||||
class SDXLModelLoaderInvocation(BaseInvocation):
|
||||
"""Loads an sdxl base model, outputting its submodels."""
|
||||
@ -122,7 +122,7 @@ class SDXLModelLoaderInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
|
||||
@title("SDXL Refiner Model Loader")
|
||||
@title("SDXL Refiner Model")
|
||||
@tags("model", "sdxl", "refiner")
|
||||
class SDXLRefinerModelLoaderInvocation(BaseInvocation):
|
||||
"""Loads an sdxl refiner model, outputting its submodels."""
|
||||
|
@ -17,9 +17,9 @@ def create_text_to_image() -> LibraryGraph:
|
||||
description="Converts text to an image",
|
||||
graph=Graph(
|
||||
nodes={
|
||||
"width": IntegerInvocation(id="width", a=512),
|
||||
"height": IntegerInvocation(id="height", a=512),
|
||||
"seed": IntegerInvocation(id="seed", a=-1),
|
||||
"width": IntegerInvocation(id="width", value=512),
|
||||
"height": IntegerInvocation(id="height", value=512),
|
||||
"seed": IntegerInvocation(id="seed", value=-1),
|
||||
"3": NoiseInvocation(id="3"),
|
||||
"4": CompelInvocation(id="4"),
|
||||
"5": CompelInvocation(id="5"),
|
||||
@ -29,15 +29,15 @@ def create_text_to_image() -> LibraryGraph:
|
||||
},
|
||||
edges=[
|
||||
Edge(
|
||||
source=EdgeConnection(node_id="width", field="a"),
|
||||
source=EdgeConnection(node_id="width", field="value"),
|
||||
destination=EdgeConnection(node_id="3", field="width"),
|
||||
),
|
||||
Edge(
|
||||
source=EdgeConnection(node_id="height", field="a"),
|
||||
source=EdgeConnection(node_id="height", field="value"),
|
||||
destination=EdgeConnection(node_id="3", field="height"),
|
||||
),
|
||||
Edge(
|
||||
source=EdgeConnection(node_id="seed", field="a"),
|
||||
source=EdgeConnection(node_id="seed", field="value"),
|
||||
destination=EdgeConnection(node_id="3", field="seed"),
|
||||
),
|
||||
Edge(
|
||||
@ -65,9 +65,9 @@ def create_text_to_image() -> LibraryGraph:
|
||||
exposed_inputs=[
|
||||
ExposedNodeInput(node_path="4", field="prompt", alias="positive_prompt"),
|
||||
ExposedNodeInput(node_path="5", field="prompt", alias="negative_prompt"),
|
||||
ExposedNodeInput(node_path="width", field="a", alias="width"),
|
||||
ExposedNodeInput(node_path="height", field="a", alias="height"),
|
||||
ExposedNodeInput(node_path="seed", field="a", alias="seed"),
|
||||
ExposedNodeInput(node_path="width", field="value", alias="width"),
|
||||
ExposedNodeInput(node_path="height", field="value", alias="height"),
|
||||
ExposedNodeInput(node_path="seed", field="value", alias="seed"),
|
||||
],
|
||||
exposed_outputs=[ExposedNodeOutput(node_path="8", field="image", alias="image")],
|
||||
)
|
||||
|
@ -49,9 +49,36 @@ from invokeai.backend.model_management.model_cache import CacheStats
|
||||
GIG = 1073741824
|
||||
|
||||
|
||||
@dataclass
|
||||
class NodeStats:
|
||||
"""Class for tracking execution stats of an invocation node"""
|
||||
|
||||
calls: int = 0
|
||||
time_used: float = 0.0 # seconds
|
||||
max_vram: float = 0.0 # GB
|
||||
cache_hits: int = 0
|
||||
cache_misses: int = 0
|
||||
cache_high_watermark: int = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class NodeLog:
|
||||
"""Class for tracking node usage"""
|
||||
|
||||
# {node_type => NodeStats}
|
||||
nodes: Dict[str, NodeStats] = field(default_factory=dict)
|
||||
|
||||
|
||||
class InvocationStatsServiceBase(ABC):
|
||||
"Abstract base class for recording node memory/time performance statistics"
|
||||
|
||||
graph_execution_manager: ItemStorageABC["GraphExecutionState"]
|
||||
# {graph_id => NodeLog}
|
||||
_stats: Dict[str, NodeLog]
|
||||
_cache_stats: Dict[str, CacheStats]
|
||||
ram_used: float
|
||||
ram_changed: float
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, graph_execution_manager: ItemStorageABC["GraphExecutionState"]):
|
||||
"""
|
||||
@ -94,8 +121,6 @@ class InvocationStatsServiceBase(ABC):
|
||||
invocation_type: str,
|
||||
time_used: float,
|
||||
vram_used: float,
|
||||
ram_used: float,
|
||||
ram_changed: float,
|
||||
):
|
||||
"""
|
||||
Add timing information on execution of a node. Usually
|
||||
@ -104,8 +129,6 @@ class InvocationStatsServiceBase(ABC):
|
||||
:param invocation_type: String literal type of the node
|
||||
:param time_used: Time used by node's exection (sec)
|
||||
:param vram_used: Maximum VRAM used during exection (GB)
|
||||
:param ram_used: Current RAM available (GB)
|
||||
:param ram_changed: Change in RAM usage over course of the run (GB)
|
||||
"""
|
||||
pass
|
||||
|
||||
@ -116,25 +139,19 @@ class InvocationStatsServiceBase(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update_mem_stats(
|
||||
self,
|
||||
ram_used: float,
|
||||
ram_changed: float,
|
||||
):
|
||||
"""
|
||||
Update the collector with RAM memory usage info.
|
||||
|
||||
@dataclass
|
||||
class NodeStats:
|
||||
"""Class for tracking execution stats of an invocation node"""
|
||||
|
||||
calls: int = 0
|
||||
time_used: float = 0.0 # seconds
|
||||
max_vram: float = 0.0 # GB
|
||||
cache_hits: int = 0
|
||||
cache_misses: int = 0
|
||||
cache_high_watermark: int = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class NodeLog:
|
||||
"""Class for tracking node usage"""
|
||||
|
||||
# {node_type => NodeStats}
|
||||
nodes: Dict[str, NodeStats] = field(default_factory=dict)
|
||||
:param ram_used: How much RAM is currently in use.
|
||||
:param ram_changed: How much RAM changed since last generation.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class InvocationStatsService(InvocationStatsServiceBase):
|
||||
@ -152,12 +169,12 @@ class InvocationStatsService(InvocationStatsServiceBase):
|
||||
class StatsContext:
|
||||
"""Context manager for collecting statistics."""
|
||||
|
||||
invocation: BaseInvocation = None
|
||||
collector: "InvocationStatsServiceBase" = None
|
||||
graph_id: str = None
|
||||
start_time: int = 0
|
||||
ram_used: int = 0
|
||||
model_manager: ModelManagerService = None
|
||||
invocation: BaseInvocation
|
||||
collector: "InvocationStatsServiceBase"
|
||||
graph_id: str
|
||||
start_time: float
|
||||
ram_used: int
|
||||
model_manager: ModelManagerService
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -170,7 +187,7 @@ class InvocationStatsService(InvocationStatsServiceBase):
|
||||
self.invocation = invocation
|
||||
self.collector = collector
|
||||
self.graph_id = graph_id
|
||||
self.start_time = 0
|
||||
self.start_time = 0.0
|
||||
self.ram_used = 0
|
||||
self.model_manager = model_manager
|
||||
|
||||
@ -191,7 +208,7 @@ class InvocationStatsService(InvocationStatsServiceBase):
|
||||
)
|
||||
self.collector.update_invocation_stats(
|
||||
graph_id=self.graph_id,
|
||||
invocation_type=self.invocation.type,
|
||||
invocation_type=self.invocation.type, # type: ignore - `type` is not on the `BaseInvocation` model, but *is* on all invocations
|
||||
time_used=time.time() - self.start_time,
|
||||
vram_used=torch.cuda.max_memory_allocated() / GIG if torch.cuda.is_available() else 0.0,
|
||||
)
|
||||
@ -202,11 +219,6 @@ class InvocationStatsService(InvocationStatsServiceBase):
|
||||
graph_execution_state_id: str,
|
||||
model_manager: ModelManagerService,
|
||||
) -> StatsContext:
|
||||
"""
|
||||
Return a context object that will capture the statistics.
|
||||
:param invocation: BaseInvocation object from the current graph.
|
||||
:param graph_execution_state: GraphExecutionState object from the current session.
|
||||
"""
|
||||
if not self._stats.get(graph_execution_state_id): # first time we're seeing this
|
||||
self._stats[graph_execution_state_id] = NodeLog()
|
||||
self._cache_stats[graph_execution_state_id] = CacheStats()
|
||||
@ -217,7 +229,6 @@ class InvocationStatsService(InvocationStatsServiceBase):
|
||||
self._stats = {}
|
||||
|
||||
def reset_stats(self, graph_execution_id: str):
|
||||
"""Zero the statistics for the indicated graph."""
|
||||
try:
|
||||
self._stats.pop(graph_execution_id)
|
||||
except KeyError:
|
||||
@ -228,12 +239,6 @@ class InvocationStatsService(InvocationStatsServiceBase):
|
||||
ram_used: float,
|
||||
ram_changed: float,
|
||||
):
|
||||
"""
|
||||
Update the collector with RAM memory usage info.
|
||||
|
||||
:param ram_used: How much RAM is currently in use.
|
||||
:param ram_changed: How much RAM changed since last generation.
|
||||
"""
|
||||
self.ram_used = ram_used
|
||||
self.ram_changed = ram_changed
|
||||
|
||||
@ -244,16 +249,6 @@ class InvocationStatsService(InvocationStatsServiceBase):
|
||||
time_used: float,
|
||||
vram_used: float,
|
||||
):
|
||||
"""
|
||||
Add timing information on execution of a node. Usually
|
||||
used internally.
|
||||
:param graph_id: ID of the graph that is currently executing
|
||||
:param invocation_type: String literal type of the node
|
||||
:param time_used: Time used by node's exection (sec)
|
||||
:param vram_used: Maximum VRAM used during exection (GB)
|
||||
:param ram_used: Current RAM available (GB)
|
||||
:param ram_changed: Change in RAM usage over course of the run (GB)
|
||||
"""
|
||||
if not self._stats[graph_id].nodes.get(invocation_type):
|
||||
self._stats[graph_id].nodes[invocation_type] = NodeStats()
|
||||
stats = self._stats[graph_id].nodes[invocation_type]
|
||||
@ -262,14 +257,15 @@ class InvocationStatsService(InvocationStatsServiceBase):
|
||||
stats.max_vram = max(stats.max_vram, vram_used)
|
||||
|
||||
def log_stats(self):
|
||||
"""
|
||||
Send the statistics to the system logger at the info level.
|
||||
Stats will only be printed when the execution of the graph
|
||||
is complete.
|
||||
"""
|
||||
completed = set()
|
||||
errored = set()
|
||||
for graph_id, node_log in self._stats.items():
|
||||
current_graph_state = self.graph_execution_manager.get(graph_id)
|
||||
try:
|
||||
current_graph_state = self.graph_execution_manager.get(graph_id)
|
||||
except Exception:
|
||||
errored.add(graph_id)
|
||||
continue
|
||||
|
||||
if not current_graph_state.is_complete():
|
||||
continue
|
||||
|
||||
@ -302,3 +298,7 @@ class InvocationStatsService(InvocationStatsServiceBase):
|
||||
for graph_id in completed:
|
||||
del self._stats[graph_id]
|
||||
del self._cache_stats[graph_id]
|
||||
|
||||
for graph_id in errored:
|
||||
del self._stats[graph_id]
|
||||
del self._cache_stats[graph_id]
|
||||
|
Reference in New Issue
Block a user