Allow lists of basemodel objects in omegaconf

This commit is contained in:
Brandon Rising 2024-03-11 14:20:57 -04:00 committed by Mary Hipp Rogers
parent 96730107d1
commit 97afa6e2a6

View File

@ -15,6 +15,7 @@ import os
import sys import sys
from argparse import ArgumentParser from argparse import ArgumentParser
from pathlib import Path from pathlib import Path
from pydantic import BaseModel
from typing import Any, ClassVar, Dict, List, Literal, Optional, Union, get_args, get_origin, get_type_hints from typing import Any, ClassVar, Dict, List, Literal, Optional, Union, get_args, get_origin, get_type_hints
from omegaconf import DictConfig, ListConfig, OmegaConf from omegaconf import DictConfig, ListConfig, OmegaConf
@ -62,6 +63,18 @@ class InvokeAISettings(BaseSettings):
assert isinstance(category, str) assert isinstance(category, str)
if category not in field_dict[type]: if category not in field_dict[type]:
field_dict[type][category] = {} field_dict[type][category] = {}
if isinstance(value, BaseModel):
dump = value.model_dump(exclude_defaults=True, exclude_unset=True, exclude_none=True)
field_dict[type][category][name] = dump
continue
if isinstance(value, list):
val_list: List[Dict[str, Any]] = []
for list_val in value:
if isinstance(list_val, BaseModel):
dump = list_val.model_dump(exclude_defaults=True, exclude_unset=True, exclude_none=True)
val_list.append(dump)
field_dict[type][category][name] = val_list
continue
# keep paths as strings to make it easier to read # keep paths as strings to make it easier to read
field_dict[type][category][name] = str(value) if isinstance(value, Path) else value field_dict[type][category][name] = str(value) if isinstance(value, Path) else value
conf = OmegaConf.create(field_dict) conf = OmegaConf.create(field_dict)