Allow lists of basemodel objects in omegaconf

This commit is contained in:
Brandon Rising 2024-03-11 14:20:57 -04:00 committed by Mary Hipp Rogers
parent 96730107d1
commit 97afa6e2a6

View File

@ -15,6 +15,7 @@ import os
import sys
from argparse import ArgumentParser
from pathlib import Path
from pydantic import BaseModel
from typing import Any, ClassVar, Dict, List, Literal, Optional, Union, get_args, get_origin, get_type_hints
from omegaconf import DictConfig, ListConfig, OmegaConf
@ -62,6 +63,18 @@ class InvokeAISettings(BaseSettings):
assert isinstance(category, str)
if category not in field_dict[type]:
field_dict[type][category] = {}
if isinstance(value, BaseModel):
dump = value.model_dump(exclude_defaults=True, exclude_unset=True, exclude_none=True)
field_dict[type][category][name] = dump
continue
if isinstance(value, list):
val_list: List[Dict[str, Any]] = []
for list_val in value:
if isinstance(list_val, BaseModel):
dump = list_val.model_dump(exclude_defaults=True, exclude_unset=True, exclude_none=True)
val_list.append(dump)
field_dict[type][category][name] = val_list
continue
# keep paths as strings to make it easier to read
field_dict[type][category][name] = str(value) if isinstance(value, Path) else value
conf = OmegaConf.create(field_dict)