mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
add rename_model route
This commit is contained in:
parent
e71ce83e9c
commit
2faa7cee37
@ -23,6 +23,7 @@ UpdateModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
|||||||
ImportModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
ImportModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||||
ConvertModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
ConvertModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||||
MergeModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
MergeModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||||
|
ImportModelAttributes = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||||
|
|
||||||
class ModelsList(BaseModel):
|
class ModelsList(BaseModel):
|
||||||
models: list[Union[tuple(OPENAPI_MODEL_CONFIGS)]]
|
models: list[Union[tuple(OPENAPI_MODEL_CONFIGS)]]
|
||||||
@ -79,7 +80,7 @@ async def update_model(
|
|||||||
return model_response
|
return model_response
|
||||||
|
|
||||||
@models_router.post(
|
@models_router.post(
|
||||||
"/",
|
"/import",
|
||||||
operation_id="import_model",
|
operation_id="import_model",
|
||||||
responses= {
|
responses= {
|
||||||
201: {"description" : "The model imported successfully"},
|
201: {"description" : "The model imported successfully"},
|
||||||
@ -95,7 +96,7 @@ async def import_model(
|
|||||||
prediction_type: Optional[Literal['v_prediction','epsilon','sample']] = \
|
prediction_type: Optional[Literal['v_prediction','epsilon','sample']] = \
|
||||||
Body(description='Prediction type for SDv2 checkpoint files', default="v_prediction"),
|
Body(description='Prediction type for SDv2 checkpoint files', default="v_prediction"),
|
||||||
) -> ImportModelResponse:
|
) -> ImportModelResponse:
|
||||||
""" Add a model using its local path, repo_id, or remote URL """
|
""" Add a model using its local path, repo_id, or remote URL. Model characteristics will be probed and configured automatically """
|
||||||
|
|
||||||
items_to_import = {location}
|
items_to_import = {location}
|
||||||
prediction_types = { x.value: x for x in SchedulerPredictionType }
|
prediction_types = { x.value: x for x in SchedulerPredictionType }
|
||||||
@ -127,6 +128,90 @@ async def import_model(
|
|||||||
logger.error(str(e))
|
logger.error(str(e))
|
||||||
raise HTTPException(status_code=409, detail=str(e))
|
raise HTTPException(status_code=409, detail=str(e))
|
||||||
|
|
||||||
|
@models_router.post(
|
||||||
|
"/add",
|
||||||
|
operation_id="add_model",
|
||||||
|
responses= {
|
||||||
|
201: {"description" : "The model added successfully"},
|
||||||
|
404: {"description" : "The model could not be found"},
|
||||||
|
424: {"description" : "The model appeared to add successfully, but could not be found in the model manager"},
|
||||||
|
409: {"description" : "There is already a model corresponding to this path or repo_id"},
|
||||||
|
},
|
||||||
|
status_code=201,
|
||||||
|
response_model=ImportModelResponse
|
||||||
|
)
|
||||||
|
async def add_model(
|
||||||
|
info: Union[tuple(OPENAPI_MODEL_CONFIGS)] = Body(description="Model configuration"),
|
||||||
|
) -> ImportModelResponse:
|
||||||
|
""" Add a model using the configuration information appropriate for its type. Only local models can be added by path"""
|
||||||
|
|
||||||
|
logger = ApiDependencies.invoker.services.logger
|
||||||
|
|
||||||
|
try:
|
||||||
|
ApiDependencies.invoker.services.model_manager.add_model(
|
||||||
|
info.model_name,
|
||||||
|
info.base_model,
|
||||||
|
info.model_type,
|
||||||
|
model_attributes = info.dict()
|
||||||
|
)
|
||||||
|
logger.info(f'Successfully added {info.model_name}')
|
||||||
|
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
||||||
|
model_name=info.model_name,
|
||||||
|
base_model=info.base_model,
|
||||||
|
model_type=info.model_type
|
||||||
|
)
|
||||||
|
return parse_obj_as(ImportModelResponse, model_raw)
|
||||||
|
except KeyError as e:
|
||||||
|
logger.error(str(e))
|
||||||
|
raise HTTPException(status_code=404, detail=str(e))
|
||||||
|
except ValueError as e:
|
||||||
|
logger.error(str(e))
|
||||||
|
raise HTTPException(status_code=409, detail=str(e))
|
||||||
|
|
||||||
|
@models_router.post(
|
||||||
|
"/rename/{base_model}/{model_type}/{model_name}",
|
||||||
|
operation_id="rename_model",
|
||||||
|
responses= {
|
||||||
|
201: {"description" : "The model was renamed successfully"},
|
||||||
|
404: {"description" : "The model could not be found"},
|
||||||
|
409: {"description" : "There is already a model corresponding to the new name"},
|
||||||
|
},
|
||||||
|
status_code=201,
|
||||||
|
response_model=ImportModelResponse
|
||||||
|
)
|
||||||
|
async def rename_model(
|
||||||
|
base_model: BaseModelType = Path(description="Base model"),
|
||||||
|
model_type: ModelType = Path(description="The type of model"),
|
||||||
|
model_name: str = Path(description="current model name"),
|
||||||
|
new_name: Optional[str] = Query(description="new model name", default=None),
|
||||||
|
new_base: Optional[BaseModelType] = Query(description="new model base", default=None),
|
||||||
|
) -> ImportModelResponse:
|
||||||
|
""" Rename a model"""
|
||||||
|
|
||||||
|
logger = ApiDependencies.invoker.services.logger
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = ApiDependencies.invoker.services.model_manager.rename_model(
|
||||||
|
base_model = base_model,
|
||||||
|
model_type = model_type,
|
||||||
|
model_name = model_name,
|
||||||
|
new_name = new_name,
|
||||||
|
new_base = new_base,
|
||||||
|
)
|
||||||
|
logger.debug(result)
|
||||||
|
logger.info(f'Successfully renamed {model_name}=>{new_name}')
|
||||||
|
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
||||||
|
model_name=new_name or model_name,
|
||||||
|
base_model=new_base or base_model,
|
||||||
|
model_type=model_type
|
||||||
|
)
|
||||||
|
return parse_obj_as(ImportModelResponse, model_raw)
|
||||||
|
except KeyError as e:
|
||||||
|
logger.error(str(e))
|
||||||
|
raise HTTPException(status_code=404, detail=str(e))
|
||||||
|
except ValueError as e:
|
||||||
|
logger.error(str(e))
|
||||||
|
raise HTTPException(status_code=409, detail=str(e))
|
||||||
|
|
||||||
@models_router.delete(
|
@models_router.delete(
|
||||||
"/{base_model}/{model_type}/{model_name}",
|
"/{base_model}/{model_type}/{model_name}",
|
||||||
|
@ -167,6 +167,18 @@ class ModelManagerServiceBase(ABC):
|
|||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def rename_model(self,
|
||||||
|
model_name: str,
|
||||||
|
base_model: BaseModelType,
|
||||||
|
model_type: ModelType,
|
||||||
|
new_name: str,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Rename the indicated model.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def list_checkpoint_configs(
|
def list_checkpoint_configs(
|
||||||
self
|
self
|
||||||
@ -615,3 +627,26 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
conf_path = config.legacy_conf_path
|
conf_path = config.legacy_conf_path
|
||||||
root_path = config.root_path
|
root_path = config.root_path
|
||||||
return [(conf_path / x).relative_to(root_path) for x in conf_path.glob('**/*.yaml')]
|
return [(conf_path / x).relative_to(root_path) for x in conf_path.glob('**/*.yaml')]
|
||||||
|
|
||||||
|
def rename_model(self,
|
||||||
|
model_name: str,
|
||||||
|
base_model: BaseModelType,
|
||||||
|
model_type: ModelType,
|
||||||
|
new_name: str = None,
|
||||||
|
new_base: BaseModelType = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Rename the indicated model. Can provide a new name and/or a new base.
|
||||||
|
:param model_name: Current name of the model
|
||||||
|
:param base_model: Current base of the model
|
||||||
|
:param model_type: Model type (can't be changed)
|
||||||
|
:param new_name: New name for the model
|
||||||
|
:param new_base: New base for the model
|
||||||
|
"""
|
||||||
|
self.mgr.rename_model(base_model = base_model,
|
||||||
|
model_type = model_type,
|
||||||
|
model_name = model_name,
|
||||||
|
new_name = new_name,
|
||||||
|
new_base = new_base,
|
||||||
|
)
|
||||||
|
|
||||||
|
@ -71,8 +71,6 @@ class ModelInstallList:
|
|||||||
class InstallSelections():
|
class InstallSelections():
|
||||||
install_models: List[str]= field(default_factory=list)
|
install_models: List[str]= field(default_factory=list)
|
||||||
remove_models: List[str]=field(default_factory=list)
|
remove_models: List[str]=field(default_factory=list)
|
||||||
# scan_directory: Path = None
|
|
||||||
# autoscan_on_startup: bool=False
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ModelLoadInfo():
|
class ModelLoadInfo():
|
||||||
|
@ -671,6 +671,55 @@ class ModelManager(object):
|
|||||||
config = model_config,
|
config = model_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def rename_model(
|
||||||
|
self,
|
||||||
|
model_name: str,
|
||||||
|
base_model: BaseModelType,
|
||||||
|
model_type: ModelType,
|
||||||
|
new_name: str = None,
|
||||||
|
new_base: BaseModelType = None,
|
||||||
|
):
|
||||||
|
'''
|
||||||
|
Rename or rebase a model.
|
||||||
|
'''
|
||||||
|
if new_name is None and new_base is None:
|
||||||
|
self.logger.error("rename_model() called with neither a new_name nor a new_base. {model_name} unchanged.")
|
||||||
|
return
|
||||||
|
|
||||||
|
model_key = self.create_key(model_name, base_model, model_type)
|
||||||
|
model_cfg = self.models.get(model_key, None)
|
||||||
|
if not model_cfg:
|
||||||
|
raise KeyError(f"Unknown model: {model_key}")
|
||||||
|
|
||||||
|
old_path = self.app_config.root_path / model_cfg.path
|
||||||
|
new_name = new_name or model_name
|
||||||
|
new_base = new_base or base_model
|
||||||
|
new_key = self.create_key(new_name, new_base, model_type)
|
||||||
|
if new_key in self.models:
|
||||||
|
raise ValueError(f'Attempt to overwrite existing model definition "{new_key}"')
|
||||||
|
|
||||||
|
# if this is a model file/directory that we manage ourselves, we need to move it
|
||||||
|
if old_path.is_relative_to(self.app_config.models_path):
|
||||||
|
new_path = self.app_config.root_path / 'models' / new_base.value / model_type.value / new_name
|
||||||
|
move(old_path, new_path)
|
||||||
|
model_cfg.path = str(new_path.relative_to(self.app_config.root_path))
|
||||||
|
|
||||||
|
# clean up caches
|
||||||
|
old_model_cache = self._get_model_cache_path(old_path)
|
||||||
|
if old_model_cache.exists():
|
||||||
|
if old_model_cache.is_dir():
|
||||||
|
rmtree(str(old_model_cache))
|
||||||
|
else:
|
||||||
|
old_model_cache.unlink()
|
||||||
|
|
||||||
|
cache_ids = self.cache_keys.pop(model_key, [])
|
||||||
|
for cache_id in cache_ids:
|
||||||
|
self.cache.uncache_model(cache_id)
|
||||||
|
|
||||||
|
self.models.pop(model_key, None) # delete
|
||||||
|
self.models[new_key] = model_cfg
|
||||||
|
self.commit()
|
||||||
|
|
||||||
def convert_model (
|
def convert_model (
|
||||||
self,
|
self,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
|
@ -48,7 +48,9 @@ for base_model, models in MODEL_CLASSES.items():
|
|||||||
model_configs.discard(None)
|
model_configs.discard(None)
|
||||||
MODEL_CONFIGS.extend(model_configs)
|
MODEL_CONFIGS.extend(model_configs)
|
||||||
|
|
||||||
for cfg in model_configs:
|
# LS: sort to get the checkpoint configs first, which makes
|
||||||
|
# for a better template in the Swagger docs
|
||||||
|
for cfg in sorted(model_configs, key=lambda x: str(x)):
|
||||||
model_name, cfg_name = cfg.__qualname__.split('.')[-2:]
|
model_name, cfg_name = cfg.__qualname__.split('.')[-2:]
|
||||||
openapi_cfg_name = model_name + cfg_name
|
openapi_cfg_name = model_name + cfg_name
|
||||||
if openapi_cfg_name in vars():
|
if openapi_cfg_name in vars():
|
||||||
|
@ -59,7 +59,6 @@ class ModelConfigBase(BaseModel):
|
|||||||
path: str # or Path
|
path: str # or Path
|
||||||
description: Optional[str] = Field(None)
|
description: Optional[str] = Field(None)
|
||||||
model_format: Optional[str] = Field(None)
|
model_format: Optional[str] = Field(None)
|
||||||
# do not save to config
|
|
||||||
error: Optional[ModelError] = Field(None)
|
error: Optional[ModelError] = Field(None)
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
|
@ -38,7 +38,6 @@ class StableDiffusion1Model(DiffusersModel):
|
|||||||
config: str
|
config: str
|
||||||
variant: ModelVariantType
|
variant: ModelVariantType
|
||||||
|
|
||||||
|
|
||||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||||
assert base_model == BaseModelType.StableDiffusion1
|
assert base_model == BaseModelType.StableDiffusion1
|
||||||
assert model_type == ModelType.Main
|
assert model_type == ModelType.Main
|
||||||
|
Loading…
Reference in New Issue
Block a user