mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
list_models() API call now working
This commit is contained in:
@ -29,12 +29,9 @@ ConvertModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||
MergeModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||
ImportModelAttributes = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||
|
||||
# class ModelsList(BaseModel):
|
||||
# models: list[Union[tuple(OPENAPI_MODEL_CONFIGS)]]
|
||||
|
||||
|
||||
class ModelsList(BaseModel):
|
||||
models: List[ModelConfigBase]
|
||||
models: list[Union[tuple(OPENAPI_MODEL_CONFIGS)]]
|
||||
|
||||
|
||||
@models_router.get(
|
||||
@ -51,9 +48,9 @@ async def list_models(
|
||||
if base_models and len(base_models) > 0:
|
||||
models_raw = list()
|
||||
for base_model in base_models:
|
||||
models_raw.extend(manager.list_models(base_model=base_model, model_type=model_type))
|
||||
models_raw.extend([x.dict() for x in manager.list_models(base_model=base_model, model_type=model_type)])
|
||||
else:
|
||||
models_raw = manager.list_models(model_type=model_type)
|
||||
models_raw = [x.dict() for x in manager.list_models(model_type=model_type)]
|
||||
models = parse_obj_as(ModelsList, {"models": models_raw})
|
||||
return models
|
||||
|
||||
|
@ -10,7 +10,7 @@ Typical usage:
|
||||
base_model='sd-1',
|
||||
model_type='main',
|
||||
config='configs/stable-diffusion/v1-inference.yaml',
|
||||
model_variant='normal',
|
||||
variant='normal',
|
||||
model_format='checkpoint'
|
||||
)
|
||||
config = ModelConfigFactory.make_config(raw)
|
||||
@ -173,12 +173,18 @@ class VaeDiffusersConfig(ModelConfigBase):
|
||||
model_format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
||||
|
||||
|
||||
class ControlNetDiffusersConfig(ModelConfigBase):
|
||||
class ControlNetDiffusersConfig(DiffusersConfig):
|
||||
"""Model config for ControlNet models (diffusers version)."""
|
||||
|
||||
model_format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
||||
|
||||
|
||||
class ControlNetCheckpointConfig(CheckpointConfig):
|
||||
"""Model config for ControlNet models (diffusers version)."""
|
||||
|
||||
model_format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
|
||||
|
||||
|
||||
class TextualInversionConfig(ModelConfigBase):
|
||||
"""Model config for textual inversion embeddings."""
|
||||
|
||||
@ -189,12 +195,14 @@ class MainConfig(ModelConfigBase):
|
||||
"""Model config for main models."""
|
||||
|
||||
vae: Optional[str] = Field(None)
|
||||
model_variant: ModelVariantType = ModelVariantType.Normal
|
||||
variant: ModelVariantType = ModelVariantType.Normal
|
||||
|
||||
|
||||
class MainCheckpointConfig(CheckpointConfig, MainConfig):
|
||||
"""Model config for main checkpoint models."""
|
||||
|
||||
config: str
|
||||
|
||||
|
||||
class MainDiffusersConfig(DiffusersConfig, MainConfig):
|
||||
"""Model config for main diffusers models."""
|
||||
|
@ -7,6 +7,8 @@ from .base import ( # noqa F401
|
||||
UnknownJobIDException,
|
||||
DownloadJobBase,
|
||||
ModelSourceMetadata,
|
||||
REPO_ID_RE,
|
||||
HTTP_RE,
|
||||
)
|
||||
|
||||
from .queue import DownloadQueue # noqa F401
|
||||
|
@ -11,6 +11,10 @@ from typing import List, Optional, Callable, Union
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
|
||||
# Used to distinguish between repo_id sources and URL sources
|
||||
REPO_ID_RE = r"^[\w-]+/[.\w-]+$"
|
||||
HTTP_RE = r"^https?://"
|
||||
|
||||
|
||||
class DownloadJobStatus(str, Enum):
|
||||
"""State of a download job."""
|
||||
@ -49,7 +53,9 @@ class DownloadJobBase(BaseModel):
|
||||
id: int = Field(description="Numeric ID of this job", default=-1) # default id is a placeholder
|
||||
source: str = Field(description="URL or repo_id to download")
|
||||
destination: Path = Field(description="Destination of URL on local disk")
|
||||
metadata: Optional[ModelSourceMetadata] = Field(description="Model metadata (source-specific)", default=None)
|
||||
metadata: ModelSourceMetadata = Field(
|
||||
description="Model metadata (source-specific)", default_factory=ModelSourceMetadata
|
||||
)
|
||||
access_token: Optional[str] = Field(description="access token needed to access this resource")
|
||||
status: DownloadJobStatus = Field(default=DownloadJobStatus.IDLE, description="Status of the download")
|
||||
bytes: int = Field(default=0, description="Bytes downloaded so far")
|
||||
|
@ -22,6 +22,7 @@ from huggingface_hub import HfApi, hf_hub_url
|
||||
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from . import REPO_ID_RE, HTTP_RE
|
||||
from .base import (
|
||||
DownloadQueueBase,
|
||||
DownloadJobStatus,
|
||||
@ -59,7 +60,7 @@ class DownloadJobRepoID(DownloadJobBase):
|
||||
@validator("source")
|
||||
@classmethod
|
||||
def _validate_source(cls, v: str) -> str:
|
||||
if not re.match(r"^[\w-]+/[\w-]+$", v):
|
||||
if not re.match(REPO_ID_RE, v):
|
||||
raise ValidationError(f"{v} invalid repo_id")
|
||||
return v
|
||||
|
||||
@ -123,10 +124,10 @@ class DownloadQueue(DownloadQueueBase):
|
||||
|
||||
if Path(source).exists():
|
||||
cls = DownloadJobPath
|
||||
elif re.match(r"^[\w-]+/[\w-]+$", str(source)):
|
||||
elif re.match(REPO_ID_RE, str(source)):
|
||||
cls = DownloadJobRepoID
|
||||
kwargs = dict(variant=variant)
|
||||
elif re.match(r"^https?://", str(source)):
|
||||
elif re.match(HTTP_RE, str(source)):
|
||||
cls = DownloadJobURL
|
||||
else:
|
||||
raise NotImplementedError(f"Don't know what to do with this type of source: {source}")
|
||||
|
@ -66,6 +66,8 @@ from .download import (
|
||||
DownloadJobBase,
|
||||
ModelSourceMetadata,
|
||||
DownloadEventHandler,
|
||||
REPO_ID_RE,
|
||||
HTTP_RE,
|
||||
)
|
||||
from .download.queue import DownloadJobURL, DownloadJobRepoID, DownloadJobPath
|
||||
from .hash import FastModelHash
|
||||
@ -481,10 +483,10 @@ class ModelInstall(ModelInstallBase):
|
||||
models_dir = self._config.models_path
|
||||
self._tmpdir = self._tmpdir or tempfile.TemporaryDirectory(dir=models_dir)
|
||||
|
||||
if re.match(r"^[\w-]+/[\w-]+$", str(source)):
|
||||
if re.match(REPO_ID_RE, str(source)):
|
||||
cls = ModelInstallRepoIDJob
|
||||
kwargs = dict(variant=variant)
|
||||
elif re.match(r"^https?://", str(source)):
|
||||
elif re.match(HTTP_RE, str(source)):
|
||||
cls = ModelInstallURLJob
|
||||
kwargs = {}
|
||||
else:
|
||||
|
@ -7,7 +7,7 @@ import torch
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
|
||||
from ..config import ControlNetDiffusersConfig, ControlNetCheckpointConfig
|
||||
from .base import (
|
||||
BaseModelType,
|
||||
EmptyConfigLoader,
|
||||
@ -32,12 +32,11 @@ class ControlNetModel(ModelBase):
|
||||
# model_class: Type
|
||||
# model_size: int
|
||||
|
||||
class DiffusersConfig(ModelConfigBase):
|
||||
model_format: Literal[ControlNetModelFormat.Diffusers]
|
||||
class DiffusersConfig(ControlNetDiffusersConfig):
|
||||
model_format: Literal[ControlNetModelFormat.Diffusers] = ControlNetModelFormat.Diffusers
|
||||
|
||||
class CheckpointConfig(ModelConfigBase):
|
||||
model_format: Literal[ControlNetModelFormat.Checkpoint]
|
||||
config: str
|
||||
class CheckpointConfig(ControlNetCheckpointConfig):
|
||||
model_format: Literal[ControlNetModelFormat.Checkpoint] = ControlNetModelFormat.Checkpoint
|
||||
|
||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||
assert model_type == ModelType.ControlNet
|
||||
|
@ -2,11 +2,11 @@ import bisect
|
||||
import os
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional, Union
|
||||
from typing import Dict, Optional, Union, Literal
|
||||
|
||||
import torch
|
||||
from safetensors.torch import load_file
|
||||
|
||||
from ..config import LoRAConfig
|
||||
from .base import (
|
||||
BaseModelType,
|
||||
InvalidModelException,
|
||||
@ -27,8 +27,8 @@ class LoRAModelFormat(str, Enum):
|
||||
class LoRAModel(ModelBase):
|
||||
# model_size: int
|
||||
|
||||
class Config(ModelConfigBase):
|
||||
model_format: LoRAModelFormat # TODO:
|
||||
class Config(LoRAConfig):
|
||||
model_format: Literal[LoRAModelFormat.LyCORIS] # TODO:
|
||||
|
||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||
assert model_type == ModelType.Lora
|
||||
|
@ -5,7 +5,7 @@ from typing import Literal, Optional
|
||||
|
||||
from omegaconf import OmegaConf
|
||||
from pydantic import Field
|
||||
|
||||
from ..config import MainDiffusersConfig, MainCheckpointConfig
|
||||
from .base import (
|
||||
BaseModelType,
|
||||
DiffusersModel,
|
||||
@ -25,16 +25,11 @@ class StableDiffusionXLModelFormat(str, Enum):
|
||||
|
||||
class StableDiffusionXLModel(DiffusersModel):
|
||||
# TODO: check that configs overwriten properly
|
||||
class DiffusersConfig(ModelConfigBase):
|
||||
class DiffusersConfig(MainDiffusersConfig):
|
||||
model_format: Literal[StableDiffusionXLModelFormat.Diffusers]
|
||||
vae: Optional[str] = Field(None)
|
||||
variant: ModelVariantType
|
||||
|
||||
class CheckpointConfig(ModelConfigBase):
|
||||
model_format: Literal[StableDiffusionXLModelFormat.Checkpoint]
|
||||
vae: Optional[str] = Field(None)
|
||||
config: str
|
||||
variant: ModelVariantType
|
||||
|
||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||
assert base_model in {BaseModelType.StableDiffusionXL, BaseModelType.StableDiffusionXLRefiner}
|
||||
|
@ -12,6 +12,7 @@ from ..config import SilenceWarnings
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
|
||||
from ..config import MainCheckpointConfig, MainDiffusersConfig
|
||||
from .base import (
|
||||
BaseModelType,
|
||||
DiffusersModel,
|
||||
@ -32,16 +33,11 @@ class StableDiffusion1ModelFormat(str, Enum):
|
||||
|
||||
|
||||
class StableDiffusion1Model(DiffusersModel):
|
||||
class DiffusersConfig(ModelConfigBase):
|
||||
class DiffusersConfig(MainDiffusersConfig):
|
||||
model_format: Literal[StableDiffusion1ModelFormat.Diffusers]
|
||||
vae: Optional[str] = Field(None)
|
||||
variant: ModelVariantType
|
||||
|
||||
class CheckpointConfig(ModelConfigBase):
|
||||
class CheckpointConfig(MainCheckpointConfig):
|
||||
model_format: Literal[StableDiffusion1ModelFormat.Checkpoint]
|
||||
vae: Optional[str] = Field(None)
|
||||
config: str
|
||||
variant: ModelVariantType
|
||||
|
||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||
assert base_model == BaseModelType.StableDiffusion1
|
||||
|
@ -2,7 +2,7 @@ from enum import Enum
|
||||
from typing import Literal
|
||||
|
||||
from diffusers import OnnxRuntimeModel
|
||||
|
||||
from ..config import ONNXSD1Config, ONNXSD2Config
|
||||
from .base import (
|
||||
BaseModelType,
|
||||
DiffusersModel,
|
||||
@ -21,9 +21,8 @@ class StableDiffusionOnnxModelFormat(str, Enum):
|
||||
|
||||
|
||||
class ONNXStableDiffusion1Model(DiffusersModel):
|
||||
class Config(ModelConfigBase):
|
||||
class Config(ONNXSD1Config):
|
||||
model_format: Literal[StableDiffusionOnnxModelFormat.Onnx]
|
||||
variant: ModelVariantType
|
||||
|
||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||
assert base_model == BaseModelType.StableDiffusion1
|
||||
@ -80,11 +79,8 @@ class ONNXStableDiffusion1Model(DiffusersModel):
|
||||
|
||||
class ONNXStableDiffusion2Model(DiffusersModel):
|
||||
# TODO: check that configs overwriten properly
|
||||
class Config(ModelConfigBase):
|
||||
class Config(ONNXSD2Config):
|
||||
model_format: Literal[StableDiffusionOnnxModelFormat.Onnx]
|
||||
variant: ModelVariantType
|
||||
prediction_type: SchedulerPredictionType
|
||||
upcast_attention: bool
|
||||
|
||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||
assert base_model == BaseModelType.StableDiffusion2
|
||||
|
@ -1,10 +1,11 @@
|
||||
import os
|
||||
from typing import Optional
|
||||
from typing import Optional, Literal
|
||||
|
||||
import torch
|
||||
|
||||
# TODO: naming
|
||||
from ..lora import TextualInversionModel as TextualInversionModelRaw
|
||||
from ..config import ModelFormat, TextualInversionConfig
|
||||
from .base import (
|
||||
BaseModelType,
|
||||
InvalidModelException,
|
||||
@ -20,8 +21,15 @@ from .base import (
|
||||
class TextualInversionModel(ModelBase):
|
||||
# model_size: int
|
||||
|
||||
class Config(ModelConfigBase):
|
||||
model_format: None
|
||||
class FolderConfig(TextualInversionConfig):
|
||||
"""Config for embeddings that are represented as a folder containing learned_embeds.bin."""
|
||||
|
||||
model_format: Literal[ModelFormat.EmbeddingFolder]
|
||||
|
||||
class FileConfig(TextualInversionConfig):
|
||||
"""Config for embeddings that are contained in safetensors/checkpoint files."""
|
||||
|
||||
model_format: Literal[ModelFormat.EmbeddingFile]
|
||||
|
||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||
assert model_type == ModelType.TextualInversion
|
||||
|
@ -1,14 +1,14 @@
|
||||
import os
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from typing import Optional, Literal
|
||||
|
||||
import safetensors
|
||||
import torch
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
|
||||
from ..config import VaeDiffusersConfig, VaeCheckpointConfig
|
||||
from .base import (
|
||||
BaseModelType,
|
||||
EmptyConfigLoader,
|
||||
@ -34,8 +34,11 @@ class VaeModel(ModelBase):
|
||||
# vae_class: Type
|
||||
# model_size: int
|
||||
|
||||
class Config(ModelConfigBase):
|
||||
model_format: VaeModelFormat
|
||||
class DiffusersConfig(VaeDiffusersConfig):
|
||||
model_format: Literal[VaeModelFormat.Diffusers] = VaeModelFormat.Diffusers
|
||||
|
||||
class CheckpointConfig(VaeCheckpointConfig):
|
||||
model_format: Literal[VaeModelFormat.Checkpoint] = VaeModelFormat.Checkpoint
|
||||
|
||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||
assert model_type == ModelType.Vae
|
||||
|
Reference in New Issue
Block a user