accept @psychedelicious suggestions above

This commit is contained in:
Lincoln Stein 2023-07-05 14:50:57 -04:00
commit 5027d0a603
3 changed files with 108 additions and 53 deletions

View File

@ -2,25 +2,28 @@
from typing import Literal, Optional, Union
from fastapi import Query, Body, Path
from fastapi.routing import APIRouter, HTTPException
from fastapi import Body, Path, Query, Response
from fastapi.routing import APIRouter
from pydantic import BaseModel, Field, parse_obj_as
from ..dependencies import ApiDependencies
from starlette.exceptions import HTTPException
from invokeai.backend import BaseModelType, ModelType
from invokeai.backend.model_management import AddModelResult
from invokeai.backend.model_management.models import MODEL_CONFIGS, OPENAPI_MODEL_CONFIGS, SchedulerPredictionType
from invokeai.backend.model_management.models import (MODEL_CONFIGS,
OPENAPI_MODEL_CONFIGS,
SchedulerPredictionType)
from ..dependencies import ApiDependencies
models_router = APIRouter(prefix="/v1/models", tags=["models"])
class CreateModelResponse(BaseModel):
class UpdateModelResponse(BaseModel):
model_name: str = Field(description="The name of the new model")
info: Union[tuple(MODEL_CONFIGS)] = Field(description="The model info")
status: str = Field(description="The status of the API response")
class ImportModelResponse(BaseModel):
name: str = Field(description="The name of the imported model")
location: str = Field(description="The path, repo_id or URL of the imported model")
info: AddModelResult = Field(description="The model info")
status: str = Field(description="The status of the API response")
class ConvertModelResponse(BaseModel):
name: str = Field(description="The name of the imported model")
@ -48,51 +51,65 @@ async def list_models(
models = parse_obj_as(ModelsList, { "models": models_raw })
return models
@models_router.post(
@models_router.patch(
"/{base_model}/{model_type}/{model_name}",
operation_id="update_model",
responses={200: {"status": "success"}},
responses={200: {"description" : "The model was updated successfully"},
404: {"description" : "The model could not be found"},
400: {"description" : "Bad request"}
},
status_code = 200,
response_model = UpdateModelResponse,
)
async def update_model(
base_model: BaseModelType = Path(default='sd-1', description="Base model"),
model_type: ModelType = Path(default='main', description="The type of model"),
model_name: str = Path(default=None, description="model name"),
info: Union[tuple(MODEL_CONFIGS)] = Body(description="Model configuration"),
) -> CreateModelResponse:
) -> UpdateModelResponse:
""" Add Model """
ApiDependencies.invoker.services.model_manager.add_model(
try:
ApiDependencies.invoker.services.model_manager.update_model(
model_name=model_name,
base_model=base_model,
model_type=model_type,
model_attributes=info.dict(),
clobber=True,
model_attributes=info.dict()
)
model_response = CreateModelResponse(
model_response = UpdateModelResponse(
model_name = model_name,
info = info,
status="success")
info = ApiDependencies.invoker.services.model_manager.model_info(
model_name=model_name,
base_model=base_model,
model_type=model_type,
)
)
except KeyError as e:
raise HTTPException(status_code=404, detail=str(e))
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
return model_response
@models_router.post(
"/import",
"/",
operation_id="import_model",
responses= {
201: {"description" : "The model imported successfully"},
404: {"description" : "The model could not be found"},
424: {"description" : "The model appeared to import successfully, but could not be found in the model manager"},
409: {"description" : "There is already a model corresponding to this path or repo_id"},
},
status_code=201,
response_model=ImportModelResponse
)
async def import_model(
name: str = Body(description="A model path, repo_id or URL to import"),
location: str = Body(description="A model path, repo_id or URL to import"),
prediction_type: Optional[Literal['v_prediction','epsilon','sample']] = \
Body(description='Prediction type for SDv2 checkpoint files', default="v_prediction"),
) -> ImportModelResponse:
""" Add a model using its local path, repo_id, or remote URL """
items_to_import = {name}
items_to_import = {location}
prediction_types = { x.value: x for x in SchedulerPredictionType }
logger = ApiDependencies.invoker.services.logger
@ -101,12 +118,16 @@ async def import_model(
items_to_import = items_to_import,
prediction_type_helper = lambda x: prediction_types.get(prediction_type)
)
if info := installed_models.get(name):
logger.info(f'Successfully imported {name}, got {info}')
info = installed_models.get(location)
if not info:
logger.error("Import failed")
raise HTTPException(status_code=424)
logger.info(f'Successfully imported {location}, got {info}')
return ImportModelResponse(
name = name,
location = location,
info = info,
status = "success",
)
except KeyError as e:
logger.error(str(e))
@ -129,10 +150,10 @@ async def import_model(
},
)
async def delete_model(
base_model: BaseModelType = Path(default='sd-1', description="Base model"),
model_type: ModelType = Path(default='main', description="The type of model"),
model_name: str = Path(default=None, description="model name"),
) -> None:
base_model: BaseModelType = Path(description="Base model"),
model_type: ModelType = Path(description="The type of model"),
model_name: str = Path(description="model name"),
) -> Response:
"""Delete Model"""
logger = ApiDependencies.invoker.services.logger
@ -142,14 +163,10 @@ async def delete_model(
model_type = model_type
)
logger.info(f"Deleted model: {model_name}")
raise HTTPException(status_code=204, detail=f"Model '{model_name}' deleted successfully")
return Response(status_code=204)
except KeyError:
logger.error(f"Model not found: {model_name}")
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found")
else:
logger.info(f"Model deleted: {model_name}")
raise HTTPException(status_code=204, detail=f"Model '{model_name}' deleted successfully")
# @socketio.on("convertToDiffusers")
# def convert_to_diffusers(model_to_convert: dict):

View File

@ -2,10 +2,10 @@
from __future__ import annotations
import torch
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Optional, Union, Callable, List, Set, Dict, Tuple, types, TYPE_CHECKING
from typing import Optional, Union, Callable, List, Tuple, TYPE_CHECKING
from types import ModuleType
from invokeai.backend.model_management.model_manager import (
ModelManager,
@ -16,9 +16,11 @@ from invokeai.backend.model_management.model_manager import (
AddModelResult,
SchedulerPredictionType,
)
import torch
from invokeai.app.models.exceptions import CanceledException
from .config import InvokeAIAppConfig
from ...backend.util import choose_precision, choose_torch_device
from .config import InvokeAIAppConfig
if TYPE_CHECKING:
from ..invocations.baseinvocation import BaseInvocation, InvocationContext
@ -31,7 +33,7 @@ class ModelManagerServiceBase(ABC):
def __init__(
self,
config: InvokeAIAppConfig,
logger: types.ModuleType,
logger: ModuleType,
):
"""
Initialize with the path to the models.yaml config file.
@ -122,6 +124,24 @@ class ModelManagerServiceBase(ABC):
"""
pass
@abstractmethod
def update_model(
self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
model_attributes: dict,
) -> AddModelResult:
"""
Update the named model with a dictionary of attributes. Will fail with a
KeyErrorException if the name does not already exist.
On a successful update, the config will be changed in memory. Will fail
with an assertion error if provided attributes are incorrect or
the model name is missing. Call commit() to write changes to disk.
"""
pass
@abstractmethod
def del_model(
self,
@ -159,9 +179,9 @@ class ModelManagerServiceBase(ABC):
@abstractmethod
def heuristic_import(self,
items_to_import: Set[str],
prediction_type_helper: Callable[[Path],SchedulerPredictionType]=None,
)->Dict[str, AddModelResult]:
items_to_import: set[str],
prediction_type_helper: Optional[Callable[[Path],SchedulerPredictionType]]=None,
)->dict[str, AddModelResult]:
'''Import a list of paths, repo_ids or URLs. Returns the set of
successfully imported items.
:param items_to_import: Set of strings corresponding to models to be imported.
@ -181,7 +201,7 @@ class ModelManagerServiceBase(ABC):
pass
@abstractmethod
def commit(self, conf_file: Path = None) -> None:
def commit(self, conf_file: Optional[Path] = None) -> None:
"""
Write current configuration out to the indicated file.
If no conf_file is provided, then replaces the
@ -195,7 +215,7 @@ class ModelManagerService(ModelManagerServiceBase):
def __init__(
self,
config: InvokeAIAppConfig,
logger: types.ModuleType,
logger: ModuleType,
):
"""
Initialize with the path to the models.yaml config file.
@ -343,6 +363,24 @@ class ModelManagerService(ModelManagerServiceBase):
self.logger.debug(f'add/update model {model_name}')
return self.mgr.add_model(model_name, base_model, model_type, model_attributes, clobber)
def update_model(
self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
model_attributes: dict,
) -> AddModelResult:
"""
Update the named model with a dictionary of attributes. Will fail with a
KeyError exception if the name does not already exist.
On a successful update, the config will be changed in memory. Will fail
with an assertion error if provided attributes are incorrect or
the model name is missing. Call commit() to write changes to disk.
"""
self.logger.debug(f'update model {model_name}')
if not self.model_exists(model_name, base_model, model_type):
raise KeyError(f"Unknown model {model_name}")
return self.add_model(model_name, base_model, model_type, model_attributes, clobber=True)
def del_model(
self,
@ -429,9 +467,9 @@ class ModelManagerService(ModelManagerServiceBase):
return self.mgr.logger
def heuristic_import(self,
items_to_import: Set[str],
prediction_type_helper: Callable[[Path],SchedulerPredictionType]=None,
)->Dict[str, AddModelResult]:
items_to_import: set[str],
prediction_type_helper: Optional[Callable[[Path],SchedulerPredictionType]]=None,
)->dict[str, AddModelResult]:
'''Import a list of paths, repo_ids or URLs. Returns the set of
successfully imported items.
:param items_to_import: Set of strings corresponding to models to be imported.

View File

@ -480,7 +480,7 @@ class ModelManager(object):
"""
model_key = self.create_key(model_name, base_model, model_type)
if model_key in self.models:
return self.models[model_key].dict(exclude_defaults=True)
return self.models[model_key].dict(exclude_defaults=True, exclude={"error"})
else:
return None # TODO: None or empty dict on not found