mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix a number of typechecking errors
This commit is contained in:
parent
433eb73d8e
commit
bd802d1e7a
@ -36,7 +36,7 @@ async def list_downloads() -> List[DownloadJob]:
|
|||||||
400: {"description": "Bad request"},
|
400: {"description": "Bad request"},
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
async def prune_downloads():
|
async def prune_downloads() -> Response:
|
||||||
"""Prune completed and errored jobs."""
|
"""Prune completed and errored jobs."""
|
||||||
queue = ApiDependencies.invoker.services.download_queue
|
queue = ApiDependencies.invoker.services.download_queue
|
||||||
queue.prune_jobs()
|
queue.prune_jobs()
|
||||||
@ -87,7 +87,7 @@ async def get_download_job(
|
|||||||
)
|
)
|
||||||
async def cancel_download_job(
|
async def cancel_download_job(
|
||||||
id: int = Path(description="ID of the download job to cancel."),
|
id: int = Path(description="ID of the download job to cancel."),
|
||||||
):
|
) -> Response:
|
||||||
"""Cancel a download job using its ID."""
|
"""Cancel a download job using its ID."""
|
||||||
try:
|
try:
|
||||||
queue = ApiDependencies.invoker.services.download_queue
|
queue = ApiDependencies.invoker.services.download_queue
|
||||||
@ -105,7 +105,7 @@ async def cancel_download_job(
|
|||||||
204: {"description": "Download jobs have been cancelled"},
|
204: {"description": "Download jobs have been cancelled"},
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
async def cancel_all_download_jobs():
|
async def cancel_all_download_jobs() -> Response:
|
||||||
"""Cancel all download jobs."""
|
"""Cancel all download jobs."""
|
||||||
ApiDependencies.invoker.services.download_queue.cancel_all_jobs()
|
ApiDependencies.invoker.services.download_queue.cancel_all_jobs()
|
||||||
return Response(status_code=204)
|
return Response(status_code=204)
|
||||||
|
@ -9,7 +9,7 @@ from typing import Any, Dict, List, Optional, Set
|
|||||||
|
|
||||||
from fastapi import Body, Path, Query, Response
|
from fastapi import Body, Path, Query, Response
|
||||||
from fastapi.routing import APIRouter
|
from fastapi.routing import APIRouter
|
||||||
from pydantic import BaseModel, ConfigDict
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
from starlette.exceptions import HTTPException
|
from starlette.exceptions import HTTPException
|
||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
@ -37,6 +37,35 @@ from ..dependencies import ApiDependencies
|
|||||||
|
|
||||||
model_manager_v2_router = APIRouter(prefix="/v2/models", tags=["model_manager_v2"])
|
model_manager_v2_router = APIRouter(prefix="/v2/models", tags=["model_manager_v2"])
|
||||||
|
|
||||||
|
example_model_output = {
|
||||||
|
"path": "sd-1/main/openjourney",
|
||||||
|
"name": "openjourney",
|
||||||
|
"base": "sd-1",
|
||||||
|
"type": "main",
|
||||||
|
"format": "diffusers",
|
||||||
|
"key": "3a0e45ff858926fd4a63da630688b1e1",
|
||||||
|
"original_hash": "1c12f18fb6e403baef26fb9d720fbd2f",
|
||||||
|
"current_hash": "1c12f18fb6e403baef26fb9d720fbd2f",
|
||||||
|
"description": "sd-1 main model openjourney",
|
||||||
|
"source": "/opt/invokeai/models/sd-1/main/openjourney",
|
||||||
|
"last_modified": 1707794711,
|
||||||
|
"vae": "/opt/invokeai/models/sd-1/vae/vae-ft-mse-840000-ema-pruned_fp16.safetensors",
|
||||||
|
"variant": "normal",
|
||||||
|
"prediction_type": "epsilon",
|
||||||
|
"repo_variant": "fp16",
|
||||||
|
}
|
||||||
|
|
||||||
|
example_model_input = {
|
||||||
|
"path": "base/type/name",
|
||||||
|
"name": "model_name",
|
||||||
|
"base": "sd-1",
|
||||||
|
"type": "main",
|
||||||
|
"format": "diffusers",
|
||||||
|
"description": "Model description",
|
||||||
|
"vae": None,
|
||||||
|
"variant": "normal",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class ModelsList(BaseModel):
|
class ModelsList(BaseModel):
|
||||||
"""Return list of configs."""
|
"""Return list of configs."""
|
||||||
@ -88,7 +117,10 @@ async def list_model_records(
|
|||||||
"/i/{key}",
|
"/i/{key}",
|
||||||
operation_id="get_model_record",
|
operation_id="get_model_record",
|
||||||
responses={
|
responses={
|
||||||
200: {"description": "Success"},
|
200: {
|
||||||
|
"description": "The model configuration was retrieved successfully",
|
||||||
|
"content": {"application/json": {"example": example_model_output}},
|
||||||
|
},
|
||||||
400: {"description": "Bad request"},
|
400: {"description": "Bad request"},
|
||||||
404: {"description": "The model could not be found"},
|
404: {"description": "The model could not be found"},
|
||||||
},
|
},
|
||||||
@ -165,18 +197,22 @@ async def search_by_metadata_tags(
|
|||||||
"/i/{key}",
|
"/i/{key}",
|
||||||
operation_id="update_model_record",
|
operation_id="update_model_record",
|
||||||
responses={
|
responses={
|
||||||
200: {"description": "The model was updated successfully"},
|
200: {
|
||||||
|
"description": "The model was updated successfully",
|
||||||
|
"content": {"application/json": {"example": example_model_output}},
|
||||||
|
},
|
||||||
400: {"description": "Bad request"},
|
400: {"description": "Bad request"},
|
||||||
404: {"description": "The model could not be found"},
|
404: {"description": "The model could not be found"},
|
||||||
409: {"description": "There is already a model corresponding to the new name"},
|
409: {"description": "There is already a model corresponding to the new name"},
|
||||||
},
|
},
|
||||||
status_code=200,
|
status_code=200,
|
||||||
response_model=AnyModelConfig,
|
|
||||||
)
|
)
|
||||||
async def update_model_record(
|
async def update_model_record(
|
||||||
key: Annotated[str, Path(description="Unique key of model")],
|
key: Annotated[str, Path(description="Unique key of model")],
|
||||||
info: Annotated[AnyModelConfig, Body(description="Model config", discriminator="type")],
|
info: Annotated[
|
||||||
) -> AnyModelConfig:
|
AnyModelConfig, Body(description="Model config", discriminator="type", example=example_model_input)
|
||||||
|
],
|
||||||
|
) -> Annotated[AnyModelConfig, Field(example="this is neat")]:
|
||||||
"""Update model contents with a new config. If the model name or base fields are changed, then the model is renamed."""
|
"""Update model contents with a new config. If the model name or base fields are changed, then the model is renamed."""
|
||||||
logger = ApiDependencies.invoker.services.logger
|
logger = ApiDependencies.invoker.services.logger
|
||||||
record_store = ApiDependencies.invoker.services.model_manager.store
|
record_store = ApiDependencies.invoker.services.model_manager.store
|
||||||
@ -225,7 +261,10 @@ async def del_model_record(
|
|||||||
"/i/",
|
"/i/",
|
||||||
operation_id="add_model_record",
|
operation_id="add_model_record",
|
||||||
responses={
|
responses={
|
||||||
201: {"description": "The model added successfully"},
|
201: {
|
||||||
|
"description": "The model added successfully",
|
||||||
|
"content": {"application/json": {"example": example_model_output}},
|
||||||
|
},
|
||||||
409: {"description": "There is already a model corresponding to this path or repo_id"},
|
409: {"description": "There is already a model corresponding to this path or repo_id"},
|
||||||
415: {"description": "Unrecognized file/folder format"},
|
415: {"description": "Unrecognized file/folder format"},
|
||||||
},
|
},
|
||||||
@ -270,6 +309,7 @@ async def heuristic_import(
|
|||||||
config: Optional[Dict[str, Any]] = Body(
|
config: Optional[Dict[str, Any]] = Body(
|
||||||
description="Dict of fields that override auto-probed values in the model config record, such as name, description and prediction_type ",
|
description="Dict of fields that override auto-probed values in the model config record, such as name, description and prediction_type ",
|
||||||
default=None,
|
default=None,
|
||||||
|
example={"name": "modelT", "description": "antique cars"},
|
||||||
),
|
),
|
||||||
access_token: Optional[str] = None,
|
access_token: Optional[str] = None,
|
||||||
) -> ModelInstallJob:
|
) -> ModelInstallJob:
|
||||||
@ -497,7 +537,10 @@ async def sync_models_to_config() -> Response:
|
|||||||
"/convert/{key}",
|
"/convert/{key}",
|
||||||
operation_id="convert_model",
|
operation_id="convert_model",
|
||||||
responses={
|
responses={
|
||||||
200: {"description": "Model converted successfully"},
|
200: {
|
||||||
|
"description": "Model converted successfully",
|
||||||
|
"content": {"application/json": {"example": example_model_output}},
|
||||||
|
},
|
||||||
400: {"description": "Bad request"},
|
400: {"description": "Bad request"},
|
||||||
404: {"description": "Model not found"},
|
404: {"description": "Model not found"},
|
||||||
409: {"description": "There is already a model registered at this location"},
|
409: {"description": "There is already a model registered at this location"},
|
||||||
@ -571,6 +614,15 @@ async def convert_model(
|
|||||||
@model_manager_v2_router.put(
|
@model_manager_v2_router.put(
|
||||||
"/merge",
|
"/merge",
|
||||||
operation_id="merge",
|
operation_id="merge",
|
||||||
|
responses={
|
||||||
|
200: {
|
||||||
|
"description": "Model converted successfully",
|
||||||
|
"content": {"application/json": {"example": example_model_output}},
|
||||||
|
},
|
||||||
|
400: {"description": "Bad request"},
|
||||||
|
404: {"description": "Model not found"},
|
||||||
|
409: {"description": "There is already a model registered at this location"},
|
||||||
|
},
|
||||||
)
|
)
|
||||||
async def merge(
|
async def merge(
|
||||||
keys: List[str] = Body(description="Keys for two to three models to merge", min_length=2, max_length=3),
|
keys: List[str] = Body(description="Keys for two to three models to merge", min_length=2, max_length=3),
|
||||||
@ -596,7 +648,6 @@ async def merge(
|
|||||||
interp: Interpolation method. One of "weighted_sum", "sigmoid", "inv_sigmoid" or "add_difference" [weighted_sum]
|
interp: Interpolation method. One of "weighted_sum", "sigmoid", "inv_sigmoid" or "add_difference" [weighted_sum]
|
||||||
merge_dest_directory: Specify a directory to store the merged model in [models directory]
|
merge_dest_directory: Specify a directory to store the merged model in [models directory]
|
||||||
"""
|
"""
|
||||||
print(f"here i am, keys={keys}")
|
|
||||||
logger = ApiDependencies.invoker.services.logger
|
logger = ApiDependencies.invoker.services.logger
|
||||||
try:
|
try:
|
||||||
logger.info(f"Merging models: {keys} into {merge_dest_directory or '<MODELS>'}/{merged_model_name}")
|
logger.info(f"Merging models: {keys} into {merge_dest_directory or '<MODELS>'}/{merged_model_name}")
|
||||||
|
@ -92,10 +92,10 @@ class IPAdapterInvocation(BaseInvocation):
|
|||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> IPAdapterOutput:
|
def invoke(self, context: InvocationContext) -> IPAdapterOutput:
|
||||||
# Lookup the CLIP Vision encoder that is intended to be used with the IP-Adapter model.
|
# Lookup the CLIP Vision encoder that is intended to be used with the IP-Adapter model.
|
||||||
ip_adapter_info = context.services.model_records.get_model(self.ip_adapter_model.key)
|
ip_adapter_info = context.services.model_manager.store.get_model(self.ip_adapter_model.key)
|
||||||
image_encoder_model_id = ip_adapter_info.image_encoder_model_id
|
image_encoder_model_id = ip_adapter_info.image_encoder_model_id
|
||||||
image_encoder_model_name = image_encoder_model_id.split("/")[-1].strip()
|
image_encoder_model_name = image_encoder_model_id.split("/")[-1].strip()
|
||||||
image_encoder_models = context.services.model_records.search_by_attr(
|
image_encoder_models = context.services.model_manager.store.search_by_attr(
|
||||||
model_name=image_encoder_model_name, base_model=BaseModelType.Any, model_type=ModelType.CLIPVision
|
model_name=image_encoder_model_name, base_model=BaseModelType.Any, model_type=ModelType.CLIPVision
|
||||||
)
|
)
|
||||||
assert len(image_encoder_models) == 1
|
assert len(image_encoder_models) == 1
|
||||||
|
@ -106,7 +106,7 @@ class MainModelLoaderInvocation(BaseInvocation):
|
|||||||
key = self.model.key
|
key = self.model.key
|
||||||
|
|
||||||
# TODO: not found exceptions
|
# TODO: not found exceptions
|
||||||
if not context.services.model_records.exists(key):
|
if not context.services.model_manager.store.exists(key):
|
||||||
raise Exception(f"Unknown model {key}")
|
raise Exception(f"Unknown model {key}")
|
||||||
|
|
||||||
return ModelLoaderOutput(
|
return ModelLoaderOutput(
|
||||||
@ -175,7 +175,7 @@ class LoraLoaderInvocation(BaseInvocation):
|
|||||||
|
|
||||||
lora_key = self.lora.key
|
lora_key = self.lora.key
|
||||||
|
|
||||||
if not context.services.model_records.exists(lora_key):
|
if not context.services.model_manager.store.exists(lora_key):
|
||||||
raise Exception(f"Unkown lora: {lora_key}!")
|
raise Exception(f"Unkown lora: {lora_key}!")
|
||||||
|
|
||||||
if self.unet is not None and any(lora.key == lora_key for lora in self.unet.loras):
|
if self.unet is not None and any(lora.key == lora_key for lora in self.unet.loras):
|
||||||
@ -255,7 +255,7 @@ class SDXLLoraLoaderInvocation(BaseInvocation):
|
|||||||
|
|
||||||
lora_key = self.lora.key
|
lora_key = self.lora.key
|
||||||
|
|
||||||
if not context.services.model_records.exists(lora_key):
|
if not context.services.model_manager.store.exists(lora_key):
|
||||||
raise Exception(f"Unknown lora: {lora_key}!")
|
raise Exception(f"Unknown lora: {lora_key}!")
|
||||||
|
|
||||||
if self.unet is not None and any(lora.key == lora_key for lora in self.unet.loras):
|
if self.unet is not None and any(lora.key == lora_key for lora in self.unet.loras):
|
||||||
@ -321,7 +321,7 @@ class VaeLoaderInvocation(BaseInvocation):
|
|||||||
def invoke(self, context: InvocationContext) -> VAEOutput:
|
def invoke(self, context: InvocationContext) -> VAEOutput:
|
||||||
key = self.vae_model.key
|
key = self.vae_model.key
|
||||||
|
|
||||||
if not context.services.model_records.exists(key):
|
if not context.services.model_manager.store.exists(key):
|
||||||
raise Exception(f"Unkown vae: {key}!")
|
raise Exception(f"Unkown vae: {key}!")
|
||||||
|
|
||||||
return VAEOutput(vae=VaeField(vae=ModelInfo(key=key)))
|
return VAEOutput(vae=VaeField(vae=ModelInfo(key=key)))
|
||||||
|
@ -27,11 +27,11 @@ class InvokeAISettings(BaseSettings):
|
|||||||
"""Runtime configuration settings in which default values are read from an omegaconf .yaml file."""
|
"""Runtime configuration settings in which default values are read from an omegaconf .yaml file."""
|
||||||
|
|
||||||
initconf: ClassVar[Optional[DictConfig]] = None
|
initconf: ClassVar[Optional[DictConfig]] = None
|
||||||
argparse_groups: ClassVar[Dict] = {}
|
argparse_groups: ClassVar[Dict[str, Any]] = {}
|
||||||
|
|
||||||
model_config = SettingsConfigDict(env_file_encoding="utf-8", arbitrary_types_allowed=True, case_sensitive=True)
|
model_config = SettingsConfigDict(env_file_encoding="utf-8", arbitrary_types_allowed=True, case_sensitive=True)
|
||||||
|
|
||||||
def parse_args(self, argv: Optional[list] = sys.argv[1:]):
|
def parse_args(self, argv: Optional[List[str]] = sys.argv[1:]) -> None:
|
||||||
"""Call to parse command-line arguments."""
|
"""Call to parse command-line arguments."""
|
||||||
parser = self.get_parser()
|
parser = self.get_parser()
|
||||||
opt, unknown_opts = parser.parse_known_args(argv)
|
opt, unknown_opts = parser.parse_known_args(argv)
|
||||||
@ -68,7 +68,7 @@ class InvokeAISettings(BaseSettings):
|
|||||||
return OmegaConf.to_yaml(conf)
|
return OmegaConf.to_yaml(conf)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def add_parser_arguments(cls, parser):
|
def add_parser_arguments(cls, parser) -> None:
|
||||||
"""Dynamically create arguments for a settings parser."""
|
"""Dynamically create arguments for a settings parser."""
|
||||||
if "type" in get_type_hints(cls):
|
if "type" in get_type_hints(cls):
|
||||||
settings_stanza = get_args(get_type_hints(cls)["type"])[0]
|
settings_stanza = get_args(get_type_hints(cls)["type"])[0]
|
||||||
@ -117,7 +117,8 @@ class InvokeAISettings(BaseSettings):
|
|||||||
"""Return the category of a setting."""
|
"""Return the category of a setting."""
|
||||||
hints = get_type_hints(cls)
|
hints = get_type_hints(cls)
|
||||||
if command_field in hints:
|
if command_field in hints:
|
||||||
return get_args(hints[command_field])[0]
|
result: str = get_args(hints[command_field])[0]
|
||||||
|
return result
|
||||||
else:
|
else:
|
||||||
return "Uncategorized"
|
return "Uncategorized"
|
||||||
|
|
||||||
@ -158,7 +159,7 @@ class InvokeAISettings(BaseSettings):
|
|||||||
]
|
]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def add_field_argument(cls, command_parser, name: str, field, default_override=None):
|
def add_field_argument(cls, command_parser, name: str, field, default_override=None) -> None:
|
||||||
"""Add the argparse arguments for a setting parser."""
|
"""Add the argparse arguments for a setting parser."""
|
||||||
field_type = get_type_hints(cls).get(name)
|
field_type = get_type_hints(cls).get(name)
|
||||||
default = (
|
default = (
|
||||||
|
@ -21,7 +21,7 @@ class PagingArgumentParser(argparse.ArgumentParser):
|
|||||||
It also supports reading defaults from an init file.
|
It also supports reading defaults from an init file.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def print_help(self, file=None):
|
def print_help(self, file=None) -> None:
|
||||||
text = self.format_help()
|
text = self.format_help()
|
||||||
pydoc.pager(text)
|
pydoc.pager(text)
|
||||||
|
|
||||||
|
@ -8,12 +8,12 @@ import time
|
|||||||
import traceback
|
import traceback
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from queue import Empty, PriorityQueue
|
from queue import Empty, PriorityQueue
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional, Set
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
from pydantic.networks import AnyHttpUrl
|
from pydantic.networks import AnyHttpUrl
|
||||||
from requests import HTTPError
|
from requests import HTTPError
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm, std
|
||||||
|
|
||||||
from invokeai.app.services.events.events_base import EventServiceBase
|
from invokeai.app.services.events.events_base import EventServiceBase
|
||||||
from invokeai.app.util.misc import get_iso_timestamp
|
from invokeai.app.util.misc import get_iso_timestamp
|
||||||
@ -49,12 +49,12 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
|||||||
:param max_parallel_dl: Number of simultaneous downloads allowed [5].
|
:param max_parallel_dl: Number of simultaneous downloads allowed [5].
|
||||||
:param requests_session: Optional requests.sessions.Session object, for unit tests.
|
:param requests_session: Optional requests.sessions.Session object, for unit tests.
|
||||||
"""
|
"""
|
||||||
self._jobs = {}
|
self._jobs: Dict[int, DownloadJob] = {}
|
||||||
self._next_job_id = 0
|
self._next_job_id = 0
|
||||||
self._queue = PriorityQueue()
|
self._queue: PriorityQueue[DownloadJob] = PriorityQueue()
|
||||||
self._stop_event = threading.Event()
|
self._stop_event = threading.Event()
|
||||||
self._job_completed_event = threading.Event()
|
self._job_completed_event = threading.Event()
|
||||||
self._worker_pool = set()
|
self._worker_pool: Set[threading.Thread] = set()
|
||||||
self._lock = threading.Lock()
|
self._lock = threading.Lock()
|
||||||
self._logger = InvokeAILogger.get_logger("DownloadQueueService")
|
self._logger = InvokeAILogger.get_logger("DownloadQueueService")
|
||||||
self._event_bus = event_bus
|
self._event_bus = event_bus
|
||||||
@ -424,7 +424,7 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
|||||||
class TqdmProgress(object):
|
class TqdmProgress(object):
|
||||||
"""TQDM-based progress bar object to use in on_progress handlers."""
|
"""TQDM-based progress bar object to use in on_progress handlers."""
|
||||||
|
|
||||||
_bars: Dict[int, tqdm] # the tqdm object
|
_bars: Dict[int, tqdm] # type: ignore
|
||||||
_last: Dict[int, int] # last bytes downloaded
|
_last: Dict[int, int] # last bytes downloaded
|
||||||
|
|
||||||
def __init__(self) -> None: # noqa D107
|
def __init__(self) -> None: # noqa D107
|
||||||
|
@ -5,7 +5,7 @@ import uuid
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
def get_timestamp():
|
def get_timestamp() -> int:
|
||||||
return int(datetime.datetime.now(datetime.timezone.utc).timestamp())
|
return int(datetime.datetime.now(datetime.timezone.utc).timestamp())
|
||||||
|
|
||||||
|
|
||||||
@ -20,16 +20,16 @@ def get_datetime_from_iso_timestamp(iso_timestamp: str) -> datetime.datetime:
|
|||||||
SEED_MAX = np.iinfo(np.uint32).max
|
SEED_MAX = np.iinfo(np.uint32).max
|
||||||
|
|
||||||
|
|
||||||
def get_random_seed():
|
def get_random_seed() -> int:
|
||||||
rng = np.random.default_rng(seed=None)
|
rng = np.random.default_rng(seed=None)
|
||||||
return int(rng.integers(0, SEED_MAX))
|
return int(rng.integers(0, SEED_MAX))
|
||||||
|
|
||||||
|
|
||||||
def uuid_string():
|
def uuid_string() -> str:
|
||||||
res = uuid.uuid4()
|
res = uuid.uuid4()
|
||||||
return str(res)
|
return str(res)
|
||||||
|
|
||||||
|
|
||||||
def is_optional(value: typing.Any):
|
def is_optional(value: typing.Any) -> bool:
|
||||||
"""Checks if a value is typed as Optional. Note that Optional is sugar for Union[x, None]."""
|
"""Checks if a value is typed as Optional. Note that Optional is sugar for Union[x, None]."""
|
||||||
return typing.get_origin(value) is typing.Union and type(None) in typing.get_args(value)
|
return typing.get_origin(value) is typing.Union and type(None) in typing.get_args(value)
|
||||||
|
@ -22,6 +22,7 @@ from invokeai.backend.model_manager.config import (
|
|||||||
AnyModel,
|
AnyModel,
|
||||||
AnyModelConfig,
|
AnyModelConfig,
|
||||||
BaseModelType,
|
BaseModelType,
|
||||||
|
ModelConfigBase,
|
||||||
ModelFormat,
|
ModelFormat,
|
||||||
ModelType,
|
ModelType,
|
||||||
SubModelType,
|
SubModelType,
|
||||||
@ -70,7 +71,7 @@ class ModelLoaderBase(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
|
def load_model(self, model_config: ModelConfigBase, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
|
||||||
"""
|
"""
|
||||||
Return a model given its confguration.
|
Return a model given its confguration.
|
||||||
|
|
||||||
@ -122,7 +123,7 @@ class AnyModelLoader:
|
|||||||
"""Return the convert cache associated used by the loaders."""
|
"""Return the convert cache associated used by the loaders."""
|
||||||
return self._convert_cache
|
return self._convert_cache
|
||||||
|
|
||||||
def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
|
def load_model(self, model_config: ModelConfigBase, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
|
||||||
"""
|
"""
|
||||||
Return a model given its configuration.
|
Return a model given its configuration.
|
||||||
|
|
||||||
@ -144,8 +145,8 @@ class AnyModelLoader:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_implementation(
|
def get_implementation(
|
||||||
cls, config: AnyModelConfig, submodel_type: Optional[SubModelType]
|
cls, config: ModelConfigBase, submodel_type: Optional[SubModelType]
|
||||||
) -> Tuple[Type[ModelLoaderBase], AnyModelConfig, Optional[SubModelType]]:
|
) -> Tuple[Type[ModelLoaderBase], ModelConfigBase, Optional[SubModelType]]:
|
||||||
"""Get subclass of ModelLoaderBase registered to handle base and type."""
|
"""Get subclass of ModelLoaderBase registered to handle base and type."""
|
||||||
# We have to handle VAE overrides here because this will change the model type and the corresponding implementation returned
|
# We have to handle VAE overrides here because this will change the model type and the corresponding implementation returned
|
||||||
conf2, submodel_type = cls._handle_subtype_overrides(config, submodel_type)
|
conf2, submodel_type = cls._handle_subtype_overrides(config, submodel_type)
|
||||||
@ -161,8 +162,8 @@ class AnyModelLoader:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _handle_subtype_overrides(
|
def _handle_subtype_overrides(
|
||||||
cls, config: AnyModelConfig, submodel_type: Optional[SubModelType]
|
cls, config: ModelConfigBase, submodel_type: Optional[SubModelType]
|
||||||
) -> Tuple[AnyModelConfig, Optional[SubModelType]]:
|
) -> Tuple[ModelConfigBase, Optional[SubModelType]]:
|
||||||
if submodel_type == SubModelType.Vae and hasattr(config, "vae") and config.vae is not None:
|
if submodel_type == SubModelType.Vae and hasattr(config, "vae") and config.vae is not None:
|
||||||
model_path = Path(config.vae)
|
model_path = Path(config.vae)
|
||||||
config_class = (
|
config_class = (
|
||||||
|
@ -34,8 +34,8 @@ from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot,
|
|||||||
from invokeai.backend.util.devices import choose_torch_device
|
from invokeai.backend.util.devices import choose_torch_device
|
||||||
from invokeai.backend.util.logging import InvokeAILogger
|
from invokeai.backend.util.logging import InvokeAILogger
|
||||||
|
|
||||||
from .model_cache_base import CacheRecord, CacheStats, ModelCacheBase
|
from .model_cache_base import CacheRecord, CacheStats, ModelCacheBase, ModelLockerBase
|
||||||
from .model_locker import ModelLocker, ModelLockerBase
|
from .model_locker import ModelLocker
|
||||||
|
|
||||||
if choose_torch_device() == torch.device("mps"):
|
if choose_torch_device() == torch.device("mps"):
|
||||||
from torch import mps
|
from torch import mps
|
||||||
|
@ -20,7 +20,7 @@ from requests.sessions import Session
|
|||||||
|
|
||||||
from invokeai.backend.model_manager import ModelRepoVariant
|
from invokeai.backend.model_manager import ModelRepoVariant
|
||||||
|
|
||||||
from ..metadata_base import AnyModelRepoMetadata, AnyModelRepoMetadataValidator
|
from ..metadata_base import AnyModelRepoMetadata, AnyModelRepoMetadataValidator, BaseMetadata
|
||||||
|
|
||||||
|
|
||||||
class ModelMetadataFetchBase(ABC):
|
class ModelMetadataFetchBase(ABC):
|
||||||
@ -62,5 +62,5 @@ class ModelMetadataFetchBase(ABC):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def from_json(cls, json: str) -> AnyModelRepoMetadata:
|
def from_json(cls, json: str) -> AnyModelRepoMetadata:
|
||||||
"""Given the JSON representation of the metadata, return the corresponding Pydantic object."""
|
"""Given the JSON representation of the metadata, return the corresponding Pydantic object."""
|
||||||
metadata = AnyModelRepoMetadataValidator.validate_json(json)
|
metadata: BaseMetadata = AnyModelRepoMetadataValidator.validate_json(json) # type: ignore
|
||||||
return metadata
|
return metadata
|
||||||
|
@ -166,7 +166,7 @@ class ModelProbe(object):
|
|||||||
fields["original_hash"] = fields.get("original_hash") or hash
|
fields["original_hash"] = fields.get("original_hash") or hash
|
||||||
fields["current_hash"] = fields.get("current_hash") or hash
|
fields["current_hash"] = fields.get("current_hash") or hash
|
||||||
|
|
||||||
if format_type == ModelFormat.Diffusers:
|
if format_type == ModelFormat.Diffusers and hasattr(probe, "get_repo_variant"):
|
||||||
fields["repo_variant"] = fields.get("repo_variant") or probe.get_repo_variant()
|
fields["repo_variant"] = fields.get("repo_variant") or probe.get_repo_variant()
|
||||||
|
|
||||||
# additional fields needed for main and controlnet models
|
# additional fields needed for main and controlnet models
|
||||||
|
@ -116,9 +116,9 @@ class ModelSearch(ModelSearchBase):
|
|||||||
# returns all models that have 'anime' in the path
|
# returns all models that have 'anime' in the path
|
||||||
"""
|
"""
|
||||||
|
|
||||||
models_found: Set[Path] = Field(default=None)
|
models_found: Optional[Set[Path]] = Field(default=None)
|
||||||
scanned_dirs: Set[Path] = Field(default=None)
|
scanned_dirs: Optional[Set[Path]] = Field(default=None)
|
||||||
pruned_paths: Set[Path] = Field(default=None)
|
pruned_paths: Optional[Set[Path]] = Field(default=None)
|
||||||
|
|
||||||
def search_started(self) -> None:
|
def search_started(self) -> None:
|
||||||
self.models_found = set()
|
self.models_found = set()
|
||||||
|
Loading…
Reference in New Issue
Block a user