accept @psychedelicious suggestions above

This commit is contained in:
Lincoln Stein 2023-07-05 14:50:57 -04:00
commit 5027d0a603
3 changed files with 108 additions and 53 deletions

View File

@ -2,25 +2,28 @@
from typing import Literal, Optional, Union from typing import Literal, Optional, Union
from fastapi import Query, Body, Path from fastapi import Body, Path, Query, Response
from fastapi.routing import APIRouter, HTTPException from fastapi.routing import APIRouter
from pydantic import BaseModel, Field, parse_obj_as from pydantic import BaseModel, Field, parse_obj_as
from ..dependencies import ApiDependencies from starlette.exceptions import HTTPException
from invokeai.backend import BaseModelType, ModelType from invokeai.backend import BaseModelType, ModelType
from invokeai.backend.model_management import AddModelResult from invokeai.backend.model_management import AddModelResult
from invokeai.backend.model_management.models import MODEL_CONFIGS, OPENAPI_MODEL_CONFIGS, SchedulerPredictionType from invokeai.backend.model_management.models import (MODEL_CONFIGS,
OPENAPI_MODEL_CONFIGS,
SchedulerPredictionType)
from ..dependencies import ApiDependencies
models_router = APIRouter(prefix="/v1/models", tags=["models"]) models_router = APIRouter(prefix="/v1/models", tags=["models"])
class CreateModelResponse(BaseModel): class UpdateModelResponse(BaseModel):
model_name: str = Field(description="The name of the new model") model_name: str = Field(description="The name of the new model")
info: Union[tuple(MODEL_CONFIGS)] = Field(description="The model info") info: Union[tuple(MODEL_CONFIGS)] = Field(description="The model info")
status: str = Field(description="The status of the API response")
class ImportModelResponse(BaseModel): class ImportModelResponse(BaseModel):
name: str = Field(description="The name of the imported model") location: str = Field(description="The path, repo_id or URL of the imported model")
info: AddModelResult = Field(description="The model info") info: AddModelResult = Field(description="The model info")
status: str = Field(description="The status of the API response")
class ConvertModelResponse(BaseModel): class ConvertModelResponse(BaseModel):
name: str = Field(description="The name of the imported model") name: str = Field(description="The name of the imported model")
@ -48,51 +51,65 @@ async def list_models(
models = parse_obj_as(ModelsList, { "models": models_raw }) models = parse_obj_as(ModelsList, { "models": models_raw })
return models return models
@models_router.post( @models_router.patch(
"/{base_model}/{model_type}/{model_name}", "/{base_model}/{model_type}/{model_name}",
operation_id="update_model", operation_id="update_model",
responses={200: {"status": "success"}}, responses={200: {"description" : "The model was updated successfully"},
404: {"description" : "The model could not be found"},
400: {"description" : "Bad request"}
},
status_code = 200,
response_model = UpdateModelResponse,
) )
async def update_model( async def update_model(
base_model: BaseModelType = Path(default='sd-1', description="Base model"), base_model: BaseModelType = Path(default='sd-1', description="Base model"),
model_type: ModelType = Path(default='main', description="The type of model"), model_type: ModelType = Path(default='main', description="The type of model"),
model_name: str = Path(default=None, description="model name"), model_name: str = Path(default=None, description="model name"),
info: Union[tuple(MODEL_CONFIGS)] = Body(description="Model configuration"), info: Union[tuple(MODEL_CONFIGS)] = Body(description="Model configuration"),
) -> CreateModelResponse: ) -> UpdateModelResponse:
""" Add Model """ """ Add Model """
ApiDependencies.invoker.services.model_manager.add_model( try:
model_name=model_name, ApiDependencies.invoker.services.model_manager.update_model(
base_model=base_model, model_name=model_name,
model_type=model_type, base_model=base_model,
model_attributes=info.dict(), model_type=model_type,
clobber=True, model_attributes=info.dict()
) )
model_response = CreateModelResponse( model_response = UpdateModelResponse(
model_name = model_name, model_name = model_name,
info = info, info = ApiDependencies.invoker.services.model_manager.model_info(
status="success") model_name=model_name,
base_model=base_model,
model_type=model_type,
)
)
except KeyError as e:
raise HTTPException(status_code=404, detail=str(e))
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
return model_response return model_response
@models_router.post( @models_router.post(
"/import", "/",
operation_id="import_model", operation_id="import_model",
responses= { responses= {
201: {"description" : "The model imported successfully"}, 201: {"description" : "The model imported successfully"},
404: {"description" : "The model could not be found"}, 404: {"description" : "The model could not be found"},
424: {"description" : "The model appeared to import successfully, but could not be found in the model manager"},
409: {"description" : "There is already a model corresponding to this path or repo_id"}, 409: {"description" : "There is already a model corresponding to this path or repo_id"},
}, },
status_code=201, status_code=201,
response_model=ImportModelResponse response_model=ImportModelResponse
) )
async def import_model( async def import_model(
name: str = Body(description="A model path, repo_id or URL to import"), location: str = Body(description="A model path, repo_id or URL to import"),
prediction_type: Optional[Literal['v_prediction','epsilon','sample']] = \ prediction_type: Optional[Literal['v_prediction','epsilon','sample']] = \
Body(description='Prediction type for SDv2 checkpoint files', default="v_prediction"), Body(description='Prediction type for SDv2 checkpoint files', default="v_prediction"),
) -> ImportModelResponse: ) -> ImportModelResponse:
""" Add a model using its local path, repo_id, or remote URL """ """ Add a model using its local path, repo_id, or remote URL """
items_to_import = {name} items_to_import = {location}
prediction_types = { x.value: x for x in SchedulerPredictionType } prediction_types = { x.value: x for x in SchedulerPredictionType }
logger = ApiDependencies.invoker.services.logger logger = ApiDependencies.invoker.services.logger
@ -101,12 +118,16 @@ async def import_model(
items_to_import = items_to_import, items_to_import = items_to_import,
prediction_type_helper = lambda x: prediction_types.get(prediction_type) prediction_type_helper = lambda x: prediction_types.get(prediction_type)
) )
if info := installed_models.get(name): info = installed_models.get(location)
logger.info(f'Successfully imported {name}, got {info}')
return ImportModelResponse( if not info:
name = name, logger.error("Import failed")
raise HTTPException(status_code=424)
logger.info(f'Successfully imported {location}, got {info}')
return ImportModelResponse(
location = location,
info = info, info = info,
status = "success",
) )
except KeyError as e: except KeyError as e:
logger.error(str(e)) logger.error(str(e))
@ -129,10 +150,10 @@ async def import_model(
}, },
) )
async def delete_model( async def delete_model(
base_model: BaseModelType = Path(default='sd-1', description="Base model"), base_model: BaseModelType = Path(description="Base model"),
model_type: ModelType = Path(default='main', description="The type of model"), model_type: ModelType = Path(description="The type of model"),
model_name: str = Path(default=None, description="model name"), model_name: str = Path(description="model name"),
) -> None: ) -> Response:
"""Delete Model""" """Delete Model"""
logger = ApiDependencies.invoker.services.logger logger = ApiDependencies.invoker.services.logger
@ -142,14 +163,10 @@ async def delete_model(
model_type = model_type model_type = model_type
) )
logger.info(f"Deleted model: {model_name}") logger.info(f"Deleted model: {model_name}")
raise HTTPException(status_code=204, detail=f"Model '{model_name}' deleted successfully") return Response(status_code=204)
except KeyError: except KeyError:
logger.error(f"Model not found: {model_name}") logger.error(f"Model not found: {model_name}")
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found") raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found")
else:
logger.info(f"Model deleted: {model_name}")
raise HTTPException(status_code=204, detail=f"Model '{model_name}' deleted successfully")
# @socketio.on("convertToDiffusers") # @socketio.on("convertToDiffusers")
# def convert_to_diffusers(model_to_convert: dict): # def convert_to_diffusers(model_to_convert: dict):

View File

@ -2,10 +2,10 @@
from __future__ import annotations from __future__ import annotations
import torch
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from pathlib import Path from pathlib import Path
from typing import Optional, Union, Callable, List, Set, Dict, Tuple, types, TYPE_CHECKING from typing import Optional, Union, Callable, List, Tuple, TYPE_CHECKING
from types import ModuleType
from invokeai.backend.model_management.model_manager import ( from invokeai.backend.model_management.model_manager import (
ModelManager, ModelManager,
@ -16,9 +16,11 @@ from invokeai.backend.model_management.model_manager import (
AddModelResult, AddModelResult,
SchedulerPredictionType, SchedulerPredictionType,
) )
import torch
from invokeai.app.models.exceptions import CanceledException from invokeai.app.models.exceptions import CanceledException
from .config import InvokeAIAppConfig
from ...backend.util import choose_precision, choose_torch_device from ...backend.util import choose_precision, choose_torch_device
from .config import InvokeAIAppConfig
if TYPE_CHECKING: if TYPE_CHECKING:
from ..invocations.baseinvocation import BaseInvocation, InvocationContext from ..invocations.baseinvocation import BaseInvocation, InvocationContext
@ -31,7 +33,7 @@ class ModelManagerServiceBase(ABC):
def __init__( def __init__(
self, self,
config: InvokeAIAppConfig, config: InvokeAIAppConfig,
logger: types.ModuleType, logger: ModuleType,
): ):
""" """
Initialize with the path to the models.yaml config file. Initialize with the path to the models.yaml config file.
@ -122,6 +124,24 @@ class ModelManagerServiceBase(ABC):
""" """
pass pass
@abstractmethod
def update_model(
self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
model_attributes: dict,
) -> AddModelResult:
"""
Update the named model with a dictionary of attributes. Will fail with a
KeyErrorException if the name does not already exist.
On a successful update, the config will be changed in memory. Will fail
with an assertion error if provided attributes are incorrect or
the model name is missing. Call commit() to write changes to disk.
"""
pass
@abstractmethod @abstractmethod
def del_model( def del_model(
self, self,
@ -159,9 +179,9 @@ class ModelManagerServiceBase(ABC):
@abstractmethod @abstractmethod
def heuristic_import(self, def heuristic_import(self,
items_to_import: Set[str], items_to_import: set[str],
prediction_type_helper: Callable[[Path],SchedulerPredictionType]=None, prediction_type_helper: Optional[Callable[[Path],SchedulerPredictionType]]=None,
)->Dict[str, AddModelResult]: )->dict[str, AddModelResult]:
'''Import a list of paths, repo_ids or URLs. Returns the set of '''Import a list of paths, repo_ids or URLs. Returns the set of
successfully imported items. successfully imported items.
:param items_to_import: Set of strings corresponding to models to be imported. :param items_to_import: Set of strings corresponding to models to be imported.
@ -181,7 +201,7 @@ class ModelManagerServiceBase(ABC):
pass pass
@abstractmethod @abstractmethod
def commit(self, conf_file: Path = None) -> None: def commit(self, conf_file: Optional[Path] = None) -> None:
""" """
Write current configuration out to the indicated file. Write current configuration out to the indicated file.
If no conf_file is provided, then replaces the If no conf_file is provided, then replaces the
@ -195,7 +215,7 @@ class ModelManagerService(ModelManagerServiceBase):
def __init__( def __init__(
self, self,
config: InvokeAIAppConfig, config: InvokeAIAppConfig,
logger: types.ModuleType, logger: ModuleType,
): ):
""" """
Initialize with the path to the models.yaml config file. Initialize with the path to the models.yaml config file.
@ -343,7 +363,25 @@ class ModelManagerService(ModelManagerServiceBase):
self.logger.debug(f'add/update model {model_name}') self.logger.debug(f'add/update model {model_name}')
return self.mgr.add_model(model_name, base_model, model_type, model_attributes, clobber) return self.mgr.add_model(model_name, base_model, model_type, model_attributes, clobber)
def update_model(
self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
model_attributes: dict,
) -> AddModelResult:
"""
Update the named model with a dictionary of attributes. Will fail with a
KeyError exception if the name does not already exist.
On a successful update, the config will be changed in memory. Will fail
with an assertion error if provided attributes are incorrect or
the model name is missing. Call commit() to write changes to disk.
"""
self.logger.debug(f'update model {model_name}')
if not self.model_exists(model_name, base_model, model_type):
raise KeyError(f"Unknown model {model_name}")
return self.add_model(model_name, base_model, model_type, model_attributes, clobber=True)
def del_model( def del_model(
self, self,
model_name: str, model_name: str,
@ -429,9 +467,9 @@ class ModelManagerService(ModelManagerServiceBase):
return self.mgr.logger return self.mgr.logger
def heuristic_import(self, def heuristic_import(self,
items_to_import: Set[str], items_to_import: set[str],
prediction_type_helper: Callable[[Path],SchedulerPredictionType]=None, prediction_type_helper: Optional[Callable[[Path],SchedulerPredictionType]]=None,
)->Dict[str, AddModelResult]: )->dict[str, AddModelResult]:
'''Import a list of paths, repo_ids or URLs. Returns the set of '''Import a list of paths, repo_ids or URLs. Returns the set of
successfully imported items. successfully imported items.
:param items_to_import: Set of strings corresponding to models to be imported. :param items_to_import: Set of strings corresponding to models to be imported.

View File

@ -480,7 +480,7 @@ class ModelManager(object):
""" """
model_key = self.create_key(model_name, base_model, model_type) model_key = self.create_key(model_name, base_model, model_type)
if model_key in self.models: if model_key in self.models:
return self.models[model_key].dict(exclude_defaults=True) return self.models[model_key].dict(exclude_defaults=True, exclude={"error"})
else: else:
return None # TODO: None or empty dict on not found return None # TODO: None or empty dict on not found