mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
more type mismatch fixes
This commit is contained in:
@ -3,13 +3,14 @@
|
||||
|
||||
import pathlib
|
||||
from enum import Enum
|
||||
from typing import List, Literal, Optional, Union, Tuple
|
||||
from typing import Any, List, Literal, Optional, Union
|
||||
|
||||
from fastapi import Body, Path, Query, Response
|
||||
from fastapi.routing import APIRouter
|
||||
from pydantic import BaseModel, parse_obj_as
|
||||
from starlette.exceptions import HTTPException
|
||||
|
||||
from invokeai.app.api.dependencies import ApiDependencies
|
||||
from invokeai.backend import BaseModelType, ModelType
|
||||
from invokeai.backend.model_manager import (
|
||||
OPENAPI_MODEL_CONFIGS,
|
||||
@ -22,8 +23,6 @@ from invokeai.backend.model_manager import (
|
||||
from invokeai.backend.model_manager.download import DownloadJobStatus, UnknownJobIDException
|
||||
from invokeai.backend.model_manager.merge import MergeInterpolationMethod
|
||||
|
||||
from invokeai.app.api.dependencies import ApiDependencies
|
||||
|
||||
models_router = APIRouter(prefix="/v1/models", tags=["models"])
|
||||
|
||||
# NOTE: The generic configuration classes defined in invokeai.backend.model_manager.config
|
||||
@ -33,15 +32,11 @@ models_router = APIRouter(prefix="/v1/models", tags=["models"])
|
||||
|
||||
# There are still numerous mypy errors here because it does not seem to like this
|
||||
# way of dynamically generating the typing hints below.
|
||||
UpdateModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||
ImportModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||
ConvertModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||
MergeModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||
ImportModelAttributes = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||
InvokeAIModelConfig: Any = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||
|
||||
|
||||
class ModelsList(BaseModel):
|
||||
models: List[Union[tuple(OPENAPI_MODEL_CONFIGS)]]
|
||||
models: List[InvokeAIModelConfig]
|
||||
|
||||
|
||||
class ModelImportStatus(BaseModel):
|
||||
@ -93,12 +88,12 @@ async def list_models(
|
||||
409: {"description": "There is already a model corresponding to the new name"},
|
||||
},
|
||||
status_code=200,
|
||||
response_model=UpdateModelResponse,
|
||||
response_model=InvokeAIModelConfig,
|
||||
)
|
||||
async def update_model(
|
||||
key: str = Path(description="Unique key of model"),
|
||||
info: Union[tuple(OPENAPI_MODEL_CONFIGS)] = Body(description="Model configuration"),
|
||||
) -> UpdateModelResponse:
|
||||
info: InvokeAIModelConfig = Body(description="Model configuration"),
|
||||
) -> InvokeAIModelConfig:
|
||||
"""Update model contents with a new config. If the model name or base fields are changed, then the model is renamed."""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
|
||||
@ -106,7 +101,7 @@ async def update_model(
|
||||
info_dict = info.dict()
|
||||
info_dict = {x: info_dict[x] if info_dict[x] else None for x in info_dict.keys()}
|
||||
new_config = ApiDependencies.invoker.services.model_manager.update_model(key, new_config=info_dict)
|
||||
model_response = parse_obj_as(UpdateModelResponse, new_config.dict())
|
||||
model_response = parse_obj_as(InvokeAIModelConfig, new_config.dict())
|
||||
except UnknownModelException as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except ValueError as e:
|
||||
@ -198,11 +193,11 @@ async def import_model(
|
||||
409: {"description": "There is already a model corresponding to this path or repo_id"},
|
||||
},
|
||||
status_code=201,
|
||||
response_model=ImportModelResponse,
|
||||
response_model=InvokeAIModelConfig,
|
||||
)
|
||||
async def add_model(
|
||||
info: Union[tuple(OPENAPI_MODEL_CONFIGS)] = Body(description="Model configuration"),
|
||||
) -> ImportModelResponse:
|
||||
info: InvokeAIModelConfig = Body(description="Model configuration"),
|
||||
) -> InvokeAIModelConfig:
|
||||
"""
|
||||
Add a model using the configuration information appropriate for its type. Only local models can be added by path.
|
||||
This call will block until the model is installed.
|
||||
@ -220,7 +215,7 @@ async def add_model(
|
||||
info_dict = info.dict()
|
||||
info_dict = {x: info_dict[x] if info_dict[x] else None for x in info_dict.keys()}
|
||||
new_config = ApiDependencies.invoker.services.model_manager.update_model(key, new_config=info_dict)
|
||||
return parse_obj_as(ImportModelResponse, new_config.dict())
|
||||
return parse_obj_as(InvokeAIModelConfig, new_config.dict())
|
||||
except UnknownModelException as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
@ -261,20 +256,20 @@ async def delete_model(
|
||||
404: {"description": "Model not found"},
|
||||
},
|
||||
status_code=200,
|
||||
response_model=ConvertModelResponse,
|
||||
response_model=InvokeAIModelConfig,
|
||||
)
|
||||
async def convert_model(
|
||||
key: str = Path(description="Unique key of model to remove from model registry."),
|
||||
convert_dest_directory: Optional[str] = Query(
|
||||
default=None, description="Save the converted model to the designated directory"
|
||||
),
|
||||
) -> ConvertModelResponse:
|
||||
) -> InvokeAIModelConfig:
|
||||
"""Convert a checkpoint model into a diffusers model, optionally saving to the indicated destination directory, or `models` if none."""
|
||||
try:
|
||||
dest = pathlib.Path(convert_dest_directory) if convert_dest_directory else None
|
||||
ApiDependencies.invoker.services.model_manager.convert_model(key, convert_dest_directory=dest)
|
||||
model_raw = ApiDependencies.invoker.services.model_manager.model_info(key).dict()
|
||||
response = parse_obj_as(ConvertModelResponse, model_raw)
|
||||
response = parse_obj_as(InvokeAIModelConfig, model_raw)
|
||||
except UnknownModelException as e:
|
||||
raise HTTPException(status_code=404, detail=f"Model '{key}' not found: {str(e)}")
|
||||
except ValueError as e:
|
||||
@ -347,7 +342,7 @@ async def sync_to_config() -> bool:
|
||||
409: {"description": "An identical merged model is already installed"},
|
||||
},
|
||||
status_code=200,
|
||||
response_model=MergeModelResponse,
|
||||
response_model=InvokeAIModelConfig,
|
||||
)
|
||||
async def merge_models(
|
||||
keys: List[str] = Body(description="model name", min_items=2, max_items=3),
|
||||
@ -361,7 +356,7 @@ async def merge_models(
|
||||
description="Save the merged model to the designated directory (with 'merged_model_name' appended)",
|
||||
default=None,
|
||||
),
|
||||
) -> MergeModelResponse:
|
||||
) -> InvokeAIModelConfig:
|
||||
"""Merge the indicated diffusers model."""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
try:
|
||||
@ -375,7 +370,7 @@ async def merge_models(
|
||||
force=force,
|
||||
merge_dest_directory=dest,
|
||||
)
|
||||
response = parse_obj_as(ConvertModelResponse, result.dict())
|
||||
response = parse_obj_as(InvokeAIModelConfig, result.dict())
|
||||
except DuplicateModelException as e:
|
||||
raise HTTPException(status_code=409, detail=str(e))
|
||||
except UnknownModelException:
|
||||
|
@ -26,9 +26,9 @@ from invokeai.backend.model_manager import (
|
||||
from invokeai.backend.model_manager.cache import CacheStats
|
||||
from invokeai.backend.model_manager.download import DownloadJobBase
|
||||
from invokeai.backend.model_manager.merge import MergeInterpolationMethod, ModelMerger
|
||||
from .events import EventServiceBase
|
||||
|
||||
from .config import InvokeAIAppConfig
|
||||
from .events import EventServiceBase
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..invocations.baseinvocation import InvocationContext
|
||||
|
@ -208,7 +208,7 @@ class ModelCache(object):
|
||||
self.stats.hits += 1
|
||||
|
||||
if self.stats:
|
||||
self.stats.cache_size = self.max_cache_size * GIG
|
||||
self.stats.cache_size = int(self.max_cache_size * GIG)
|
||||
self.stats.high_watermark = max(self.stats.high_watermark, self._cache_size())
|
||||
self.stats.in_cache = len(self._cached_models)
|
||||
self.stats.loaded_model_sizes[key] = max(
|
||||
|
@ -1,6 +1,6 @@
|
||||
import inspect
|
||||
from enum import Enum
|
||||
from typing import Literal, get_origin
|
||||
from typing import Any, Literal, get_origin
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
@ -89,8 +89,8 @@ MODEL_CLASSES = {
|
||||
# },
|
||||
}
|
||||
|
||||
MODEL_CONFIGS = list()
|
||||
OPENAPI_MODEL_CONFIGS = list()
|
||||
MODEL_CONFIGS: Any = list()
|
||||
OPENAPI_MODEL_CONFIGS: Any = list()
|
||||
|
||||
|
||||
class OpenAPIModelInfoBase(BaseModel):
|
||||
|
@ -2,14 +2,11 @@ import base64
|
||||
import importlib
|
||||
import io
|
||||
import math
|
||||
import multiprocessing as mp
|
||||
import os
|
||||
import re
|
||||
from collections import abc
|
||||
from inspect import isfunction
|
||||
from pathlib import Path
|
||||
from queue import Queue
|
||||
from threading import Thread
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
@ -166,7 +163,7 @@ def ask_user(question: str, answers: list):
|
||||
|
||||
|
||||
# -------------------------------------
|
||||
def download_with_resume(url: str, dest: Path, access_token: str = None) -> Path:
|
||||
def download_with_resume(url: str, dest: Path, access_token: str = None) -> Optional[Path]:
|
||||
"""
|
||||
Download a model file.
|
||||
:param url: https, http or ftp URL
|
||||
@ -183,10 +180,7 @@ def download_with_resume(url: str, dest: Path, access_token: str = None) -> Path
|
||||
content_length = int(resp.headers.get("content-length", 0))
|
||||
|
||||
if dest.is_dir():
|
||||
try:
|
||||
file_name = re.search('filename="(.+)"', resp.headers.get("Content-Disposition")).group(1)
|
||||
except AttributeError:
|
||||
file_name = os.path.basename(url)
|
||||
file_name = response_attachment(resp) or os.path.basename(url)
|
||||
dest = dest / file_name
|
||||
else:
|
||||
dest.parent.mkdir(parents=True, exist_ok=True)
|
||||
@ -235,15 +229,24 @@ def download_with_resume(url: str, dest: Path, access_token: str = None) -> Path
|
||||
return dest
|
||||
|
||||
|
||||
def url_attachment_name(url: str) -> dict:
|
||||
def response_attachment(response: requests.Response) -> Optional[str]:
|
||||
try:
|
||||
resp = requests.get(url, stream=True)
|
||||
match = re.search('filename="(.+)"', resp.headers.get("Content-Disposition"))
|
||||
return match.group(1)
|
||||
if disposition := response.headers.get("Content-Disposition"):
|
||||
if match := re.search('filename="(.+)"', disposition):
|
||||
return match.group(1)
|
||||
return None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def url_attachment_name(url: str) -> Optional[str]:
|
||||
resp = requests.get(url)
|
||||
if resp.ok:
|
||||
return response_attachment(resp)
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def download_with_progress_bar(url: str, dest: Path) -> bool:
|
||||
result = download_with_resume(url, dest, access_token=None)
|
||||
return result is not None
|
||||
|
Reference in New Issue
Block a user