revise config but need to migrate old format to new

This commit is contained in:
Lincoln Stein 2023-08-16 23:30:00 -04:00
parent e373bfca54
commit 503e3bca54

View File

@ -167,7 +167,7 @@ from argparse import ArgumentParser
from omegaconf import OmegaConf, DictConfig, ListConfig
from pathlib import Path
from pydantic import BaseSettings, Field, parse_obj_as
from typing import ClassVar, Dict, List, Set, Literal, Union, get_origin, get_type_hints, get_args
from typing import Any, ClassVar, Dict, List, Set, Literal, Union, get_origin, get_type_hints, get_args
INIT_FILE = Path("invokeai.yaml")
DB_FILE = Path("invokeai.db")
@ -321,7 +321,7 @@ class InvokeAISettings(BaseSettings):
for val in allowed_values:
allowed_types.add(type(val))
allowed_types_list = list(allowed_types)
field_type = allowed_types_list[0] if len(allowed_types) == 1 else Union[allowed_types_list] # type: ignore
field_type = allowed_types_list[0] if len(allowed_types) == 1 else int_or_float_or_str
argparse_group.add_argument(
f"--{name}",
@ -331,6 +331,15 @@ class InvokeAISettings(BaseSettings):
choices=allowed_values,
help=field.field_info.description,
)
elif get_origin(field_type) == Union:
argparse_group.add_argument(
f"--{name}",
dest=name,
type=int_or_float_or_str,
default=default,
help=field.field_info.description,
)
elif get_origin(field_type) == list:
argparse_group.add_argument(
@ -390,14 +399,15 @@ class InvokeAIAppConfig(InvokeAISettings):
log_tokenization : bool = Field(default=False, description="Enable logging of parsed prompt tokens.", category='Features')
patchmatch : bool = Field(default=True, description="Enable/disable patchmatch inpaint code", category='Features')
always_use_cpu : bool = Field(default=False, description="If true, use the CPU for rendering even if a GPU is available.", category='Memory/Performance')
free_gpu_mem : bool = Field(default=False, description="If true, purge model from GPU after each generation.", category='Memory/Performance')
max_cache_size : float = Field(default=6.0, gt=0, description="Maximum memory amount used by model cache for rapid switching", category='Memory/Performance')
max_vram_cache_size : float = Field(default=2.75, ge=0, description="Amount of VRAM reserved for model storage", category='Memory/Performance')
precision : Literal[tuple(['auto','float16','float32','autocast'])] = Field(default='auto',description='Floating point precision', category='Memory/Performance')
sequential_guidance : bool = Field(default=False, description="Whether to calculate guidance in serial instead of in parallel, lowering memory requirements", category='Memory/Performance')
xformers_enabled : bool = Field(default=True, description="Enable/disable memory-efficient attention", category='Memory/Performance')
tiled_decode : bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty)", category='Memory/Performance')
ram : Union[float,Literal['auto']] = Field(default=6.0, gt=0, description="Maximum memory amount used by model cache for rapid switching (floating point number or 'auto')", category='Cache')
vram : Union[float,Literal['auto']] = Field(default=2.75, ge=0, description="Amount of VRAM reserved for model storage (floating point number or 'auto')", category='Cache')
lazy_offload : bool = Field(default=True, description='Keep models in VRAM until their space is needed', category='Cache')
precision : Literal[tuple(['auto','float16','float32','autocast'])] = Field(default='auto',description='Floating point precision', category='Device')
device : Literal[tuple(['cpu','cuda','mps','cuda','cuda:1','auto'])] = Field(default='auto',description='Generation device', category='Device')
sequential_guidance : bool = Field(default=False, description="Whether to calculate guidance in serial instead of in parallel, lowering memory requirements", category='Generation')
attention_type : Literal[tuple(['auto','normal','xformers','sliced','torch-sdp'])] = Field(default='auto', description='Attention type', category='Generation')
attention_slice_size: Literal[tuple(['auto','max',1,2,3,4,5,6,7,8])] = Field(default='auto', description='Slice size, valid when attention_type=="sliced"', category='Generation')
tiled_decode : bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty)", category='Generation')
root : Path = Field(default=None, description='InvokeAI runtime root directory', category='Paths')
autoimport_dir : Path = Field(default='autoimport', description='Path to a directory of models files to be imported on startup.', category='Paths')
@ -541,6 +551,11 @@ class InvokeAIAppConfig(InvokeAISettings):
"""Return true if precision set to float32"""
return self.precision == "float32"
@property
def xformers_enabled(self) -> bool:
"""Return true if attention_type=='xformers'."""
return self.attention_type=='xformers'
@property
def disable_xformers(self) -> bool:
"""Return true if xformers_enabled is false"""
@ -561,6 +576,16 @@ class InvokeAIAppConfig(InvokeAISettings):
"""invisible watermark node is always active and disabled from Web UIe"""
return True
@property
def max_cache_size(self) -> Union[str, float]:
"""return value of ram attribute."""
return self.ram
@property
def max_vram_cache_size(self) -> Union[str, float]:
"""return value of vram attribute."""
return self.vram
@staticmethod
def find_root() -> Path:
"""
@ -569,6 +594,18 @@ class InvokeAIAppConfig(InvokeAISettings):
"""
return _find_root()
# @property
# def attention_slice_size(self) -> Union[str, int]:
# """
# Return one of "auto", "max", or 1-8.
# """
# size = self.attention_slice
# try:
# size= int(size)
# assert size > 0
# except ValueError:
# pass
# return size
class PagingArgumentParser(argparse.ArgumentParser):
"""
@ -586,3 +623,18 @@ def get_invokeai_config(**kwargs) -> InvokeAIAppConfig:
Legacy function which returns InvokeAIAppConfig.get_config()
"""
return InvokeAIAppConfig.get_config(**kwargs)
def int_or_float_or_str(value:Any) -> Union[int, float, str]:
"""
Workaround for argparse type checking.
"""
try:
return int(value)
except:
pass
try:
return float(value)
except:
pass
return str(value)