mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
c238a7f18b
Upgrade pydantic and fastapi to latest. - pydantic~=2.4.2 - fastapi~=103.2 - fastapi-events~=0.9.1 **Big Changes** There are a number of logic changes needed to support pydantic v2. Most changes are very simple, like using the new methods to serialized and deserialize models, but there are a few more complex changes. **Invocations** The biggest change relates to invocation creation, instantiation and validation. Because pydantic v2 moves all validation logic into the rust pydantic-core, we may no longer directly stick our fingers into the validation pie. Previously, we (ab)used models and fields to allow invocation fields to be optional at instantiation, but required when `invoke()` is called. We directly manipulated the fields and invocation models when calling `invoke()`. With pydantic v2, this is much more involved. Changes to the python wrapper do not propagate down to the rust validation logic - you have to rebuild the model. This causes problem with concurrent access to the invocation classes and is not a free operation. This logic has been totally refactored and we do not need to change the model any more. The details are in `baseinvocation.py`, in the `InputField` function and `BaseInvocation.invoke_internal()` method. In the end, this implementation is cleaner. **Invocation Fields** In pydantic v2, you can no longer directly add or remove fields from a model. Previously, we did this to add the `type` field to invocations. **Invocation Decorators** With pydantic v2, we instead use the imperative `create_model()` API to create a new model with the additional field. This is done in `baseinvocation.py` in the `invocation()` wrapper. A similar technique is used for `invocation_output()`. **Minor Changes** There are a number of minor changes around the pydantic v2 models API. **Protected `model_` Namespace** All models' pydantic-provided methods and attributes are prefixed with `model_` and this is considered a protected namespace. This causes some conflict, because "model" means something to us, and we have a ton of pydantic models with attributes starting with "model_". Forunately, there are no direct conflicts. However, in any pydantic model where we define an attribute or method that starts with "model_", we must tell set the protected namespaces to an empty tuple. ```py class IPAdapterModelField(BaseModel): model_name: str = Field(description="Name of the IP-Adapter model") base_model: BaseModelType = Field(description="Base model") model_config = ConfigDict(protected_namespaces=()) ``` **Model Serialization** Pydantic models no longer have `Model.dict()` or `Model.json()`. Instead, we use `Model.model_dump()` or `Model.model_dump_json()`. **Model Deserialization** Pydantic models no longer have `Model.parse_obj()` or `Model.parse_raw()`, and there are no `parse_raw_as()` or `parse_obj_as()` functions. Instead, you need to create a `TypeAdapter` object to parse python objects or JSON into a model. ```py adapter_graph = TypeAdapter(Graph) deserialized_graph_from_json = adapter_graph.validate_json(graph_json) deserialized_graph_from_dict = adapter_graph.validate_python(graph_dict) ``` **Field Customisation** Pydantic `Field`s no longer accept arbitrary args. Now, you must put all additional arbitrary args in a `json_schema_extra` arg on the field. **Schema Customisation** FastAPI and pydantic schema generation now follows the OpenAPI version 3.1 spec. This necessitates two changes: - Our schema customization logic has been revised - Schema parsing to build node templates has been revised The specific aren't important, but this does present additional surface area for bugs. **Performance Improvements** Pydantic v2 is a full rewrite with a rust backend. This offers a substantial performance improvement (pydantic claims 5x to 50x depending on the task). We'll notice this the most during serialization and deserialization of sessions/graphs, which happens very very often - a couple times per node. I haven't done any benchmarks, but anecdotally, graph execution is much faster. Also, very larges graphs - like with massive iterators - are much, much faster.
219 lines
7.8 KiB
Python
219 lines
7.8 KiB
Python
# Copyright (c) 2023 Lincoln Stein (https://github.com/lstein) and the InvokeAI Development Team
|
|
|
|
"""
|
|
Base class for the InvokeAI configuration system.
|
|
It defines a type of pydantic BaseSettings object that
|
|
is able to read and write from an omegaconf-based config file,
|
|
with overriding of settings from environment variables and/or
|
|
the command line.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import os
|
|
import sys
|
|
from argparse import ArgumentParser
|
|
from pathlib import Path
|
|
from typing import ClassVar, Dict, List, Literal, Optional, Union, get_args, get_origin, get_type_hints
|
|
|
|
from omegaconf import DictConfig, ListConfig, OmegaConf
|
|
from pydantic_settings import BaseSettings, SettingsConfigDict
|
|
|
|
from invokeai.app.services.config.config_common import PagingArgumentParser, int_or_float_or_str
|
|
|
|
|
|
class InvokeAISettings(BaseSettings):
|
|
"""
|
|
Runtime configuration settings in which default values are
|
|
read from an omegaconf .yaml file.
|
|
"""
|
|
|
|
initconf: ClassVar[Optional[DictConfig]] = None
|
|
argparse_groups: ClassVar[Dict] = {}
|
|
|
|
model_config = SettingsConfigDict(env_file_encoding="utf-8", arbitrary_types_allowed=True, case_sensitive=True)
|
|
|
|
def parse_args(self, argv: Optional[list] = sys.argv[1:]):
|
|
parser = self.get_parser()
|
|
opt, unknown_opts = parser.parse_known_args(argv)
|
|
if len(unknown_opts) > 0:
|
|
print("Unknown args:", unknown_opts)
|
|
for name in self.model_fields:
|
|
if name not in self._excluded():
|
|
value = getattr(opt, name)
|
|
if isinstance(value, ListConfig):
|
|
value = list(value)
|
|
elif isinstance(value, DictConfig):
|
|
value = dict(value)
|
|
setattr(self, name, value)
|
|
|
|
def to_yaml(self) -> str:
|
|
"""
|
|
Return a YAML string representing our settings. This can be used
|
|
as the contents of `invokeai.yaml` to restore settings later.
|
|
"""
|
|
cls = self.__class__
|
|
type = get_args(get_type_hints(cls)["type"])[0]
|
|
field_dict = dict({type: dict()})
|
|
for name, field in self.model_fields.items():
|
|
if name in cls._excluded_from_yaml():
|
|
continue
|
|
category = (
|
|
field.json_schema_extra.get("category", "Uncategorized") if field.json_schema_extra else "Uncategorized"
|
|
)
|
|
value = getattr(self, name)
|
|
if category not in field_dict[type]:
|
|
field_dict[type][category] = dict()
|
|
# keep paths as strings to make it easier to read
|
|
field_dict[type][category][name] = str(value) if isinstance(value, Path) else value
|
|
conf = OmegaConf.create(field_dict)
|
|
return OmegaConf.to_yaml(conf)
|
|
|
|
@classmethod
|
|
def add_parser_arguments(cls, parser):
|
|
if "type" in get_type_hints(cls):
|
|
settings_stanza = get_args(get_type_hints(cls)["type"])[0]
|
|
else:
|
|
settings_stanza = "Uncategorized"
|
|
|
|
env_prefix = getattr(cls.model_config, "env_prefix", None)
|
|
env_prefix = env_prefix if env_prefix is not None else settings_stanza.upper()
|
|
|
|
initconf = (
|
|
cls.initconf.get(settings_stanza)
|
|
if cls.initconf and settings_stanza in cls.initconf
|
|
else OmegaConf.create()
|
|
)
|
|
|
|
# create an upcase version of the environment in
|
|
# order to achieve case-insensitive environment
|
|
# variables (the way Windows does)
|
|
upcase_environ = dict()
|
|
for key, value in os.environ.items():
|
|
upcase_environ[key.upper()] = value
|
|
|
|
fields = cls.model_fields
|
|
cls.argparse_groups = {}
|
|
|
|
for name, field in fields.items():
|
|
if name not in cls._excluded():
|
|
current_default = field.default
|
|
|
|
category = (
|
|
field.json_schema_extra.get("category", "Uncategorized")
|
|
if field.json_schema_extra
|
|
else "Uncategorized"
|
|
)
|
|
env_name = env_prefix + "_" + name
|
|
if category in initconf and name in initconf.get(category):
|
|
field.default = initconf.get(category).get(name)
|
|
if env_name.upper() in upcase_environ:
|
|
field.default = upcase_environ[env_name.upper()]
|
|
cls.add_field_argument(parser, name, field)
|
|
|
|
field.default = current_default
|
|
|
|
@classmethod
|
|
def cmd_name(cls, command_field: str = "type") -> str:
|
|
hints = get_type_hints(cls)
|
|
if command_field in hints:
|
|
return get_args(hints[command_field])[0]
|
|
else:
|
|
return "Uncategorized"
|
|
|
|
@classmethod
|
|
def get_parser(cls) -> ArgumentParser:
|
|
parser = PagingArgumentParser(
|
|
prog=cls.cmd_name(),
|
|
description=cls.__doc__,
|
|
)
|
|
cls.add_parser_arguments(parser)
|
|
return parser
|
|
|
|
@classmethod
|
|
def _excluded(cls) -> List[str]:
|
|
# internal fields that shouldn't be exposed as command line options
|
|
return ["type", "initconf"]
|
|
|
|
@classmethod
|
|
def _excluded_from_yaml(cls) -> List[str]:
|
|
# combination of deprecated parameters and internal ones that shouldn't be exposed as invokeai.yaml options
|
|
return [
|
|
"type",
|
|
"initconf",
|
|
"version",
|
|
"from_file",
|
|
"model",
|
|
"root",
|
|
"max_cache_size",
|
|
"max_vram_cache_size",
|
|
"always_use_cpu",
|
|
"free_gpu_mem",
|
|
"xformers_enabled",
|
|
"tiled_decode",
|
|
]
|
|
|
|
@classmethod
|
|
def add_field_argument(cls, command_parser, name: str, field, default_override=None):
|
|
field_type = get_type_hints(cls).get(name)
|
|
default = (
|
|
default_override
|
|
if default_override is not None
|
|
else field.default
|
|
if field.default_factory is None
|
|
else field.default_factory()
|
|
)
|
|
if category := (field.json_schema_extra.get("category", None) if field.json_schema_extra else None):
|
|
if category not in cls.argparse_groups:
|
|
cls.argparse_groups[category] = command_parser.add_argument_group(category)
|
|
argparse_group = cls.argparse_groups[category]
|
|
else:
|
|
argparse_group = command_parser
|
|
|
|
if get_origin(field_type) == Literal:
|
|
allowed_values = get_args(field.annotation)
|
|
allowed_types = set()
|
|
for val in allowed_values:
|
|
allowed_types.add(type(val))
|
|
allowed_types_list = list(allowed_types)
|
|
field_type = allowed_types_list[0] if len(allowed_types) == 1 else int_or_float_or_str
|
|
|
|
argparse_group.add_argument(
|
|
f"--{name}",
|
|
dest=name,
|
|
type=field_type,
|
|
default=default,
|
|
choices=allowed_values,
|
|
help=field.description,
|
|
)
|
|
|
|
elif get_origin(field_type) == Union:
|
|
argparse_group.add_argument(
|
|
f"--{name}",
|
|
dest=name,
|
|
type=int_or_float_or_str,
|
|
default=default,
|
|
help=field.description,
|
|
)
|
|
|
|
elif get_origin(field_type) == list:
|
|
argparse_group.add_argument(
|
|
f"--{name}",
|
|
dest=name,
|
|
nargs="*",
|
|
type=field.annotation,
|
|
default=default,
|
|
action=argparse.BooleanOptionalAction if field.annotation == bool else "store",
|
|
help=field.description,
|
|
)
|
|
else:
|
|
argparse_group.add_argument(
|
|
f"--{name}",
|
|
dest=name,
|
|
type=field.annotation,
|
|
default=default,
|
|
action=argparse.BooleanOptionalAction if field.annotation == bool else "store",
|
|
help=field.description,
|
|
)
|