revert(mm): restore convert route

This commit is contained in:
psychedelicious 2024-03-05 18:40:17 +11:00
parent 7c9128b253
commit 48119d9010

View File

@ -2,6 +2,7 @@
"""FastAPI route for model configuration records.""" """FastAPI route for model configuration records."""
import pathlib import pathlib
import shutil
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from fastapi import Body, Path, Query, Response from fastapi import Body, Path, Query, Response
@ -15,12 +16,14 @@ from invokeai.app.services.model_records import (
InvalidModelException, InvalidModelException,
UnknownModelException, UnknownModelException,
) )
from invokeai.app.services.model_records.model_records_base import ModelRecordChanges from invokeai.app.services.model_records.model_records_base import DuplicateModelException, ModelRecordChanges
from invokeai.backend.model_manager.config import ( from invokeai.backend.model_manager.config import (
AnyModelConfig, AnyModelConfig,
BaseModelType, BaseModelType,
MainCheckpointConfig,
ModelFormat, ModelFormat,
ModelType, ModelType,
SubModelType,
) )
from invokeai.backend.model_manager.search import ModelSearch from invokeai.backend.model_manager.search import ModelSearch
@ -465,8 +468,8 @@ async def cancel_model_install_job(id: int = Path(description="Model install job
installer.cancel_job(job) installer.cancel_job(job)
@model_manager_router.patch( @model_manager_router.delete(
"/install/prune", "/install",
operation_id="prune_model_install_jobs", operation_id="prune_model_install_jobs",
responses={ responses={
204: {"description": "All completed and errored jobs have been pruned"}, 204: {"description": "All completed and errored jobs have been pruned"},
@ -498,80 +501,80 @@ async def sync_models_to_config() -> Response:
return Response(status_code=204) return Response(status_code=204)
# @model_manager_router.put( @model_manager_router.put(
# "/convert/{key}", "/convert/{key}",
# operation_id="convert_model", operation_id="convert_model",
# responses={ responses={
# 200: { 200: {
# "description": "Model converted successfully", "description": "Model converted successfully",
# "content": {"application/json": {"example": example_model_config}}, "content": {"application/json": {"example": example_model_config}},
# }, },
# 400: {"description": "Bad request"}, 400: {"description": "Bad request"},
# 404: {"description": "Model not found"}, 404: {"description": "Model not found"},
# 409: {"description": "There is already a model registered at this location"}, 409: {"description": "There is already a model registered at this location"},
# }, },
# ) )
# async def convert_model( async def convert_model(
# key: str = Path(description="Unique key of the safetensors main model to convert to diffusers format."), key: str = Path(description="Unique key of the safetensors main model to convert to diffusers format."),
# ) -> AnyModelConfig: ) -> AnyModelConfig:
# """ """
# Permanently convert a model into diffusers format, replacing the safetensors version. Permanently convert a model into diffusers format, replacing the safetensors version.
# Note that during the conversion process the key and model hash will change. Note that during the conversion process the key and model hash will change.
# The return value is the model configuration for the converted model. The return value is the model configuration for the converted model.
# """ """
# model_manager = ApiDependencies.invoker.services.model_manager model_manager = ApiDependencies.invoker.services.model_manager
# logger = ApiDependencies.invoker.services.logger logger = ApiDependencies.invoker.services.logger
# loader = ApiDependencies.invoker.services.model_manager.load loader = ApiDependencies.invoker.services.model_manager.load
# store = ApiDependencies.invoker.services.model_manager.store store = ApiDependencies.invoker.services.model_manager.store
# installer = ApiDependencies.invoker.services.model_manager.install installer = ApiDependencies.invoker.services.model_manager.install
# try: try:
# model_config = store.get_model(key) model_config = store.get_model(key)
# except UnknownModelException as e: except UnknownModelException as e:
# logger.error(str(e)) logger.error(str(e))
# raise HTTPException(status_code=424, detail=str(e)) raise HTTPException(status_code=424, detail=str(e))
# if not isinstance(model_config, MainCheckpointConfig): if not isinstance(model_config, MainCheckpointConfig):
# logger.error(f"The model with key {key} is not a main checkpoint model.") logger.error(f"The model with key {key} is not a main checkpoint model.")
# raise HTTPException(400, f"The model with key {key} is not a main checkpoint model.") raise HTTPException(400, f"The model with key {key} is not a main checkpoint model.")
# # loading the model will convert it into a cached diffusers file # loading the model will convert it into a cached diffusers file
# model_manager.load_model_by_config(model_config, submodel_type=SubModelType.Scheduler) model_manager.load_model_by_config(model_config, submodel_type=SubModelType.Scheduler)
# # Get the path of the converted model from the loader # Get the path of the converted model from the loader
# cache_path = loader.convert_cache.cache_path(key) cache_path = loader.convert_cache.cache_path(key)
# assert cache_path.exists() assert cache_path.exists()
# # temporarily rename the original safetensors file so that there is no naming conflict # temporarily rename the original safetensors file so that there is no naming conflict
# original_name = model_config.name original_name = model_config.name
# model_config.name = f"{original_name}.DELETE" model_config.name = f"{original_name}.DELETE"
# changes = ModelRecordChanges(name=model_config.name) changes = ModelRecordChanges(name=model_config.name)
# store.update_model(key, changes=changes) store.update_model(key, changes=changes)
# # install the diffusers # install the diffusers
# try: try:
# new_key = installer.install_path( new_key = installer.install_path(
# cache_path, cache_path,
# config={ config={
# "name": original_name, "name": original_name,
# "description": model_config.description, "description": model_config.description,
# "hash": model_config.hash, "hash": model_config.hash,
# "source": model_config.source, "source": model_config.source,
# }, },
# ) )
# except DuplicateModelException as e: except DuplicateModelException as e:
# logger.error(str(e)) logger.error(str(e))
# raise HTTPException(status_code=409, detail=str(e)) raise HTTPException(status_code=409, detail=str(e))
# # delete the original safetensors file # delete the original safetensors file
# installer.delete(key) installer.delete(key)
# # delete the cached version # delete the cached version
# shutil.rmtree(cache_path) shutil.rmtree(cache_path)
# # return the config record for the new diffusers directory # return the config record for the new diffusers directory
# new_config: AnyModelConfig = store.get_model(new_key) new_config: AnyModelConfig = store.get_model(new_key)
# return new_config return new_config
# @model_manager_router.put( # @model_manager_router.put(