list_models() API call now working

This commit is contained in:
Lincoln Stein
2023-09-15 21:58:28 -04:00
parent 3529925234
commit b7789bb7bb
13 changed files with 66 additions and 53 deletions

View File

@ -29,12 +29,9 @@ ConvertModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
MergeModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
ImportModelAttributes = Union[tuple(OPENAPI_MODEL_CONFIGS)]
# class ModelsList(BaseModel):
# models: list[Union[tuple(OPENAPI_MODEL_CONFIGS)]]
class ModelsList(BaseModel):
models: List[ModelConfigBase]
models: list[Union[tuple(OPENAPI_MODEL_CONFIGS)]]
@models_router.get(
@ -51,9 +48,9 @@ async def list_models(
if base_models and len(base_models) > 0:
models_raw = list()
for base_model in base_models:
models_raw.extend(manager.list_models(base_model=base_model, model_type=model_type))
models_raw.extend([x.dict() for x in manager.list_models(base_model=base_model, model_type=model_type)])
else:
models_raw = manager.list_models(model_type=model_type)
models_raw = [x.dict() for x in manager.list_models(model_type=model_type)]
models = parse_obj_as(ModelsList, {"models": models_raw})
return models

View File

@ -10,7 +10,7 @@ Typical usage:
base_model='sd-1',
model_type='main',
config='configs/stable-diffusion/v1-inference.yaml',
model_variant='normal',
variant='normal',
model_format='checkpoint'
)
config = ModelConfigFactory.make_config(raw)
@ -173,12 +173,18 @@ class VaeDiffusersConfig(ModelConfigBase):
model_format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
class ControlNetDiffusersConfig(ModelConfigBase):
class ControlNetDiffusersConfig(DiffusersConfig):
"""Model config for ControlNet models (diffusers version)."""
model_format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
class ControlNetCheckpointConfig(CheckpointConfig):
"""Model config for ControlNet models (diffusers version)."""
model_format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
class TextualInversionConfig(ModelConfigBase):
"""Model config for textual inversion embeddings."""
@ -189,12 +195,14 @@ class MainConfig(ModelConfigBase):
"""Model config for main models."""
vae: Optional[str] = Field(None)
model_variant: ModelVariantType = ModelVariantType.Normal
variant: ModelVariantType = ModelVariantType.Normal
class MainCheckpointConfig(CheckpointConfig, MainConfig):
"""Model config for main checkpoint models."""
config: str
class MainDiffusersConfig(DiffusersConfig, MainConfig):
"""Model config for main diffusers models."""

View File

@ -7,6 +7,8 @@ from .base import ( # noqa F401
UnknownJobIDException,
DownloadJobBase,
ModelSourceMetadata,
REPO_ID_RE,
HTTP_RE,
)
from .queue import DownloadQueue # noqa F401

View File

@ -11,6 +11,10 @@ from typing import List, Optional, Callable, Union
from pydantic import BaseModel, Field
from pydantic.networks import AnyHttpUrl
# Used to distinguish between repo_id sources and URL sources
REPO_ID_RE = r"^[\w-]+/[.\w-]+$"
HTTP_RE = r"^https?://"
class DownloadJobStatus(str, Enum):
"""State of a download job."""
@ -49,7 +53,9 @@ class DownloadJobBase(BaseModel):
id: int = Field(description="Numeric ID of this job", default=-1) # default id is a placeholder
source: str = Field(description="URL or repo_id to download")
destination: Path = Field(description="Destination of URL on local disk")
metadata: Optional[ModelSourceMetadata] = Field(description="Model metadata (source-specific)", default=None)
metadata: ModelSourceMetadata = Field(
description="Model metadata (source-specific)", default_factory=ModelSourceMetadata
)
access_token: Optional[str] = Field(description="access token needed to access this resource")
status: DownloadJobStatus = Field(default=DownloadJobStatus.IDLE, description="Status of the download")
bytes: int = Field(default=0, description="Bytes downloaded so far")

View File

@ -22,6 +22,7 @@ from huggingface_hub import HfApi, hf_hub_url
from invokeai.backend.util.logging import InvokeAILogger
from invokeai.app.services.config import InvokeAIAppConfig
from . import REPO_ID_RE, HTTP_RE
from .base import (
DownloadQueueBase,
DownloadJobStatus,
@ -59,7 +60,7 @@ class DownloadJobRepoID(DownloadJobBase):
@validator("source")
@classmethod
def _validate_source(cls, v: str) -> str:
if not re.match(r"^[\w-]+/[\w-]+$", v):
if not re.match(REPO_ID_RE, v):
raise ValidationError(f"{v} invalid repo_id")
return v
@ -123,10 +124,10 @@ class DownloadQueue(DownloadQueueBase):
if Path(source).exists():
cls = DownloadJobPath
elif re.match(r"^[\w-]+/[\w-]+$", str(source)):
elif re.match(REPO_ID_RE, str(source)):
cls = DownloadJobRepoID
kwargs = dict(variant=variant)
elif re.match(r"^https?://", str(source)):
elif re.match(HTTP_RE, str(source)):
cls = DownloadJobURL
else:
raise NotImplementedError(f"Don't know what to do with this type of source: {source}")

View File

@ -66,6 +66,8 @@ from .download import (
DownloadJobBase,
ModelSourceMetadata,
DownloadEventHandler,
REPO_ID_RE,
HTTP_RE,
)
from .download.queue import DownloadJobURL, DownloadJobRepoID, DownloadJobPath
from .hash import FastModelHash
@ -481,10 +483,10 @@ class ModelInstall(ModelInstallBase):
models_dir = self._config.models_path
self._tmpdir = self._tmpdir or tempfile.TemporaryDirectory(dir=models_dir)
if re.match(r"^[\w-]+/[\w-]+$", str(source)):
if re.match(REPO_ID_RE, str(source)):
cls = ModelInstallRepoIDJob
kwargs = dict(variant=variant)
elif re.match(r"^https?://", str(source)):
elif re.match(HTTP_RE, str(source)):
cls = ModelInstallURLJob
kwargs = {}
else:

View File

@ -7,7 +7,7 @@ import torch
import invokeai.backend.util.logging as logger
from invokeai.app.services.config import InvokeAIAppConfig
from ..config import ControlNetDiffusersConfig, ControlNetCheckpointConfig
from .base import (
BaseModelType,
EmptyConfigLoader,
@ -32,12 +32,11 @@ class ControlNetModel(ModelBase):
# model_class: Type
# model_size: int
class DiffusersConfig(ModelConfigBase):
model_format: Literal[ControlNetModelFormat.Diffusers]
class DiffusersConfig(ControlNetDiffusersConfig):
model_format: Literal[ControlNetModelFormat.Diffusers] = ControlNetModelFormat.Diffusers
class CheckpointConfig(ModelConfigBase):
model_format: Literal[ControlNetModelFormat.Checkpoint]
config: str
class CheckpointConfig(ControlNetCheckpointConfig):
model_format: Literal[ControlNetModelFormat.Checkpoint] = ControlNetModelFormat.Checkpoint
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
assert model_type == ModelType.ControlNet

View File

@ -2,11 +2,11 @@ import bisect
import os
from enum import Enum
from pathlib import Path
from typing import Dict, Optional, Union
from typing import Dict, Optional, Union, Literal
import torch
from safetensors.torch import load_file
from ..config import LoRAConfig
from .base import (
BaseModelType,
InvalidModelException,
@ -27,8 +27,8 @@ class LoRAModelFormat(str, Enum):
class LoRAModel(ModelBase):
# model_size: int
class Config(ModelConfigBase):
model_format: LoRAModelFormat # TODO:
class Config(LoRAConfig):
model_format: Literal[LoRAModelFormat.LyCORIS] # TODO:
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
assert model_type == ModelType.Lora

View File

@ -5,7 +5,7 @@ from typing import Literal, Optional
from omegaconf import OmegaConf
from pydantic import Field
from ..config import MainDiffusersConfig, MainCheckpointConfig
from .base import (
BaseModelType,
DiffusersModel,
@ -25,16 +25,11 @@ class StableDiffusionXLModelFormat(str, Enum):
class StableDiffusionXLModel(DiffusersModel):
# TODO: check that configs overwriten properly
class DiffusersConfig(ModelConfigBase):
class DiffusersConfig(MainDiffusersConfig):
model_format: Literal[StableDiffusionXLModelFormat.Diffusers]
vae: Optional[str] = Field(None)
variant: ModelVariantType
class CheckpointConfig(ModelConfigBase):
model_format: Literal[StableDiffusionXLModelFormat.Checkpoint]
vae: Optional[str] = Field(None)
config: str
variant: ModelVariantType
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
assert base_model in {BaseModelType.StableDiffusionXL, BaseModelType.StableDiffusionXLRefiner}

View File

@ -12,6 +12,7 @@ from ..config import SilenceWarnings
import invokeai.backend.util.logging as logger
from invokeai.app.services.config import InvokeAIAppConfig
from ..config import MainCheckpointConfig, MainDiffusersConfig
from .base import (
BaseModelType,
DiffusersModel,
@ -32,16 +33,11 @@ class StableDiffusion1ModelFormat(str, Enum):
class StableDiffusion1Model(DiffusersModel):
class DiffusersConfig(ModelConfigBase):
class DiffusersConfig(MainDiffusersConfig):
model_format: Literal[StableDiffusion1ModelFormat.Diffusers]
vae: Optional[str] = Field(None)
variant: ModelVariantType
class CheckpointConfig(ModelConfigBase):
class CheckpointConfig(MainCheckpointConfig):
model_format: Literal[StableDiffusion1ModelFormat.Checkpoint]
vae: Optional[str] = Field(None)
config: str
variant: ModelVariantType
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
assert base_model == BaseModelType.StableDiffusion1

View File

@ -2,7 +2,7 @@ from enum import Enum
from typing import Literal
from diffusers import OnnxRuntimeModel
from ..config import ONNXSD1Config, ONNXSD2Config
from .base import (
BaseModelType,
DiffusersModel,
@ -21,9 +21,8 @@ class StableDiffusionOnnxModelFormat(str, Enum):
class ONNXStableDiffusion1Model(DiffusersModel):
class Config(ModelConfigBase):
class Config(ONNXSD1Config):
model_format: Literal[StableDiffusionOnnxModelFormat.Onnx]
variant: ModelVariantType
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
assert base_model == BaseModelType.StableDiffusion1
@ -80,11 +79,8 @@ class ONNXStableDiffusion1Model(DiffusersModel):
class ONNXStableDiffusion2Model(DiffusersModel):
# TODO: check that configs overwriten properly
class Config(ModelConfigBase):
class Config(ONNXSD2Config):
model_format: Literal[StableDiffusionOnnxModelFormat.Onnx]
variant: ModelVariantType
prediction_type: SchedulerPredictionType
upcast_attention: bool
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
assert base_model == BaseModelType.StableDiffusion2

View File

@ -1,10 +1,11 @@
import os
from typing import Optional
from typing import Optional, Literal
import torch
# TODO: naming
from ..lora import TextualInversionModel as TextualInversionModelRaw
from ..config import ModelFormat, TextualInversionConfig
from .base import (
BaseModelType,
InvalidModelException,
@ -20,8 +21,15 @@ from .base import (
class TextualInversionModel(ModelBase):
# model_size: int
class Config(ModelConfigBase):
model_format: None
class FolderConfig(TextualInversionConfig):
"""Config for embeddings that are represented as a folder containing learned_embeds.bin."""
model_format: Literal[ModelFormat.EmbeddingFolder]
class FileConfig(TextualInversionConfig):
"""Config for embeddings that are contained in safetensors/checkpoint files."""
model_format: Literal[ModelFormat.EmbeddingFile]
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
assert model_type == ModelType.TextualInversion

View File

@ -1,14 +1,14 @@
import os
from enum import Enum
from pathlib import Path
from typing import Optional
from typing import Optional, Literal
import safetensors
import torch
from omegaconf import OmegaConf
from invokeai.app.services.config import InvokeAIAppConfig
from ..config import VaeDiffusersConfig, VaeCheckpointConfig
from .base import (
BaseModelType,
EmptyConfigLoader,
@ -34,8 +34,11 @@ class VaeModel(ModelBase):
# vae_class: Type
# model_size: int
class Config(ModelConfigBase):
model_format: VaeModelFormat
class DiffusersConfig(VaeDiffusersConfig):
model_format: Literal[VaeModelFormat.Diffusers] = VaeModelFormat.Diffusers
class CheckpointConfig(VaeCheckpointConfig):
model_format: Literal[VaeModelFormat.Checkpoint] = VaeModelFormat.Checkpoint
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
assert model_type == ModelType.Vae