mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
accept @psychedelicious suggestions above
This commit is contained in:
commit
5027d0a603
@ -2,25 +2,28 @@
|
|||||||
|
|
||||||
from typing import Literal, Optional, Union
|
from typing import Literal, Optional, Union
|
||||||
|
|
||||||
from fastapi import Query, Body, Path
|
from fastapi import Body, Path, Query, Response
|
||||||
from fastapi.routing import APIRouter, HTTPException
|
from fastapi.routing import APIRouter
|
||||||
from pydantic import BaseModel, Field, parse_obj_as
|
from pydantic import BaseModel, Field, parse_obj_as
|
||||||
from ..dependencies import ApiDependencies
|
from starlette.exceptions import HTTPException
|
||||||
|
|
||||||
from invokeai.backend import BaseModelType, ModelType
|
from invokeai.backend import BaseModelType, ModelType
|
||||||
from invokeai.backend.model_management import AddModelResult
|
from invokeai.backend.model_management import AddModelResult
|
||||||
from invokeai.backend.model_management.models import MODEL_CONFIGS, OPENAPI_MODEL_CONFIGS, SchedulerPredictionType
|
from invokeai.backend.model_management.models import (MODEL_CONFIGS,
|
||||||
|
OPENAPI_MODEL_CONFIGS,
|
||||||
|
SchedulerPredictionType)
|
||||||
|
|
||||||
|
from ..dependencies import ApiDependencies
|
||||||
|
|
||||||
models_router = APIRouter(prefix="/v1/models", tags=["models"])
|
models_router = APIRouter(prefix="/v1/models", tags=["models"])
|
||||||
|
|
||||||
class CreateModelResponse(BaseModel):
|
class UpdateModelResponse(BaseModel):
|
||||||
model_name: str = Field(description="The name of the new model")
|
model_name: str = Field(description="The name of the new model")
|
||||||
info: Union[tuple(MODEL_CONFIGS)] = Field(description="The model info")
|
info: Union[tuple(MODEL_CONFIGS)] = Field(description="The model info")
|
||||||
status: str = Field(description="The status of the API response")
|
|
||||||
|
|
||||||
class ImportModelResponse(BaseModel):
|
class ImportModelResponse(BaseModel):
|
||||||
name: str = Field(description="The name of the imported model")
|
location: str = Field(description="The path, repo_id or URL of the imported model")
|
||||||
info: AddModelResult = Field(description="The model info")
|
info: AddModelResult = Field(description="The model info")
|
||||||
status: str = Field(description="The status of the API response")
|
|
||||||
|
|
||||||
class ConvertModelResponse(BaseModel):
|
class ConvertModelResponse(BaseModel):
|
||||||
name: str = Field(description="The name of the imported model")
|
name: str = Field(description="The name of the imported model")
|
||||||
@ -48,51 +51,65 @@ async def list_models(
|
|||||||
models = parse_obj_as(ModelsList, { "models": models_raw })
|
models = parse_obj_as(ModelsList, { "models": models_raw })
|
||||||
return models
|
return models
|
||||||
|
|
||||||
@models_router.post(
|
@models_router.patch(
|
||||||
"/{base_model}/{model_type}/{model_name}",
|
"/{base_model}/{model_type}/{model_name}",
|
||||||
operation_id="update_model",
|
operation_id="update_model",
|
||||||
responses={200: {"status": "success"}},
|
responses={200: {"description" : "The model was updated successfully"},
|
||||||
|
404: {"description" : "The model could not be found"},
|
||||||
|
400: {"description" : "Bad request"}
|
||||||
|
},
|
||||||
|
status_code = 200,
|
||||||
|
response_model = UpdateModelResponse,
|
||||||
)
|
)
|
||||||
async def update_model(
|
async def update_model(
|
||||||
base_model: BaseModelType = Path(default='sd-1', description="Base model"),
|
base_model: BaseModelType = Path(default='sd-1', description="Base model"),
|
||||||
model_type: ModelType = Path(default='main', description="The type of model"),
|
model_type: ModelType = Path(default='main', description="The type of model"),
|
||||||
model_name: str = Path(default=None, description="model name"),
|
model_name: str = Path(default=None, description="model name"),
|
||||||
info: Union[tuple(MODEL_CONFIGS)] = Body(description="Model configuration"),
|
info: Union[tuple(MODEL_CONFIGS)] = Body(description="Model configuration"),
|
||||||
) -> CreateModelResponse:
|
) -> UpdateModelResponse:
|
||||||
""" Add Model """
|
""" Add Model """
|
||||||
ApiDependencies.invoker.services.model_manager.add_model(
|
try:
|
||||||
model_name=model_name,
|
ApiDependencies.invoker.services.model_manager.update_model(
|
||||||
base_model=base_model,
|
model_name=model_name,
|
||||||
model_type=model_type,
|
base_model=base_model,
|
||||||
model_attributes=info.dict(),
|
model_type=model_type,
|
||||||
clobber=True,
|
model_attributes=info.dict()
|
||||||
)
|
)
|
||||||
model_response = CreateModelResponse(
|
model_response = UpdateModelResponse(
|
||||||
model_name = model_name,
|
model_name = model_name,
|
||||||
info = info,
|
info = ApiDependencies.invoker.services.model_manager.model_info(
|
||||||
status="success")
|
model_name=model_name,
|
||||||
|
base_model=base_model,
|
||||||
|
model_type=model_type,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except KeyError as e:
|
||||||
|
raise HTTPException(status_code=404, detail=str(e))
|
||||||
|
except ValueError as e:
|
||||||
|
raise HTTPException(status_code=400, detail=str(e))
|
||||||
|
|
||||||
return model_response
|
return model_response
|
||||||
|
|
||||||
@models_router.post(
|
@models_router.post(
|
||||||
"/import",
|
"/",
|
||||||
operation_id="import_model",
|
operation_id="import_model",
|
||||||
responses= {
|
responses= {
|
||||||
201: {"description" : "The model imported successfully"},
|
201: {"description" : "The model imported successfully"},
|
||||||
404: {"description" : "The model could not be found"},
|
404: {"description" : "The model could not be found"},
|
||||||
|
424: {"description" : "The model appeared to import successfully, but could not be found in the model manager"},
|
||||||
409: {"description" : "There is already a model corresponding to this path or repo_id"},
|
409: {"description" : "There is already a model corresponding to this path or repo_id"},
|
||||||
},
|
},
|
||||||
status_code=201,
|
status_code=201,
|
||||||
response_model=ImportModelResponse
|
response_model=ImportModelResponse
|
||||||
)
|
)
|
||||||
async def import_model(
|
async def import_model(
|
||||||
name: str = Body(description="A model path, repo_id or URL to import"),
|
location: str = Body(description="A model path, repo_id or URL to import"),
|
||||||
prediction_type: Optional[Literal['v_prediction','epsilon','sample']] = \
|
prediction_type: Optional[Literal['v_prediction','epsilon','sample']] = \
|
||||||
Body(description='Prediction type for SDv2 checkpoint files', default="v_prediction"),
|
Body(description='Prediction type for SDv2 checkpoint files', default="v_prediction"),
|
||||||
) -> ImportModelResponse:
|
) -> ImportModelResponse:
|
||||||
""" Add a model using its local path, repo_id, or remote URL """
|
""" Add a model using its local path, repo_id, or remote URL """
|
||||||
|
|
||||||
items_to_import = {name}
|
items_to_import = {location}
|
||||||
prediction_types = { x.value: x for x in SchedulerPredictionType }
|
prediction_types = { x.value: x for x in SchedulerPredictionType }
|
||||||
logger = ApiDependencies.invoker.services.logger
|
logger = ApiDependencies.invoker.services.logger
|
||||||
|
|
||||||
@ -101,12 +118,16 @@ async def import_model(
|
|||||||
items_to_import = items_to_import,
|
items_to_import = items_to_import,
|
||||||
prediction_type_helper = lambda x: prediction_types.get(prediction_type)
|
prediction_type_helper = lambda x: prediction_types.get(prediction_type)
|
||||||
)
|
)
|
||||||
if info := installed_models.get(name):
|
info = installed_models.get(location)
|
||||||
logger.info(f'Successfully imported {name}, got {info}')
|
|
||||||
return ImportModelResponse(
|
if not info:
|
||||||
name = name,
|
logger.error("Import failed")
|
||||||
|
raise HTTPException(status_code=424)
|
||||||
|
|
||||||
|
logger.info(f'Successfully imported {location}, got {info}')
|
||||||
|
return ImportModelResponse(
|
||||||
|
location = location,
|
||||||
info = info,
|
info = info,
|
||||||
status = "success",
|
|
||||||
)
|
)
|
||||||
except KeyError as e:
|
except KeyError as e:
|
||||||
logger.error(str(e))
|
logger.error(str(e))
|
||||||
@ -129,10 +150,10 @@ async def import_model(
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
async def delete_model(
|
async def delete_model(
|
||||||
base_model: BaseModelType = Path(default='sd-1', description="Base model"),
|
base_model: BaseModelType = Path(description="Base model"),
|
||||||
model_type: ModelType = Path(default='main', description="The type of model"),
|
model_type: ModelType = Path(description="The type of model"),
|
||||||
model_name: str = Path(default=None, description="model name"),
|
model_name: str = Path(description="model name"),
|
||||||
) -> None:
|
) -> Response:
|
||||||
"""Delete Model"""
|
"""Delete Model"""
|
||||||
logger = ApiDependencies.invoker.services.logger
|
logger = ApiDependencies.invoker.services.logger
|
||||||
|
|
||||||
@ -142,14 +163,10 @@ async def delete_model(
|
|||||||
model_type = model_type
|
model_type = model_type
|
||||||
)
|
)
|
||||||
logger.info(f"Deleted model: {model_name}")
|
logger.info(f"Deleted model: {model_name}")
|
||||||
raise HTTPException(status_code=204, detail=f"Model '{model_name}' deleted successfully")
|
return Response(status_code=204)
|
||||||
except KeyError:
|
except KeyError:
|
||||||
logger.error(f"Model not found: {model_name}")
|
logger.error(f"Model not found: {model_name}")
|
||||||
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found")
|
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found")
|
||||||
else:
|
|
||||||
logger.info(f"Model deleted: {model_name}")
|
|
||||||
raise HTTPException(status_code=204, detail=f"Model '{model_name}' deleted successfully")
|
|
||||||
|
|
||||||
|
|
||||||
# @socketio.on("convertToDiffusers")
|
# @socketio.on("convertToDiffusers")
|
||||||
# def convert_to_diffusers(model_to_convert: dict):
|
# def convert_to_diffusers(model_to_convert: dict):
|
||||||
|
@ -2,10 +2,10 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import torch
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional, Union, Callable, List, Set, Dict, Tuple, types, TYPE_CHECKING
|
from typing import Optional, Union, Callable, List, Tuple, TYPE_CHECKING
|
||||||
|
from types import ModuleType
|
||||||
|
|
||||||
from invokeai.backend.model_management.model_manager import (
|
from invokeai.backend.model_management.model_manager import (
|
||||||
ModelManager,
|
ModelManager,
|
||||||
@ -16,9 +16,11 @@ from invokeai.backend.model_management.model_manager import (
|
|||||||
AddModelResult,
|
AddModelResult,
|
||||||
SchedulerPredictionType,
|
SchedulerPredictionType,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
import torch
|
||||||
from invokeai.app.models.exceptions import CanceledException
|
from invokeai.app.models.exceptions import CanceledException
|
||||||
from .config import InvokeAIAppConfig
|
|
||||||
from ...backend.util import choose_precision, choose_torch_device
|
from ...backend.util import choose_precision, choose_torch_device
|
||||||
|
from .config import InvokeAIAppConfig
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ..invocations.baseinvocation import BaseInvocation, InvocationContext
|
from ..invocations.baseinvocation import BaseInvocation, InvocationContext
|
||||||
@ -31,7 +33,7 @@ class ModelManagerServiceBase(ABC):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: InvokeAIAppConfig,
|
config: InvokeAIAppConfig,
|
||||||
logger: types.ModuleType,
|
logger: ModuleType,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initialize with the path to the models.yaml config file.
|
Initialize with the path to the models.yaml config file.
|
||||||
@ -122,6 +124,24 @@ class ModelManagerServiceBase(ABC):
|
|||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def update_model(
|
||||||
|
self,
|
||||||
|
model_name: str,
|
||||||
|
base_model: BaseModelType,
|
||||||
|
model_type: ModelType,
|
||||||
|
model_attributes: dict,
|
||||||
|
) -> AddModelResult:
|
||||||
|
"""
|
||||||
|
Update the named model with a dictionary of attributes. Will fail with a
|
||||||
|
KeyErrorException if the name does not already exist.
|
||||||
|
|
||||||
|
On a successful update, the config will be changed in memory. Will fail
|
||||||
|
with an assertion error if provided attributes are incorrect or
|
||||||
|
the model name is missing. Call commit() to write changes to disk.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def del_model(
|
def del_model(
|
||||||
self,
|
self,
|
||||||
@ -159,9 +179,9 @@ class ModelManagerServiceBase(ABC):
|
|||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def heuristic_import(self,
|
def heuristic_import(self,
|
||||||
items_to_import: Set[str],
|
items_to_import: set[str],
|
||||||
prediction_type_helper: Callable[[Path],SchedulerPredictionType]=None,
|
prediction_type_helper: Optional[Callable[[Path],SchedulerPredictionType]]=None,
|
||||||
)->Dict[str, AddModelResult]:
|
)->dict[str, AddModelResult]:
|
||||||
'''Import a list of paths, repo_ids or URLs. Returns the set of
|
'''Import a list of paths, repo_ids or URLs. Returns the set of
|
||||||
successfully imported items.
|
successfully imported items.
|
||||||
:param items_to_import: Set of strings corresponding to models to be imported.
|
:param items_to_import: Set of strings corresponding to models to be imported.
|
||||||
@ -181,7 +201,7 @@ class ModelManagerServiceBase(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def commit(self, conf_file: Path = None) -> None:
|
def commit(self, conf_file: Optional[Path] = None) -> None:
|
||||||
"""
|
"""
|
||||||
Write current configuration out to the indicated file.
|
Write current configuration out to the indicated file.
|
||||||
If no conf_file is provided, then replaces the
|
If no conf_file is provided, then replaces the
|
||||||
@ -195,7 +215,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: InvokeAIAppConfig,
|
config: InvokeAIAppConfig,
|
||||||
logger: types.ModuleType,
|
logger: ModuleType,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initialize with the path to the models.yaml config file.
|
Initialize with the path to the models.yaml config file.
|
||||||
@ -343,7 +363,25 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
self.logger.debug(f'add/update model {model_name}')
|
self.logger.debug(f'add/update model {model_name}')
|
||||||
return self.mgr.add_model(model_name, base_model, model_type, model_attributes, clobber)
|
return self.mgr.add_model(model_name, base_model, model_type, model_attributes, clobber)
|
||||||
|
|
||||||
|
def update_model(
|
||||||
|
self,
|
||||||
|
model_name: str,
|
||||||
|
base_model: BaseModelType,
|
||||||
|
model_type: ModelType,
|
||||||
|
model_attributes: dict,
|
||||||
|
) -> AddModelResult:
|
||||||
|
"""
|
||||||
|
Update the named model with a dictionary of attributes. Will fail with a
|
||||||
|
KeyError exception if the name does not already exist.
|
||||||
|
On a successful update, the config will be changed in memory. Will fail
|
||||||
|
with an assertion error if provided attributes are incorrect or
|
||||||
|
the model name is missing. Call commit() to write changes to disk.
|
||||||
|
"""
|
||||||
|
self.logger.debug(f'update model {model_name}')
|
||||||
|
if not self.model_exists(model_name, base_model, model_type):
|
||||||
|
raise KeyError(f"Unknown model {model_name}")
|
||||||
|
return self.add_model(model_name, base_model, model_type, model_attributes, clobber=True)
|
||||||
|
|
||||||
def del_model(
|
def del_model(
|
||||||
self,
|
self,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
@ -429,9 +467,9 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
return self.mgr.logger
|
return self.mgr.logger
|
||||||
|
|
||||||
def heuristic_import(self,
|
def heuristic_import(self,
|
||||||
items_to_import: Set[str],
|
items_to_import: set[str],
|
||||||
prediction_type_helper: Callable[[Path],SchedulerPredictionType]=None,
|
prediction_type_helper: Optional[Callable[[Path],SchedulerPredictionType]]=None,
|
||||||
)->Dict[str, AddModelResult]:
|
)->dict[str, AddModelResult]:
|
||||||
'''Import a list of paths, repo_ids or URLs. Returns the set of
|
'''Import a list of paths, repo_ids or URLs. Returns the set of
|
||||||
successfully imported items.
|
successfully imported items.
|
||||||
:param items_to_import: Set of strings corresponding to models to be imported.
|
:param items_to_import: Set of strings corresponding to models to be imported.
|
||||||
|
@ -480,7 +480,7 @@ class ModelManager(object):
|
|||||||
"""
|
"""
|
||||||
model_key = self.create_key(model_name, base_model, model_type)
|
model_key = self.create_key(model_name, base_model, model_type)
|
||||||
if model_key in self.models:
|
if model_key in self.models:
|
||||||
return self.models[model_key].dict(exclude_defaults=True)
|
return self.models[model_key].dict(exclude_defaults=True, exclude={"error"})
|
||||||
else:
|
else:
|
||||||
return None # TODO: None or empty dict on not found
|
return None # TODO: None or empty dict on not found
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user