mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
use dependency injector pattern in the models route
This commit is contained in:
parent
f8b54930f0
commit
e877198de0
@ -7,13 +7,15 @@ from hashlib import sha1
|
|||||||
from random import randbytes
|
from random import randbytes
|
||||||
from typing import Any, Dict, List, Optional, Set
|
from typing import Any, Dict, List, Optional, Set
|
||||||
|
|
||||||
from fastapi import Body, Path, Query, Response
|
from fastapi import Body, Depends, Path, Query, Response
|
||||||
from fastapi.routing import APIRouter
|
from fastapi.routing import APIRouter
|
||||||
from pydantic import BaseModel, ConfigDict, Field
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
from starlette.exceptions import HTTPException
|
from starlette.exceptions import HTTPException
|
||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
|
from invokeai.app.services.invocation_services import InvocationServices
|
||||||
from invokeai.app.services.model_install import ModelInstallJob
|
from invokeai.app.services.model_install import ModelInstallJob
|
||||||
|
from invokeai.app.services.model_manager import ModelManagerServiceBase
|
||||||
from invokeai.app.services.model_records import (
|
from invokeai.app.services.model_records import (
|
||||||
DuplicateModelException,
|
DuplicateModelException,
|
||||||
InvalidModelException,
|
InvalidModelException,
|
||||||
@ -39,6 +41,22 @@ from ..dependencies import ApiDependencies
|
|||||||
model_manager_router = APIRouter(prefix="/v2/models", tags=["model_manager"])
|
model_manager_router = APIRouter(prefix="/v2/models", tags=["model_manager"])
|
||||||
|
|
||||||
|
|
||||||
|
def get_services() -> InvocationServices:
|
||||||
|
"""DI magic to return services from the ApiDependencies global."""
|
||||||
|
return ApiDependencies.invoker.services
|
||||||
|
|
||||||
|
|
||||||
|
Services = Annotated[InvocationServices, Depends(get_services)]
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_manager(services: Services) -> ModelManagerServiceBase:
|
||||||
|
"""DI magic to return the model manager from the ApiDependencies global."""
|
||||||
|
return services.model_manager
|
||||||
|
|
||||||
|
|
||||||
|
ModelManager = Annotated[ModelManagerServiceBase, Depends(get_model_manager)]
|
||||||
|
|
||||||
|
|
||||||
class ModelsList(BaseModel):
|
class ModelsList(BaseModel):
|
||||||
"""Return list of configs."""
|
"""Return list of configs."""
|
||||||
|
|
||||||
@ -141,6 +159,7 @@ example_model_metadata = {
|
|||||||
operation_id="list_model_records",
|
operation_id="list_model_records",
|
||||||
)
|
)
|
||||||
async def list_model_records(
|
async def list_model_records(
|
||||||
|
model_manager: ModelManager,
|
||||||
base_models: Optional[List[BaseModelType]] = Query(default=None, description="Base models to include"),
|
base_models: Optional[List[BaseModelType]] = Query(default=None, description="Base models to include"),
|
||||||
model_type: Optional[ModelType] = Query(default=None, description="The type of model to get"),
|
model_type: Optional[ModelType] = Query(default=None, description="The type of model to get"),
|
||||||
model_name: Optional[str] = Query(default=None, description="Exact match on the name of the model"),
|
model_name: Optional[str] = Query(default=None, description="Exact match on the name of the model"),
|
||||||
@ -149,7 +168,7 @@ async def list_model_records(
|
|||||||
),
|
),
|
||||||
) -> ModelsList:
|
) -> ModelsList:
|
||||||
"""Get a list of models."""
|
"""Get a list of models."""
|
||||||
record_store = ApiDependencies.invoker.services.model_manager.store
|
record_store = model_manager.store
|
||||||
found_models: list[AnyModelConfig] = []
|
found_models: list[AnyModelConfig] = []
|
||||||
if base_models:
|
if base_models:
|
||||||
for base_model in base_models:
|
for base_model in base_models:
|
||||||
@ -171,15 +190,14 @@ async def list_model_records(
|
|||||||
response_model=AnyModelConfig,
|
response_model=AnyModelConfig,
|
||||||
)
|
)
|
||||||
async def get_model_records_by_attrs(
|
async def get_model_records_by_attrs(
|
||||||
|
model_manager: ModelManager,
|
||||||
name: str = Query(description="The name of the model"),
|
name: str = Query(description="The name of the model"),
|
||||||
type: ModelType = Query(description="The type of the model"),
|
type: ModelType = Query(description="The type of the model"),
|
||||||
base: BaseModelType = Query(description="The base model of the model"),
|
base: BaseModelType = Query(description="The base model of the model"),
|
||||||
) -> AnyModelConfig:
|
) -> AnyModelConfig:
|
||||||
"""Gets a model by its attributes. The main use of this route is to provide backwards compatibility with the old
|
"""Gets a model by its attributes. The main use of this route is to provide backwards compatibility with the old
|
||||||
model manager, which identified models by a combination of name, base and type."""
|
model manager, which identified models by a combination of name, base and type."""
|
||||||
configs = ApiDependencies.invoker.services.model_manager.store.search_by_attr(
|
configs = model_manager.store.search_by_attr(base_model=base, model_type=type, model_name=name)
|
||||||
base_model=base, model_type=type, model_name=name
|
|
||||||
)
|
|
||||||
if not configs:
|
if not configs:
|
||||||
raise HTTPException(status_code=404, detail="No model found with these attributes")
|
raise HTTPException(status_code=404, detail="No model found with these attributes")
|
||||||
|
|
||||||
@ -199,10 +217,11 @@ async def get_model_records_by_attrs(
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
async def get_model_record(
|
async def get_model_record(
|
||||||
|
model_manager: ModelManager,
|
||||||
key: str = Path(description="Key of the model record to fetch."),
|
key: str = Path(description="Key of the model record to fetch."),
|
||||||
) -> AnyModelConfig:
|
) -> AnyModelConfig:
|
||||||
"""Get a model record"""
|
"""Get a model record"""
|
||||||
record_store = ApiDependencies.invoker.services.model_manager.store
|
record_store = model_manager.store
|
||||||
try:
|
try:
|
||||||
config: AnyModelConfig = record_store.get_model(key)
|
config: AnyModelConfig = record_store.get_model(key)
|
||||||
return config
|
return config
|
||||||
@ -212,13 +231,15 @@ async def get_model_record(
|
|||||||
|
|
||||||
@model_manager_router.get("/summary", operation_id="list_model_summary")
|
@model_manager_router.get("/summary", operation_id="list_model_summary")
|
||||||
async def list_model_summary(
|
async def list_model_summary(
|
||||||
|
model_manager: ModelManager,
|
||||||
page: int = Query(default=0, description="The page to get"),
|
page: int = Query(default=0, description="The page to get"),
|
||||||
per_page: int = Query(default=10, description="The number of models per page"),
|
per_page: int = Query(default=10, description="The number of models per page"),
|
||||||
order_by: ModelRecordOrderBy = Query(default=ModelRecordOrderBy.Default, description="The attribute to order by"),
|
order_by: ModelRecordOrderBy = Query(default=ModelRecordOrderBy.Default, description="The attribute to order by"),
|
||||||
) -> PaginatedResults[ModelSummary]:
|
) -> PaginatedResults[ModelSummary]:
|
||||||
"""Gets a page of model summary data."""
|
"""Gets a page of model summary data."""
|
||||||
record_store = ApiDependencies.invoker.services.model_manager.store
|
results: PaginatedResults[ModelSummary] = model_manager.store.list_models(
|
||||||
results: PaginatedResults[ModelSummary] = record_store.list_models(page=page, per_page=per_page, order_by=order_by)
|
page=page, per_page=per_page, order_by=order_by
|
||||||
|
)
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
@ -234,11 +255,11 @@ async def list_model_summary(
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
async def get_model_metadata(
|
async def get_model_metadata(
|
||||||
|
model_manager: ModelManager,
|
||||||
key: str = Path(description="Key of the model repo metadata to fetch."),
|
key: str = Path(description="Key of the model repo metadata to fetch."),
|
||||||
) -> Optional[AnyModelRepoMetadata]:
|
) -> Optional[AnyModelRepoMetadata]:
|
||||||
"""Get a model metadata object."""
|
"""Get a model metadata object."""
|
||||||
record_store = ApiDependencies.invoker.services.model_manager.store
|
result: Optional[AnyModelRepoMetadata] = model_manager.store.get_metadata(key)
|
||||||
result: Optional[AnyModelRepoMetadata] = record_store.get_metadata(key)
|
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@ -247,10 +268,9 @@ async def get_model_metadata(
|
|||||||
"/tags",
|
"/tags",
|
||||||
operation_id="list_tags",
|
operation_id="list_tags",
|
||||||
)
|
)
|
||||||
async def list_tags() -> Set[str]:
|
async def list_tags(model_manager: ModelManager) -> Set[str]:
|
||||||
"""Get a unique set of all the model tags."""
|
"""Get a unique set of all the model tags."""
|
||||||
record_store = ApiDependencies.invoker.services.model_manager.store
|
result: Set[str] = model_manager.store.list_tags()
|
||||||
result: Set[str] = record_store.list_tags()
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
@ -270,6 +290,8 @@ class FoundModel(BaseModel):
|
|||||||
response_model=List[FoundModel],
|
response_model=List[FoundModel],
|
||||||
)
|
)
|
||||||
async def scan_for_models(
|
async def scan_for_models(
|
||||||
|
model_manager: ModelManager,
|
||||||
|
services: Services,
|
||||||
scan_path: str = Query(description="Directory path to search for models", default=None),
|
scan_path: str = Query(description="Directory path to search for models", default=None),
|
||||||
) -> List[FoundModel]:
|
) -> List[FoundModel]:
|
||||||
path = pathlib.Path(scan_path)
|
path = pathlib.Path(scan_path)
|
||||||
@ -282,7 +304,7 @@ async def scan_for_models(
|
|||||||
search = ModelSearch()
|
search = ModelSearch()
|
||||||
try:
|
try:
|
||||||
found_model_paths = search.search(path)
|
found_model_paths = search.search(path)
|
||||||
models_path = ApiDependencies.invoker.services.configuration.models_path
|
models_path = services.configuration.models_path
|
||||||
|
|
||||||
# If the search path includes the main models directory, we need to exclude core models from the list.
|
# If the search path includes the main models directory, we need to exclude core models from the list.
|
||||||
# TODO(MM2): Core models should be handled by the model manager so we can determine if they are installed
|
# TODO(MM2): Core models should be handled by the model manager so we can determine if they are installed
|
||||||
@ -290,7 +312,7 @@ async def scan_for_models(
|
|||||||
core_models_path = pathlib.Path(models_path, "core").resolve()
|
core_models_path = pathlib.Path(models_path, "core").resolve()
|
||||||
non_core_model_paths = [p for p in found_model_paths if not p.is_relative_to(core_models_path)]
|
non_core_model_paths = [p for p in found_model_paths if not p.is_relative_to(core_models_path)]
|
||||||
|
|
||||||
installed_models = ApiDependencies.invoker.services.model_manager.store.search_by_attr()
|
installed_models = model_manager.store.search_by_attr()
|
||||||
resolved_installed_model_paths: list[str] = []
|
resolved_installed_model_paths: list[str] = []
|
||||||
installed_model_sources: list[str] = []
|
installed_model_sources: list[str] = []
|
||||||
|
|
||||||
@ -328,10 +350,11 @@ async def scan_for_models(
|
|||||||
operation_id="search_by_metadata_tags",
|
operation_id="search_by_metadata_tags",
|
||||||
)
|
)
|
||||||
async def search_by_metadata_tags(
|
async def search_by_metadata_tags(
|
||||||
|
model_manager: ModelManager,
|
||||||
tags: Set[str] = Query(default=None, description="Tags to search for"),
|
tags: Set[str] = Query(default=None, description="Tags to search for"),
|
||||||
) -> ModelsList:
|
) -> ModelsList:
|
||||||
"""Get a list of models."""
|
"""Get a list of models."""
|
||||||
record_store = ApiDependencies.invoker.services.model_manager.store
|
record_store = model_manager.store
|
||||||
results = record_store.search_by_metadata_tag(tags)
|
results = record_store.search_by_metadata_tag(tags)
|
||||||
return ModelsList(models=results)
|
return ModelsList(models=results)
|
||||||
|
|
||||||
@ -351,14 +374,15 @@ async def search_by_metadata_tags(
|
|||||||
status_code=200,
|
status_code=200,
|
||||||
)
|
)
|
||||||
async def update_model_record(
|
async def update_model_record(
|
||||||
|
services: Services,
|
||||||
key: Annotated[str, Path(description="Unique key of model")],
|
key: Annotated[str, Path(description="Unique key of model")],
|
||||||
info: Annotated[
|
info: Annotated[
|
||||||
AnyModelConfig, Body(description="Model config", discriminator="type", example=example_model_input)
|
AnyModelConfig, Body(description="Model config", discriminator="type", example=example_model_input)
|
||||||
],
|
],
|
||||||
) -> AnyModelConfig:
|
) -> AnyModelConfig:
|
||||||
"""Update model contents with a new config. If the model name or base fields are changed, then the model is renamed."""
|
"""Update model contents with a new config. If the model name or base fields are changed, then the model is renamed."""
|
||||||
logger = ApiDependencies.invoker.services.logger
|
logger = services.logger
|
||||||
record_store = ApiDependencies.invoker.services.model_manager.store
|
record_store = services.model_manager.store
|
||||||
try:
|
try:
|
||||||
model_response: AnyModelConfig = record_store.update_model(key, config=info)
|
model_response: AnyModelConfig = record_store.update_model(key, config=info)
|
||||||
logger.info(f"Updated model: {key}")
|
logger.info(f"Updated model: {key}")
|
||||||
@ -380,6 +404,7 @@ async def update_model_record(
|
|||||||
status_code=204,
|
status_code=204,
|
||||||
)
|
)
|
||||||
async def del_model_record(
|
async def del_model_record(
|
||||||
|
services: Services,
|
||||||
key: str = Path(description="Unique key of model to remove from model registry."),
|
key: str = Path(description="Unique key of model to remove from model registry."),
|
||||||
) -> Response:
|
) -> Response:
|
||||||
"""
|
"""
|
||||||
@ -388,10 +413,10 @@ async def del_model_record(
|
|||||||
The configuration record will be removed. The corresponding weights files will be
|
The configuration record will be removed. The corresponding weights files will be
|
||||||
deleted as well if they reside within the InvokeAI "models" directory.
|
deleted as well if they reside within the InvokeAI "models" directory.
|
||||||
"""
|
"""
|
||||||
logger = ApiDependencies.invoker.services.logger
|
logger = services.logger
|
||||||
|
|
||||||
try:
|
try:
|
||||||
installer = ApiDependencies.invoker.services.model_manager.install
|
installer = services.model_manager.install
|
||||||
installer.delete(key)
|
installer.delete(key)
|
||||||
logger.info(f"Deleted model: {key}")
|
logger.info(f"Deleted model: {key}")
|
||||||
return Response(status_code=204)
|
return Response(status_code=204)
|
||||||
@ -414,13 +439,14 @@ async def del_model_record(
|
|||||||
status_code=201,
|
status_code=201,
|
||||||
)
|
)
|
||||||
async def add_model_record(
|
async def add_model_record(
|
||||||
|
services: Services,
|
||||||
config: Annotated[
|
config: Annotated[
|
||||||
AnyModelConfig, Body(description="Model config", discriminator="type", example=example_model_input)
|
AnyModelConfig, Body(description="Model config", discriminator="type", example=example_model_input)
|
||||||
],
|
],
|
||||||
) -> AnyModelConfig:
|
) -> AnyModelConfig:
|
||||||
"""Add a model using the configuration information appropriate for its type."""
|
"""Add a model using the configuration information appropriate for its type."""
|
||||||
logger = ApiDependencies.invoker.services.logger
|
logger = services.logger
|
||||||
record_store = ApiDependencies.invoker.services.model_manager.store
|
record_store = services.model_manager.store
|
||||||
if config.key == "<NOKEY>":
|
if config.key == "<NOKEY>":
|
||||||
config.key = sha1(randbytes(100)).hexdigest()
|
config.key = sha1(randbytes(100)).hexdigest()
|
||||||
logger.info(f"Created model {config.key} for {config.name}")
|
logger.info(f"Created model {config.key} for {config.name}")
|
||||||
@ -450,6 +476,7 @@ async def add_model_record(
|
|||||||
status_code=201,
|
status_code=201,
|
||||||
)
|
)
|
||||||
async def install_model(
|
async def install_model(
|
||||||
|
services: Services,
|
||||||
source: str = Query(description="Model source to install, can be a local path, repo_id, or remote URL"),
|
source: str = Query(description="Model source to install, can be a local path, repo_id, or remote URL"),
|
||||||
# TODO(MM2): Can we type this?
|
# TODO(MM2): Can we type this?
|
||||||
config: Optional[Dict[str, Any]] = Body(
|
config: Optional[Dict[str, Any]] = Body(
|
||||||
@ -485,10 +512,10 @@ async def install_model(
|
|||||||
See the documentation for `import_model_record` for more information on
|
See the documentation for `import_model_record` for more information on
|
||||||
interpreting the job information returned by this route.
|
interpreting the job information returned by this route.
|
||||||
"""
|
"""
|
||||||
logger = ApiDependencies.invoker.services.logger
|
logger = services.logger
|
||||||
|
|
||||||
try:
|
try:
|
||||||
installer = ApiDependencies.invoker.services.model_manager.install
|
installer = services.model_manager.install
|
||||||
result: ModelInstallJob = installer.heuristic_import(
|
result: ModelInstallJob = installer.heuristic_import(
|
||||||
source=source,
|
source=source,
|
||||||
config=config,
|
config=config,
|
||||||
@ -511,7 +538,7 @@ async def install_model(
|
|||||||
"/import",
|
"/import",
|
||||||
operation_id="list_model_install_jobs",
|
operation_id="list_model_install_jobs",
|
||||||
)
|
)
|
||||||
async def list_model_install_jobs() -> List[ModelInstallJob]:
|
async def list_model_install_jobs(services: Services) -> List[ModelInstallJob]:
|
||||||
"""Return the list of model install jobs.
|
"""Return the list of model install jobs.
|
||||||
|
|
||||||
Install jobs have a numeric `id`, a `status`, and other fields that provide information on
|
Install jobs have a numeric `id`, a `status`, and other fields that provide information on
|
||||||
@ -531,7 +558,7 @@ async def list_model_install_jobs() -> List[ModelInstallJob]:
|
|||||||
|
|
||||||
See the example and schema below for more information.
|
See the example and schema below for more information.
|
||||||
"""
|
"""
|
||||||
jobs: List[ModelInstallJob] = ApiDependencies.invoker.services.model_manager.install.list_jobs()
|
jobs: List[ModelInstallJob] = services.model_manager.install.list_jobs()
|
||||||
return jobs
|
return jobs
|
||||||
|
|
||||||
|
|
||||||
@ -543,13 +570,13 @@ async def list_model_install_jobs() -> List[ModelInstallJob]:
|
|||||||
404: {"description": "No such job"},
|
404: {"description": "No such job"},
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
async def get_model_install_job(id: int = Path(description="Model install id")) -> ModelInstallJob:
|
async def get_model_install_job(services: Services, id: int = Path(description="Model install id")) -> ModelInstallJob:
|
||||||
"""
|
"""
|
||||||
Return model install job corresponding to the given source. See the documentation for 'List Model Install Jobs'
|
Return model install job corresponding to the given source. See the documentation for 'List Model Install Jobs'
|
||||||
for information on the format of the return value.
|
for information on the format of the return value.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
result: ModelInstallJob = ApiDependencies.invoker.services.model_manager.install.get_job_by_id(id)
|
result: ModelInstallJob = services.model_manager.install.get_job_by_id(id)
|
||||||
return result
|
return result
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
raise HTTPException(status_code=404, detail=str(e))
|
raise HTTPException(status_code=404, detail=str(e))
|
||||||
@ -564,9 +591,9 @@ async def get_model_install_job(id: int = Path(description="Model install id"))
|
|||||||
},
|
},
|
||||||
status_code=201,
|
status_code=201,
|
||||||
)
|
)
|
||||||
async def cancel_model_install_job(id: int = Path(description="Model install job ID")) -> None:
|
async def cancel_model_install_job(services: Services, id: int = Path(description="Model install job ID")) -> None:
|
||||||
"""Cancel the model install job(s) corresponding to the given job ID."""
|
"""Cancel the model install job(s) corresponding to the given job ID."""
|
||||||
installer = ApiDependencies.invoker.services.model_manager.install
|
installer = services.model_manager.install
|
||||||
try:
|
try:
|
||||||
job = installer.get_job_by_id(id)
|
job = installer.get_job_by_id(id)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
@ -582,9 +609,9 @@ async def cancel_model_install_job(id: int = Path(description="Model install job
|
|||||||
400: {"description": "Bad request"},
|
400: {"description": "Bad request"},
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
async def prune_model_install_jobs() -> Response:
|
async def prune_model_install_jobs(model_manager: ModelManager) -> Response:
|
||||||
"""Prune all completed and errored jobs from the install job list."""
|
"""Prune all completed and errored jobs from the install job list."""
|
||||||
ApiDependencies.invoker.services.model_manager.install.prune_jobs()
|
model_manager.install.prune_jobs()
|
||||||
return Response(status_code=204)
|
return Response(status_code=204)
|
||||||
|
|
||||||
|
|
||||||
@ -596,14 +623,14 @@ async def prune_model_install_jobs() -> Response:
|
|||||||
400: {"description": "Bad request"},
|
400: {"description": "Bad request"},
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
async def sync_models_to_config() -> Response:
|
async def sync_models_to_config(model_manager: ModelManager) -> Response:
|
||||||
"""
|
"""
|
||||||
Traverse the models and autoimport directories.
|
Traverse the models and autoimport directories.
|
||||||
|
|
||||||
Model files without a corresponding
|
Model files without a corresponding
|
||||||
record in the database are added. Orphan records without a models file are deleted.
|
record in the database are added. Orphan records without a models file are deleted.
|
||||||
"""
|
"""
|
||||||
ApiDependencies.invoker.services.model_manager.install.sync_to_config()
|
model_manager.install.sync_to_config()
|
||||||
return Response(status_code=204)
|
return Response(status_code=204)
|
||||||
|
|
||||||
|
|
||||||
@ -621,6 +648,7 @@ async def sync_models_to_config() -> Response:
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
async def convert_model(
|
async def convert_model(
|
||||||
|
services: Services,
|
||||||
key: str = Path(description="Unique key of the safetensors main model to convert to diffusers format."),
|
key: str = Path(description="Unique key of the safetensors main model to convert to diffusers format."),
|
||||||
) -> AnyModelConfig:
|
) -> AnyModelConfig:
|
||||||
"""
|
"""
|
||||||
@ -628,11 +656,11 @@ async def convert_model(
|
|||||||
Note that during the conversion process the key and model hash will change.
|
Note that during the conversion process the key and model hash will change.
|
||||||
The return value is the model configuration for the converted model.
|
The return value is the model configuration for the converted model.
|
||||||
"""
|
"""
|
||||||
model_manager = ApiDependencies.invoker.services.model_manager
|
model_manager = services.model_manager
|
||||||
logger = ApiDependencies.invoker.services.logger
|
logger = services.logger
|
||||||
loader = ApiDependencies.invoker.services.model_manager.load
|
loader = services.model_manager.load
|
||||||
store = ApiDependencies.invoker.services.model_manager.store
|
store = services.model_manager.store
|
||||||
installer = ApiDependencies.invoker.services.model_manager.install
|
installer = services.model_manager.install
|
||||||
|
|
||||||
try:
|
try:
|
||||||
model_config = store.get_model(key)
|
model_config = store.get_model(key)
|
||||||
@ -700,6 +728,7 @@ async def convert_model(
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
async def merge(
|
async def merge(
|
||||||
|
services: Services,
|
||||||
keys: List[str] = Body(description="Keys for two to three models to merge", min_length=2, max_length=3),
|
keys: List[str] = Body(description="Keys for two to three models to merge", min_length=2, max_length=3),
|
||||||
merged_model_name: Optional[str] = Body(description="Name of destination model", default=None),
|
merged_model_name: Optional[str] = Body(description="Name of destination model", default=None),
|
||||||
alpha: float = Body(description="Alpha weighting strength to apply to 2d and 3d models", default=0.5),
|
alpha: float = Body(description="Alpha weighting strength to apply to 2d and 3d models", default=0.5),
|
||||||
@ -726,11 +755,11 @@ async def merge(
|
|||||||
merge_dest_directory Specify a directory to store the merged model in [models directory]
|
merge_dest_directory Specify a directory to store the merged model in [models directory]
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
logger = ApiDependencies.invoker.services.logger
|
logger = services.logger
|
||||||
try:
|
try:
|
||||||
logger.info(f"Merging models: {keys} into {merge_dest_directory or '<MODELS>'}/{merged_model_name}")
|
logger.info(f"Merging models: {keys} into {merge_dest_directory or '<MODELS>'}/{merged_model_name}")
|
||||||
dest = pathlib.Path(merge_dest_directory) if merge_dest_directory else None
|
dest = pathlib.Path(merge_dest_directory) if merge_dest_directory else None
|
||||||
installer = ApiDependencies.invoker.services.model_manager.install
|
installer = services.model_manager.install
|
||||||
merger = ModelMerger(installer)
|
merger = ModelMerger(installer)
|
||||||
model_names = [installer.record_store.get_model(x).name for x in keys]
|
model_names = [installer.record_store.get_model(x).name for x in keys]
|
||||||
response = merger.merge_diffusion_models_and_save(
|
response = merger.merge_diffusion_models_and_save(
|
||||||
|
Loading…
Reference in New Issue
Block a user