tidy(mm): add default_settings to model config

This commit is contained in:
psychedelicious 2024-03-05 11:24:25 +11:00
parent c953e61294
commit 37b969d339
3 changed files with 27 additions and 10 deletions

View File

@ -11,12 +11,14 @@ from typing import Any, Dict, List, Optional, Set, Union
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from invokeai.app.services.shared.pagination import PaginatedResults from invokeai.app.services.shared.pagination import PaginatedResults
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
from invokeai.backend.model_manager import ( from invokeai.backend.model_manager import (
AnyModelConfig, AnyModelConfig,
BaseModelType, BaseModelType,
ModelFormat, ModelFormat,
ModelType, ModelType,
) )
from invokeai.backend.model_manager.config import ModelDefaultSettings
class DuplicateModelException(Exception): class DuplicateModelException(Exception):
@ -57,6 +59,18 @@ class ModelSummary(BaseModel):
tags: Set[str] = Field(description="tags associated with model") tags: Set[str] = Field(description="tags associated with model")
class ModelRecordChanges(BaseModelExcludeNull, extra="allow"):
"""A set of changes to apply to model metadata.
Only limited changes are valid:
- `default_settings`: the user-configured default settings for this model
"""
default_settings: Optional[ModelDefaultSettings] = Field(
default=None, description="The user-configured default settings for this model"
)
"""The user-configured default settings for this model"""
class ModelRecordServiceBase(ABC): class ModelRecordServiceBase(ABC):
"""Abstract base class for storage and retrieval of model configs.""" """Abstract base class for storage and retrieval of model configs."""

View File

@ -29,6 +29,7 @@ from diffusers.models.modeling_utils import ModelMixin
from pydantic import BaseModel, ConfigDict, Discriminator, Field, Tag, TypeAdapter from pydantic import BaseModel, ConfigDict, Discriminator, Field, Tag, TypeAdapter
from typing_extensions import Annotated, Any, Dict from typing_extensions import Annotated, Any, Dict
from invokeai.app.invocations.constants import SCHEDULER_NAME_VALUES
from invokeai.app.util.misc import uuid_string from invokeai.app.util.misc import uuid_string
from ..raw_model import RawModel from ..raw_model import RawModel
@ -131,6 +132,15 @@ class ModelSourceType(str, Enum):
CivitAI = "civitai" CivitAI = "civitai"
class ModelDefaultSettings(BaseModel):
vae: str | None
vae_precision: str | None
scheduler: SCHEDULER_NAME_VALUES | None
steps: int | None
cfg_scale: float | None
cfg_rescale_multiplier: float | None
class ModelConfigBase(BaseModel): class ModelConfigBase(BaseModel):
"""Base class for model configuration information.""" """Base class for model configuration information."""
@ -148,6 +158,9 @@ class ModelConfigBase(BaseModel):
description="The original API response from the source, as stringified JSON.", default=None description="The original API response from the source, as stringified JSON.", default=None
) )
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None) trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)
default_settings: Optional[ModelDefaultSettings] = Field(
description="Default settings for this model", default=None
)
model_config = ConfigDict(use_enum_values=False, validate_assignment=True) model_config = ConfigDict(use_enum_values=False, validate_assignment=True)

View File

@ -23,7 +23,6 @@ from pydantic.networks import AnyHttpUrl
from requests.sessions import Session from requests.sessions import Session
from typing_extensions import Annotated from typing_extensions import Annotated
from invokeai.app.invocations.constants import SCHEDULER_NAME_VALUES
from invokeai.backend.model_manager import ModelRepoVariant from invokeai.backend.model_manager import ModelRepoVariant
from ..util import select_hf_files from ..util import select_hf_files
@ -42,15 +41,6 @@ class RemoteModelFile(BaseModel):
sha256: Optional[str] = Field(description="SHA256 hash of this model (not always available)", default=None) sha256: Optional[str] = Field(description="SHA256 hash of this model (not always available)", default=None)
class ModelDefaultSettings(BaseModel):
vae: str | None
vae_precision: str | None
scheduler: SCHEDULER_NAME_VALUES | None
steps: int | None
cfg_scale: float | None
cfg_rescale_multiplier: float | None
class ModelMetadataBase(BaseModel): class ModelMetadataBase(BaseModel):
"""Base class for model metadata information.""" """Base class for model metadata information."""