diff --git a/invokeai/app/api/routers/models.py b/invokeai/app/api/routers/models.py index 1c1ed45705..1961a79c47 100644 --- a/invokeai/app/api/routers/models.py +++ b/invokeai/app/api/routers/models.py @@ -29,12 +29,9 @@ ConvertModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)] MergeModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)] ImportModelAttributes = Union[tuple(OPENAPI_MODEL_CONFIGS)] -# class ModelsList(BaseModel): -# models: list[Union[tuple(OPENAPI_MODEL_CONFIGS)]] - class ModelsList(BaseModel): - models: List[ModelConfigBase] + models: list[Union[tuple(OPENAPI_MODEL_CONFIGS)]] @models_router.get( @@ -51,9 +48,9 @@ async def list_models( if base_models and len(base_models) > 0: models_raw = list() for base_model in base_models: - models_raw.extend(manager.list_models(base_model=base_model, model_type=model_type)) + models_raw.extend([x.dict() for x in manager.list_models(base_model=base_model, model_type=model_type)]) else: - models_raw = manager.list_models(model_type=model_type) + models_raw = [x.dict() for x in manager.list_models(model_type=model_type)] models = parse_obj_as(ModelsList, {"models": models_raw}) return models diff --git a/invokeai/backend/model_manager/config.py b/invokeai/backend/model_manager/config.py index 378e87306d..a148c2b93a 100644 --- a/invokeai/backend/model_manager/config.py +++ b/invokeai/backend/model_manager/config.py @@ -10,7 +10,7 @@ Typical usage: base_model='sd-1', model_type='main', config='configs/stable-diffusion/v1-inference.yaml', - model_variant='normal', + variant='normal', model_format='checkpoint' ) config = ModelConfigFactory.make_config(raw) @@ -173,12 +173,18 @@ class VaeDiffusersConfig(ModelConfigBase): model_format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers -class ControlNetDiffusersConfig(ModelConfigBase): +class ControlNetDiffusersConfig(DiffusersConfig): """Model config for ControlNet models (diffusers version).""" model_format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers +class ControlNetCheckpointConfig(CheckpointConfig): + """Model config for ControlNet models (diffusers version).""" + + model_format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint + + class TextualInversionConfig(ModelConfigBase): """Model config for textual inversion embeddings.""" @@ -189,12 +195,14 @@ class MainConfig(ModelConfigBase): """Model config for main models.""" vae: Optional[str] = Field(None) - model_variant: ModelVariantType = ModelVariantType.Normal + variant: ModelVariantType = ModelVariantType.Normal class MainCheckpointConfig(CheckpointConfig, MainConfig): """Model config for main checkpoint models.""" + config: str + class MainDiffusersConfig(DiffusersConfig, MainConfig): """Model config for main diffusers models.""" diff --git a/invokeai/backend/model_manager/download/__init__.py b/invokeai/backend/model_manager/download/__init__.py index 59bd617102..85f60b8d1d 100644 --- a/invokeai/backend/model_manager/download/__init__.py +++ b/invokeai/backend/model_manager/download/__init__.py @@ -7,6 +7,8 @@ from .base import ( # noqa F401 UnknownJobIDException, DownloadJobBase, ModelSourceMetadata, + REPO_ID_RE, + HTTP_RE, ) from .queue import DownloadQueue # noqa F401 diff --git a/invokeai/backend/model_manager/download/base.py b/invokeai/backend/model_manager/download/base.py index 1d864c289b..167053d441 100644 --- a/invokeai/backend/model_manager/download/base.py +++ b/invokeai/backend/model_manager/download/base.py @@ -11,6 +11,10 @@ from typing import List, Optional, Callable, Union from pydantic import BaseModel, Field from pydantic.networks import AnyHttpUrl +# Used to distinguish between repo_id sources and URL sources +REPO_ID_RE = r"^[\w-]+/[.\w-]+$" +HTTP_RE = r"^https?://" + class DownloadJobStatus(str, Enum): """State of a download job.""" @@ -49,7 +53,9 @@ class DownloadJobBase(BaseModel): id: int = Field(description="Numeric ID of this job", default=-1) # default id is a placeholder source: str = Field(description="URL or repo_id to download") destination: Path = Field(description="Destination of URL on local disk") - metadata: Optional[ModelSourceMetadata] = Field(description="Model metadata (source-specific)", default=None) + metadata: ModelSourceMetadata = Field( + description="Model metadata (source-specific)", default_factory=ModelSourceMetadata + ) access_token: Optional[str] = Field(description="access token needed to access this resource") status: DownloadJobStatus = Field(default=DownloadJobStatus.IDLE, description="Status of the download") bytes: int = Field(default=0, description="Bytes downloaded so far") diff --git a/invokeai/backend/model_manager/download/queue.py b/invokeai/backend/model_manager/download/queue.py index b89f9d5291..d54ddb38f3 100644 --- a/invokeai/backend/model_manager/download/queue.py +++ b/invokeai/backend/model_manager/download/queue.py @@ -22,6 +22,7 @@ from huggingface_hub import HfApi, hf_hub_url from invokeai.backend.util.logging import InvokeAILogger from invokeai.app.services.config import InvokeAIAppConfig +from . import REPO_ID_RE, HTTP_RE from .base import ( DownloadQueueBase, DownloadJobStatus, @@ -59,7 +60,7 @@ class DownloadJobRepoID(DownloadJobBase): @validator("source") @classmethod def _validate_source(cls, v: str) -> str: - if not re.match(r"^[\w-]+/[\w-]+$", v): + if not re.match(REPO_ID_RE, v): raise ValidationError(f"{v} invalid repo_id") return v @@ -123,10 +124,10 @@ class DownloadQueue(DownloadQueueBase): if Path(source).exists(): cls = DownloadJobPath - elif re.match(r"^[\w-]+/[\w-]+$", str(source)): + elif re.match(REPO_ID_RE, str(source)): cls = DownloadJobRepoID kwargs = dict(variant=variant) - elif re.match(r"^https?://", str(source)): + elif re.match(HTTP_RE, str(source)): cls = DownloadJobURL else: raise NotImplementedError(f"Don't know what to do with this type of source: {source}") diff --git a/invokeai/backend/model_manager/install.py b/invokeai/backend/model_manager/install.py index 1c19f65b0b..4c9f27f0cf 100644 --- a/invokeai/backend/model_manager/install.py +++ b/invokeai/backend/model_manager/install.py @@ -66,6 +66,8 @@ from .download import ( DownloadJobBase, ModelSourceMetadata, DownloadEventHandler, + REPO_ID_RE, + HTTP_RE, ) from .download.queue import DownloadJobURL, DownloadJobRepoID, DownloadJobPath from .hash import FastModelHash @@ -481,10 +483,10 @@ class ModelInstall(ModelInstallBase): models_dir = self._config.models_path self._tmpdir = self._tmpdir or tempfile.TemporaryDirectory(dir=models_dir) - if re.match(r"^[\w-]+/[\w-]+$", str(source)): + if re.match(REPO_ID_RE, str(source)): cls = ModelInstallRepoIDJob kwargs = dict(variant=variant) - elif re.match(r"^https?://", str(source)): + elif re.match(HTTP_RE, str(source)): cls = ModelInstallURLJob kwargs = {} else: diff --git a/invokeai/backend/model_manager/models/controlnet.py b/invokeai/backend/model_manager/models/controlnet.py index 359df91a82..426c3ec712 100644 --- a/invokeai/backend/model_manager/models/controlnet.py +++ b/invokeai/backend/model_manager/models/controlnet.py @@ -7,7 +7,7 @@ import torch import invokeai.backend.util.logging as logger from invokeai.app.services.config import InvokeAIAppConfig - +from ..config import ControlNetDiffusersConfig, ControlNetCheckpointConfig from .base import ( BaseModelType, EmptyConfigLoader, @@ -32,12 +32,11 @@ class ControlNetModel(ModelBase): # model_class: Type # model_size: int - class DiffusersConfig(ModelConfigBase): - model_format: Literal[ControlNetModelFormat.Diffusers] + class DiffusersConfig(ControlNetDiffusersConfig): + model_format: Literal[ControlNetModelFormat.Diffusers] = ControlNetModelFormat.Diffusers - class CheckpointConfig(ModelConfigBase): - model_format: Literal[ControlNetModelFormat.Checkpoint] - config: str + class CheckpointConfig(ControlNetCheckpointConfig): + model_format: Literal[ControlNetModelFormat.Checkpoint] = ControlNetModelFormat.Checkpoint def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType): assert model_type == ModelType.ControlNet diff --git a/invokeai/backend/model_manager/models/lora.py b/invokeai/backend/model_manager/models/lora.py index b6f321d60b..a9bd95645a 100644 --- a/invokeai/backend/model_manager/models/lora.py +++ b/invokeai/backend/model_manager/models/lora.py @@ -2,11 +2,11 @@ import bisect import os from enum import Enum from pathlib import Path -from typing import Dict, Optional, Union +from typing import Dict, Optional, Union, Literal import torch from safetensors.torch import load_file - +from ..config import LoRAConfig from .base import ( BaseModelType, InvalidModelException, @@ -27,8 +27,8 @@ class LoRAModelFormat(str, Enum): class LoRAModel(ModelBase): # model_size: int - class Config(ModelConfigBase): - model_format: LoRAModelFormat # TODO: + class Config(LoRAConfig): + model_format: Literal[LoRAModelFormat.LyCORIS] # TODO: def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType): assert model_type == ModelType.Lora diff --git a/invokeai/backend/model_manager/models/sdxl.py b/invokeai/backend/model_manager/models/sdxl.py index 41586e35b9..b054972f62 100644 --- a/invokeai/backend/model_manager/models/sdxl.py +++ b/invokeai/backend/model_manager/models/sdxl.py @@ -5,7 +5,7 @@ from typing import Literal, Optional from omegaconf import OmegaConf from pydantic import Field - +from ..config import MainDiffusersConfig, MainCheckpointConfig from .base import ( BaseModelType, DiffusersModel, @@ -25,16 +25,11 @@ class StableDiffusionXLModelFormat(str, Enum): class StableDiffusionXLModel(DiffusersModel): # TODO: check that configs overwriten properly - class DiffusersConfig(ModelConfigBase): + class DiffusersConfig(MainDiffusersConfig): model_format: Literal[StableDiffusionXLModelFormat.Diffusers] - vae: Optional[str] = Field(None) - variant: ModelVariantType class CheckpointConfig(ModelConfigBase): model_format: Literal[StableDiffusionXLModelFormat.Checkpoint] - vae: Optional[str] = Field(None) - config: str - variant: ModelVariantType def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType): assert base_model in {BaseModelType.StableDiffusionXL, BaseModelType.StableDiffusionXLRefiner} diff --git a/invokeai/backend/model_manager/models/stable_diffusion.py b/invokeai/backend/model_manager/models/stable_diffusion.py index 607a8f8fbe..57f0124cfa 100644 --- a/invokeai/backend/model_manager/models/stable_diffusion.py +++ b/invokeai/backend/model_manager/models/stable_diffusion.py @@ -12,6 +12,7 @@ from ..config import SilenceWarnings import invokeai.backend.util.logging as logger from invokeai.app.services.config import InvokeAIAppConfig +from ..config import MainCheckpointConfig, MainDiffusersConfig from .base import ( BaseModelType, DiffusersModel, @@ -32,16 +33,11 @@ class StableDiffusion1ModelFormat(str, Enum): class StableDiffusion1Model(DiffusersModel): - class DiffusersConfig(ModelConfigBase): + class DiffusersConfig(MainDiffusersConfig): model_format: Literal[StableDiffusion1ModelFormat.Diffusers] - vae: Optional[str] = Field(None) - variant: ModelVariantType - class CheckpointConfig(ModelConfigBase): + class CheckpointConfig(MainCheckpointConfig): model_format: Literal[StableDiffusion1ModelFormat.Checkpoint] - vae: Optional[str] = Field(None) - config: str - variant: ModelVariantType def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType): assert base_model == BaseModelType.StableDiffusion1 diff --git a/invokeai/backend/model_manager/models/stable_diffusion_onnx.py b/invokeai/backend/model_manager/models/stable_diffusion_onnx.py index 2d0dd22c43..0d1150788a 100644 --- a/invokeai/backend/model_manager/models/stable_diffusion_onnx.py +++ b/invokeai/backend/model_manager/models/stable_diffusion_onnx.py @@ -2,7 +2,7 @@ from enum import Enum from typing import Literal from diffusers import OnnxRuntimeModel - +from ..config import ONNXSD1Config, ONNXSD2Config from .base import ( BaseModelType, DiffusersModel, @@ -21,9 +21,8 @@ class StableDiffusionOnnxModelFormat(str, Enum): class ONNXStableDiffusion1Model(DiffusersModel): - class Config(ModelConfigBase): + class Config(ONNXSD1Config): model_format: Literal[StableDiffusionOnnxModelFormat.Onnx] - variant: ModelVariantType def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType): assert base_model == BaseModelType.StableDiffusion1 @@ -80,11 +79,8 @@ class ONNXStableDiffusion1Model(DiffusersModel): class ONNXStableDiffusion2Model(DiffusersModel): # TODO: check that configs overwriten properly - class Config(ModelConfigBase): + class Config(ONNXSD2Config): model_format: Literal[StableDiffusionOnnxModelFormat.Onnx] - variant: ModelVariantType - prediction_type: SchedulerPredictionType - upcast_attention: bool def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType): assert base_model == BaseModelType.StableDiffusion2 diff --git a/invokeai/backend/model_manager/models/textual_inversion.py b/invokeai/backend/model_manager/models/textual_inversion.py index b59e635045..b70ae0b9da 100644 --- a/invokeai/backend/model_manager/models/textual_inversion.py +++ b/invokeai/backend/model_manager/models/textual_inversion.py @@ -1,10 +1,11 @@ import os -from typing import Optional +from typing import Optional, Literal import torch # TODO: naming from ..lora import TextualInversionModel as TextualInversionModelRaw +from ..config import ModelFormat, TextualInversionConfig from .base import ( BaseModelType, InvalidModelException, @@ -20,8 +21,15 @@ from .base import ( class TextualInversionModel(ModelBase): # model_size: int - class Config(ModelConfigBase): - model_format: None + class FolderConfig(TextualInversionConfig): + """Config for embeddings that are represented as a folder containing learned_embeds.bin.""" + + model_format: Literal[ModelFormat.EmbeddingFolder] + + class FileConfig(TextualInversionConfig): + """Config for embeddings that are contained in safetensors/checkpoint files.""" + + model_format: Literal[ModelFormat.EmbeddingFile] def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType): assert model_type == ModelType.TextualInversion diff --git a/invokeai/backend/model_manager/models/vae.py b/invokeai/backend/model_manager/models/vae.py index 637160c69b..ab7d5ba94d 100644 --- a/invokeai/backend/model_manager/models/vae.py +++ b/invokeai/backend/model_manager/models/vae.py @@ -1,14 +1,14 @@ import os from enum import Enum from pathlib import Path -from typing import Optional +from typing import Optional, Literal import safetensors import torch from omegaconf import OmegaConf from invokeai.app.services.config import InvokeAIAppConfig - +from ..config import VaeDiffusersConfig, VaeCheckpointConfig from .base import ( BaseModelType, EmptyConfigLoader, @@ -34,8 +34,11 @@ class VaeModel(ModelBase): # vae_class: Type # model_size: int - class Config(ModelConfigBase): - model_format: VaeModelFormat + class DiffusersConfig(VaeDiffusersConfig): + model_format: Literal[VaeModelFormat.Diffusers] = VaeModelFormat.Diffusers + + class CheckpointConfig(VaeCheckpointConfig): + model_format: Literal[VaeModelFormat.Checkpoint] = VaeModelFormat.Checkpoint def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType): assert model_type == ModelType.Vae