Merge branch 'main' into sdxl-support

This commit is contained in:
blessedcoolant
2023-07-18 13:34:07 +12:00
100 changed files with 3322 additions and 10176 deletions

View File

@ -11,6 +11,7 @@ from invokeai.app.services.board_images import (
)
from invokeai.app.services.board_record_storage import SqliteBoardRecordStorage
from invokeai.app.services.boards import BoardService, BoardServiceDependencies
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.image_record_storage import SqliteImageRecordStorage
from invokeai.app.services.images import ImageService, ImageServiceDependencies
from invokeai.app.services.resource_name import SimpleNameService
@ -20,7 +21,6 @@ from invokeai.version.invokeai_version import __version__
from ..services.default_graphs import create_system_graphs
from ..services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
from ..services.restoration_services import RestorationServices
from ..services.graph import GraphExecutionState, LibraryGraph
from ..services.image_file_storage import DiskImageFileStorage
from ..services.invocation_queue import MemoryInvocationQueue
@ -57,8 +57,8 @@ class ApiDependencies:
invoker: Invoker = None
@staticmethod
def initialize(config, event_handler_id: int, logger: Logger = logger):
logger.debug(f'InvokeAI version {__version__}')
def initialize(config: InvokeAIAppConfig, event_handler_id: int, logger: Logger = logger):
logger.debug(f"InvokeAI version {__version__}")
logger.debug(f"Internet connectivity is {config.internet_available}")
events = FastAPIEventService(event_handler_id)
@ -117,7 +117,7 @@ class ApiDependencies:
)
services = InvocationServices(
model_manager=ModelManagerService(config,logger),
model_manager=ModelManagerService(config, logger),
events=events,
latents=latents,
images=images,
@ -129,7 +129,6 @@ class ApiDependencies:
),
graph_execution_manager=graph_execution_manager,
processor=DefaultInvocationProcessor(),
restoration=RestorationServices(config, logger),
configuration=config,
logger=logger,
)

View File

@ -13,8 +13,10 @@ from invokeai.backend import BaseModelType, ModelType
from invokeai.backend.model_management.models import (
OPENAPI_MODEL_CONFIGS,
SchedulerPredictionType,
ModelNotFoundException,
)
from invokeai.backend.model_management import MergeInterpolationMethod
from ..dependencies import ApiDependencies
models_router = APIRouter(prefix="/v1/models", tags=["models"])
@ -51,8 +53,9 @@ async def list_models(
"/{base_model}/{model_type}/{model_name}",
operation_id="update_model",
responses={200: {"description" : "The model was updated successfully"},
400: {"description" : "Bad request"},
404: {"description" : "The model could not be found"},
400: {"description" : "Bad request"}
409: {"description" : "There is already a model corresponding to the new name"},
},
status_code = 200,
response_model = UpdateModelResponse,
@ -63,23 +66,58 @@ async def update_model(
model_name: str = Path(description="model name"),
info: Union[tuple(OPENAPI_MODEL_CONFIGS)] = Body(description="Model configuration"),
) -> UpdateModelResponse:
""" Add Model """
""" Update model contents with a new config. If the model name or base fields are changed, then the model is renamed. """
logger = ApiDependencies.invoker.services.logger
try:
previous_info = ApiDependencies.invoker.services.model_manager.list_model(
model_name=model_name,
base_model=base_model,
model_type=model_type,
)
# rename operation requested
if info.model_name != model_name or info.base_model != base_model:
ApiDependencies.invoker.services.model_manager.rename_model(
base_model = base_model,
model_type = model_type,
model_name = model_name,
new_name = info.model_name,
new_base = info.base_model,
)
logger.info(f'Successfully renamed {base_model}/{model_name}=>{info.base_model}/{info.model_name}')
# update information to support an update of attributes
model_name = info.model_name
base_model = info.base_model
new_info = ApiDependencies.invoker.services.model_manager.list_model(
model_name=model_name,
base_model=base_model,
model_type=model_type,
)
if new_info.get('path') != previous_info.get('path'): # model manager moved model path during rename - don't overwrite it
info.path = new_info.get('path')
ApiDependencies.invoker.services.model_manager.update_model(
model_name=model_name,
base_model=base_model,
model_type=model_type,
model_attributes=info.dict()
)
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
model_name=model_name,
base_model=base_model,
model_type=model_type,
)
model_response = parse_obj_as(UpdateModelResponse, model_raw)
except KeyError as e:
except ModelNotFoundException as e:
raise HTTPException(status_code=404, detail=str(e))
except ValueError as e:
logger.error(str(e))
raise HTTPException(status_code=409, detail=str(e))
except Exception as e:
logger.error(str(e))
raise HTTPException(status_code=400, detail=str(e))
return model_response
@ -126,7 +164,7 @@ async def import_model(
)
return parse_obj_as(ImportModelResponse, model_raw)
except KeyError as e:
except ModelNotFoundException as e:
logger.error(str(e))
raise HTTPException(status_code=404, detail=str(e))
except ValueError as e:
@ -166,57 +204,13 @@ async def add_model(
model_type=info.model_type
)
return parse_obj_as(ImportModelResponse, model_raw)
except KeyError as e:
except ModelNotFoundException as e:
logger.error(str(e))
raise HTTPException(status_code=404, detail=str(e))
except ValueError as e:
logger.error(str(e))
raise HTTPException(status_code=409, detail=str(e))
@models_router.post(
"/rename/{base_model}/{model_type}/{model_name}",
operation_id="rename_model",
responses= {
201: {"description" : "The model was renamed successfully"},
404: {"description" : "The model could not be found"},
409: {"description" : "There is already a model corresponding to the new name"},
},
status_code=201,
response_model=ImportModelResponse
)
async def rename_model(
base_model: BaseModelType = Path(description="Base model"),
model_type: ModelType = Path(description="The type of model"),
model_name: str = Path(description="current model name"),
new_name: Optional[str] = Query(description="new model name", default=None),
new_base: Optional[BaseModelType] = Query(description="new model base", default=None),
) -> ImportModelResponse:
""" Rename a model"""
logger = ApiDependencies.invoker.services.logger
try:
result = ApiDependencies.invoker.services.model_manager.rename_model(
base_model = base_model,
model_type = model_type,
model_name = model_name,
new_name = new_name,
new_base = new_base,
)
logger.debug(result)
logger.info(f'Successfully renamed {model_name}=>{new_name}')
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
model_name=new_name or model_name,
base_model=new_base or base_model,
model_type=model_type
)
return parse_obj_as(ImportModelResponse, model_raw)
except KeyError as e:
logger.error(str(e))
raise HTTPException(status_code=404, detail=str(e))
except ValueError as e:
logger.error(str(e))
raise HTTPException(status_code=409, detail=str(e))
@models_router.delete(
"/{base_model}/{model_type}/{model_name}",
@ -243,9 +237,9 @@ async def delete_model(
)
logger.info(f"Deleted model: {model_name}")
return Response(status_code=204)
except KeyError:
logger.error(f"Model not found: {model_name}")
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found")
except ModelNotFoundException as e:
logger.error(str(e))
raise HTTPException(status_code=404, detail=str(e))
@models_router.put(
"/convert/{base_model}/{model_type}/{model_name}",
@ -278,8 +272,8 @@ async def convert_model(
base_model = base_model,
model_type = model_type)
response = parse_obj_as(ConvertModelResponse, model_raw)
except KeyError:
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found")
except ModelNotFoundException as e:
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found: {str(e)}")
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
return response
@ -369,8 +363,55 @@ async def merge_models(
model_type = ModelType.Main,
)
response = parse_obj_as(ConvertModelResponse, model_raw)
except KeyError:
except ModelNotFoundException:
raise HTTPException(status_code=404, detail=f"One or more of the models '{model_names}' not found")
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
return response
# The rename operation is now supported by update_model and no longer needs to be
# a standalone route.
# @models_router.post(
# "/rename/{base_model}/{model_type}/{model_name}",
# operation_id="rename_model",
# responses= {
# 201: {"description" : "The model was renamed successfully"},
# 404: {"description" : "The model could not be found"},
# 409: {"description" : "There is already a model corresponding to the new name"},
# },
# status_code=201,
# response_model=ImportModelResponse
# )
# async def rename_model(
# base_model: BaseModelType = Path(description="Base model"),
# model_type: ModelType = Path(description="The type of model"),
# model_name: str = Path(description="current model name"),
# new_name: Optional[str] = Query(description="new model name", default=None),
# new_base: Optional[BaseModelType] = Query(description="new model base", default=None),
# ) -> ImportModelResponse:
# """ Rename a model"""
# logger = ApiDependencies.invoker.services.logger
# try:
# result = ApiDependencies.invoker.services.model_manager.rename_model(
# base_model = base_model,
# model_type = model_type,
# model_name = model_name,
# new_name = new_name,
# new_base = new_base,
# )
# logger.debug(result)
# logger.info(f'Successfully renamed {model_name}=>{new_name}')
# model_raw = ApiDependencies.invoker.services.model_manager.list_model(
# model_name=new_name or model_name,
# base_model=new_base or base_model,
# model_type=model_type
# )
# return parse_obj_as(ImportModelResponse, model_raw)
# except ModelNotFoundException as e:
# logger.error(str(e))
# raise HTTPException(status_code=404, detail=str(e))
# except ValueError as e:
# logger.error(str(e))
# raise HTTPException(status_code=409, detail=str(e))