final tidying before marking PR as ready for review

- Replace AnyModelLoader with ModelLoaderRegistry
- Fix type check errors in multiple files
- Remove apparently unneeded `get_model_config_enum()` method from model manager
- Remove last vestiges of old model manager
- Updated tests and documentation

resolve conflict with seamless.py
This commit is contained in:
psychedelicious 2024-02-18 17:27:42 +11:00
parent 2ad0752582
commit be8b99eed5
74 changed files with 672 additions and 10362 deletions

View File

@ -1531,23 +1531,29 @@ Here is a typical initialization pattern:
``` ```
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.model_records import ModelRecordServiceBase from invokeai.app.services.model_load import ModelLoadService, ModelLoaderRegistry
from invokeai.app.services.model_load import ModelLoadService
config = InvokeAIAppConfig.get_config() config = InvokeAIAppConfig.get_config()
store = ModelRecordServiceBase.open(config) ram_cache = ModelCache(
loader = ModelLoadService(config, store) max_cache_size=config.ram_cache_size, max_vram_cache_size=config.vram_cache_size, logger=logger
)
convert_cache = ModelConvertCache(
cache_path=config.models_convert_cache_path, max_size=config.convert_cache_size
)
loader = ModelLoadService(
app_config=config,
ram_cache=ram_cache,
convert_cache=convert_cache,
registry=ModelLoaderRegistry
)
``` ```
Note that we are relying on the contents of the application ### load_model(model_config, [submodel_type], [context]) -> LoadedModel
configuration to choose the implementation of
`ModelRecordServiceBase`.
### load_model_by_key(key, [submodel_type], [context]) -> LoadedModel The `load_model()` method takes an `AnyModelConfig` returned by
`ModelRecordService.get_model()` and returns the corresponding loaded
The `load_model_by_key()` method receives the unique key that model. It loads the model into memory, gets the model ready for use,
identifies the model. It loads the model into memory, gets the model and returns a `LoadedModel` object.
ready for use, and returns a `LoadedModel` object.
The optional second argument, `subtype` is a `SubModelType` string The optional second argument, `subtype` is a `SubModelType` string
enum, such as "vae". It is mandatory when used with a main model, and enum, such as "vae". It is mandatory when used with a main model, and
@ -1593,25 +1599,6 @@ with model_info as vae:
- `ModelNotFoundException` -- key in database but model not found at path - `ModelNotFoundException` -- key in database but model not found at path
- `NotImplementedException` -- the loader doesn't know how to load this type of model - `NotImplementedException` -- the loader doesn't know how to load this type of model
### load_model_by_attr(model_name, base_model, model_type, [submodel], [context]) -> LoadedModel
This is similar to `load_model_by_key`, but instead it accepts the
combination of the model's name, type and base, which it passes to the
model record config store for retrieval. If successful, this method
returns a `LoadedModel`. It can raise the following exceptions:
```
UnknownModelException -- model with these attributes not known
NotImplementedException -- the loader doesn't know how to load this type of model
ValueError -- more than one model matches this combination of base/type/name
```
### load_model_by_config(config, [submodel], [context]) -> LoadedModel
This method takes an `AnyModelConfig` returned by
ModelRecordService.get_model() and returns the corresponding loaded
model. It may raise a `NotImplementedException`.
### Emitting model loading events ### Emitting model loading events
When the `context` argument is passed to `load_model_*()`, it will When the `context` argument is passed to `load_model_*()`, it will
@ -1656,7 +1643,7 @@ onnx models.
To install a new loader, place it in To install a new loader, place it in
`invokeai/backend/model_manager/load/model_loaders`. Inherit from `invokeai/backend/model_manager/load/model_loaders`. Inherit from
`ModelLoader` and use the `@AnyModelLoader.register()` decorator to `ModelLoader` and use the `@ModelLoaderRegistry.register()` decorator to
indicate what type of models the loader can handle. indicate what type of models the loader can handle.
Here is a complete example from `generic_diffusers.py`, which is able Here is a complete example from `generic_diffusers.py`, which is able
@ -1674,12 +1661,11 @@ from invokeai.backend.model_manager import (
ModelType, ModelType,
SubModelType, SubModelType,
) )
from ..load_base import AnyModelLoader from .. import ModelLoader, ModelLoaderRegistry
from ..load_default import ModelLoader
@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.CLIPVision, format=ModelFormat.Diffusers) @ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.CLIPVision, format=ModelFormat.Diffusers)
@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.T2IAdapter, format=ModelFormat.Diffusers) @ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.T2IAdapter, format=ModelFormat.Diffusers)
class GenericDiffusersLoader(ModelLoader): class GenericDiffusersLoader(ModelLoader):
"""Class to load simple diffusers models.""" """Class to load simple diffusers models."""
@ -1728,3 +1714,74 @@ model. It does whatever it needs to do to get the model into diffusers
format, and returns the Path of the resulting model. (The path should format, and returns the Path of the resulting model. (The path should
ordinarily be the same as `output_path`.) ordinarily be the same as `output_path`.)
## The ModelManagerService object
For convenience, the API provides a `ModelManagerService` object which
gives a single point of access to the major model manager
services. This object is created at initialization time and can be
found in the global `ApiDependencies.invoker.services.model_manager`
object, or in `context.services.model_manager` from within an
invocation.
In the examples below, we have retrieved the manager using:
```
mm = ApiDependencies.invoker.services.model_manager
```
The following properties and methods will be available:
### mm.store
This retrieves the `ModelRecordService` associated with the
manager. Example:
```
configs = mm.store.get_model_by_attr(name='stable-diffusion-v1-5')
```
### mm.install
This retrieves the `ModelInstallService` associated with the manager.
Example:
```
job = mm.install.heuristic_import(`https://civitai.com/models/58390/detail-tweaker-lora-lora`)
```
### mm.load
This retrieves the `ModelLoaderService` associated with the manager. Example:
```
configs = mm.store.get_model_by_attr(name='stable-diffusion-v1-5')
assert len(configs) > 0
loaded_model = mm.load.load_model(configs[0])
```
The model manager also offers a few convenience shortcuts for loading
models:
### mm.load_model_by_config(model_config, [submodel], [context]) -> LoadedModel
Same as `mm.load.load_model()`.
### mm.load_model_by_attr(model_name, base_model, model_type, [submodel], [context]) -> LoadedModel
This accepts the combination of the model's name, type and base, which
it passes to the model record config store for retrieval. If a unique
model config is found, this method returns a `LoadedModel`. It can
raise the following exceptions:
```
UnknownModelException -- model with these attributes not known
NotImplementedException -- the loader doesn't know how to load this type of model
ValueError -- more than one model matches this combination of base/type/name
```
### mm.load_model_by_key(key, [submodel], [context]) -> LoadedModel
This method takes a model key, looks it up using the
`ModelRecordServiceBase` object in `mm.store`, and passes the returned
model configuration to `load_model_by_config()`. It may raise a
`NotImplementedException`.

View File

@ -35,7 +35,7 @@ from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
from ..dependencies import ApiDependencies from ..dependencies import ApiDependencies
model_manager_v2_router = APIRouter(prefix="/v2/models", tags=["model_manager_v2"]) model_manager_router = APIRouter(prefix="/v2/models", tags=["model_manager"])
class ModelsList(BaseModel): class ModelsList(BaseModel):
@ -135,7 +135,7 @@ example_model_metadata = {
############################################################################## ##############################################################################
@model_manager_v2_router.get( @model_manager_router.get(
"/", "/",
operation_id="list_model_records", operation_id="list_model_records",
) )
@ -164,7 +164,7 @@ async def list_model_records(
return ModelsList(models=found_models) return ModelsList(models=found_models)
@model_manager_v2_router.get( @model_manager_router.get(
"/i/{key}", "/i/{key}",
operation_id="get_model_record", operation_id="get_model_record",
responses={ responses={
@ -188,7 +188,7 @@ async def get_model_record(
raise HTTPException(status_code=404, detail=str(e)) raise HTTPException(status_code=404, detail=str(e))
@model_manager_v2_router.get("/summary", operation_id="list_model_summary") @model_manager_router.get("/summary", operation_id="list_model_summary")
async def list_model_summary( async def list_model_summary(
page: int = Query(default=0, description="The page to get"), page: int = Query(default=0, description="The page to get"),
per_page: int = Query(default=10, description="The number of models per page"), per_page: int = Query(default=10, description="The number of models per page"),
@ -200,7 +200,7 @@ async def list_model_summary(
return results return results
@model_manager_v2_router.get( @model_manager_router.get(
"/meta/i/{key}", "/meta/i/{key}",
operation_id="get_model_metadata", operation_id="get_model_metadata",
responses={ responses={
@ -223,7 +223,7 @@ async def get_model_metadata(
return result return result
@model_manager_v2_router.get( @model_manager_router.get(
"/tags", "/tags",
operation_id="list_tags", operation_id="list_tags",
) )
@ -234,7 +234,7 @@ async def list_tags() -> Set[str]:
return result return result
@model_manager_v2_router.get( @model_manager_router.get(
"/tags/search", "/tags/search",
operation_id="search_by_metadata_tags", operation_id="search_by_metadata_tags",
) )
@ -247,7 +247,7 @@ async def search_by_metadata_tags(
return ModelsList(models=results) return ModelsList(models=results)
@model_manager_v2_router.patch( @model_manager_router.patch(
"/i/{key}", "/i/{key}",
operation_id="update_model_record", operation_id="update_model_record",
responses={ responses={
@ -281,7 +281,7 @@ async def update_model_record(
return model_response return model_response
@model_manager_v2_router.delete( @model_manager_router.delete(
"/i/{key}", "/i/{key}",
operation_id="del_model_record", operation_id="del_model_record",
responses={ responses={
@ -311,7 +311,7 @@ async def del_model_record(
raise HTTPException(status_code=404, detail=str(e)) raise HTTPException(status_code=404, detail=str(e))
@model_manager_v2_router.post( @model_manager_router.post(
"/i/", "/i/",
operation_id="add_model_record", operation_id="add_model_record",
responses={ responses={
@ -349,7 +349,7 @@ async def add_model_record(
return result return result
@model_manager_v2_router.post( @model_manager_router.post(
"/heuristic_import", "/heuristic_import",
operation_id="heuristic_import_model", operation_id="heuristic_import_model",
responses={ responses={
@ -416,7 +416,7 @@ async def heuristic_import(
return result return result
@model_manager_v2_router.post( @model_manager_router.post(
"/install", "/install",
operation_id="import_model", operation_id="import_model",
responses={ responses={
@ -516,7 +516,7 @@ async def import_model(
return result return result
@model_manager_v2_router.get( @model_manager_router.get(
"/import", "/import",
operation_id="list_model_install_jobs", operation_id="list_model_install_jobs",
) )
@ -544,7 +544,7 @@ async def list_model_install_jobs() -> List[ModelInstallJob]:
return jobs return jobs
@model_manager_v2_router.get( @model_manager_router.get(
"/import/{id}", "/import/{id}",
operation_id="get_model_install_job", operation_id="get_model_install_job",
responses={ responses={
@ -564,7 +564,7 @@ async def get_model_install_job(id: int = Path(description="Model install id"))
raise HTTPException(status_code=404, detail=str(e)) raise HTTPException(status_code=404, detail=str(e))
@model_manager_v2_router.delete( @model_manager_router.delete(
"/import/{id}", "/import/{id}",
operation_id="cancel_model_install_job", operation_id="cancel_model_install_job",
responses={ responses={
@ -583,7 +583,7 @@ async def cancel_model_install_job(id: int = Path(description="Model install job
installer.cancel_job(job) installer.cancel_job(job)
@model_manager_v2_router.patch( @model_manager_router.patch(
"/import", "/import",
operation_id="prune_model_install_jobs", operation_id="prune_model_install_jobs",
responses={ responses={
@ -597,7 +597,7 @@ async def prune_model_install_jobs() -> Response:
return Response(status_code=204) return Response(status_code=204)
@model_manager_v2_router.patch( @model_manager_router.patch(
"/sync", "/sync",
operation_id="sync_models_to_config", operation_id="sync_models_to_config",
responses={ responses={
@ -616,7 +616,7 @@ async def sync_models_to_config() -> Response:
return Response(status_code=204) return Response(status_code=204)
@model_manager_v2_router.put( @model_manager_router.put(
"/convert/{key}", "/convert/{key}",
operation_id="convert_model", operation_id="convert_model",
responses={ responses={
@ -694,7 +694,7 @@ async def convert_model(
return new_config return new_config
@model_manager_v2_router.put( @model_manager_router.put(
"/merge", "/merge",
operation_id="merge", operation_id="merge",
responses={ responses={

View File

@ -1,426 +0,0 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654), 2023 Kent Keirsey (https://github.com/hipsterusername), 2023 Lincoln D. Stein
import pathlib
from typing import Annotated, List, Literal, Optional, Union
from fastapi import Body, Path, Query, Response
from fastapi.routing import APIRouter
from pydantic import BaseModel, ConfigDict, Field, TypeAdapter
from starlette.exceptions import HTTPException
from invokeai.backend.model_management import BaseModelType, MergeInterpolationMethod, ModelType
from invokeai.backend.model_management.models import (
OPENAPI_MODEL_CONFIGS,
InvalidModelException,
ModelNotFoundException,
SchedulerPredictionType,
)
from ..dependencies import ApiDependencies
models_router = APIRouter(prefix="/v1/models", tags=["models"])
UpdateModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
UpdateModelResponseValidator = TypeAdapter(UpdateModelResponse)
ImportModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
ImportModelResponseValidator = TypeAdapter(ImportModelResponse)
ConvertModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
ConvertModelResponseValidator = TypeAdapter(ConvertModelResponse)
MergeModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
ImportModelAttributes = Union[tuple(OPENAPI_MODEL_CONFIGS)]
class ModelsList(BaseModel):
models: list[Union[tuple(OPENAPI_MODEL_CONFIGS)]]
model_config = ConfigDict(use_enum_values=True)
ModelsListValidator = TypeAdapter(ModelsList)
@models_router.get(
"/",
operation_id="list_models",
responses={200: {"model": ModelsList}},
)
async def list_models(
base_models: Optional[List[BaseModelType]] = Query(default=None, description="Base models to include"),
model_type: Optional[ModelType] = Query(default=None, description="The type of model to get"),
) -> ModelsList:
"""Gets a list of models"""
if base_models and len(base_models) > 0:
models_raw = []
for base_model in base_models:
models_raw.extend(ApiDependencies.invoker.services.model_manager.list_models(base_model, model_type))
else:
models_raw = ApiDependencies.invoker.services.model_manager.list_models(None, model_type)
models = ModelsListValidator.validate_python({"models": models_raw})
return models
@models_router.patch(
"/{base_model}/{model_type}/{model_name}",
operation_id="update_model",
responses={
200: {"description": "The model was updated successfully"},
400: {"description": "Bad request"},
404: {"description": "The model could not be found"},
409: {"description": "There is already a model corresponding to the new name"},
},
status_code=200,
response_model=UpdateModelResponse,
)
async def update_model(
base_model: BaseModelType = Path(description="Base model"),
model_type: ModelType = Path(description="The type of model"),
model_name: str = Path(description="model name"),
info: Union[tuple(OPENAPI_MODEL_CONFIGS)] = Body(description="Model configuration"),
) -> UpdateModelResponse:
"""Update model contents with a new config. If the model name or base fields are changed, then the model is renamed."""
logger = ApiDependencies.invoker.services.logger
try:
previous_info = ApiDependencies.invoker.services.model_manager.list_model(
model_name=model_name,
base_model=base_model,
model_type=model_type,
)
# rename operation requested
if info.model_name != model_name or info.base_model != base_model:
ApiDependencies.invoker.services.model_manager.rename_model(
base_model=base_model,
model_type=model_type,
model_name=model_name,
new_name=info.model_name,
new_base=info.base_model,
)
logger.info(f"Successfully renamed {base_model.value}/{model_name}=>{info.base_model}/{info.model_name}")
# update information to support an update of attributes
model_name = info.model_name
base_model = info.base_model
new_info = ApiDependencies.invoker.services.model_manager.list_model(
model_name=model_name,
base_model=base_model,
model_type=model_type,
)
if new_info.get("path") != previous_info.get(
"path"
): # model manager moved model path during rename - don't overwrite it
info.path = new_info.get("path")
# replace empty string values with None/null to avoid phenomenon of vae: ''
info_dict = info.model_dump()
info_dict = {x: info_dict[x] if info_dict[x] else None for x in info_dict.keys()}
ApiDependencies.invoker.services.model_manager.update_model(
model_name=model_name,
base_model=base_model,
model_type=model_type,
model_attributes=info_dict,
)
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
model_name=model_name,
base_model=base_model,
model_type=model_type,
)
model_response = UpdateModelResponseValidator.validate_python(model_raw)
except ModelNotFoundException as e:
raise HTTPException(status_code=404, detail=str(e))
except ValueError as e:
logger.error(str(e))
raise HTTPException(status_code=409, detail=str(e))
except Exception as e:
logger.error(str(e))
raise HTTPException(status_code=400, detail=str(e))
return model_response
@models_router.post(
"/import",
operation_id="import_model",
responses={
201: {"description": "The model imported successfully"},
404: {"description": "The model could not be found"},
415: {"description": "Unrecognized file/folder format"},
424: {"description": "The model appeared to import successfully, but could not be found in the model manager"},
409: {"description": "There is already a model corresponding to this path or repo_id"},
},
status_code=201,
response_model=ImportModelResponse,
)
async def import_model(
location: str = Body(description="A model path, repo_id or URL to import"),
prediction_type: Optional[Literal["v_prediction", "epsilon", "sample"]] = Body(
description="Prediction type for SDv2 checkpoints and rare SDv1 checkpoints",
default=None,
),
) -> ImportModelResponse:
"""Add a model using its local path, repo_id, or remote URL. Model characteristics will be probed and configured automatically"""
location = location.strip("\"' ")
items_to_import = {location}
prediction_types = {x.value: x for x in SchedulerPredictionType}
logger = ApiDependencies.invoker.services.logger
try:
installed_models = ApiDependencies.invoker.services.model_manager.heuristic_import(
items_to_import=items_to_import,
prediction_type_helper=lambda x: prediction_types.get(prediction_type),
)
info = installed_models.get(location)
if not info:
logger.error("Import failed")
raise HTTPException(status_code=415)
logger.info(f"Successfully imported {location}, got {info}")
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
model_name=info.name, base_model=info.base_model, model_type=info.model_type
)
return ImportModelResponseValidator.validate_python(model_raw)
except ModelNotFoundException as e:
logger.error(str(e))
raise HTTPException(status_code=404, detail=str(e))
except InvalidModelException as e:
logger.error(str(e))
raise HTTPException(status_code=415)
except ValueError as e:
logger.error(str(e))
raise HTTPException(status_code=409, detail=str(e))
@models_router.post(
"/add",
operation_id="add_model",
responses={
201: {"description": "The model added successfully"},
404: {"description": "The model could not be found"},
424: {"description": "The model appeared to add successfully, but could not be found in the model manager"},
409: {"description": "There is already a model corresponding to this path or repo_id"},
},
status_code=201,
response_model=ImportModelResponse,
)
async def add_model(
info: Union[tuple(OPENAPI_MODEL_CONFIGS)] = Body(description="Model configuration"),
) -> ImportModelResponse:
"""Add a model using the configuration information appropriate for its type. Only local models can be added by path"""
logger = ApiDependencies.invoker.services.logger
try:
ApiDependencies.invoker.services.model_manager.add_model(
info.model_name,
info.base_model,
info.model_type,
model_attributes=info.model_dump(),
)
logger.info(f"Successfully added {info.model_name}")
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
model_name=info.model_name,
base_model=info.base_model,
model_type=info.model_type,
)
return ImportModelResponseValidator.validate_python(model_raw)
except ModelNotFoundException as e:
logger.error(str(e))
raise HTTPException(status_code=404, detail=str(e))
except ValueError as e:
logger.error(str(e))
raise HTTPException(status_code=409, detail=str(e))
@models_router.delete(
"/{base_model}/{model_type}/{model_name}",
operation_id="del_model",
responses={
204: {"description": "Model deleted successfully"},
404: {"description": "Model not found"},
},
status_code=204,
response_model=None,
)
async def delete_model(
base_model: BaseModelType = Path(description="Base model"),
model_type: ModelType = Path(description="The type of model"),
model_name: str = Path(description="model name"),
) -> Response:
"""Delete Model"""
logger = ApiDependencies.invoker.services.logger
try:
ApiDependencies.invoker.services.model_manager.del_model(
model_name, base_model=base_model, model_type=model_type
)
logger.info(f"Deleted model: {model_name}")
return Response(status_code=204)
except ModelNotFoundException as e:
logger.error(str(e))
raise HTTPException(status_code=404, detail=str(e))
@models_router.put(
"/convert/{base_model}/{model_type}/{model_name}",
operation_id="convert_model",
responses={
200: {"description": "Model converted successfully"},
400: {"description": "Bad request"},
404: {"description": "Model not found"},
},
status_code=200,
response_model=ConvertModelResponse,
)
async def convert_model(
base_model: BaseModelType = Path(description="Base model"),
model_type: ModelType = Path(description="The type of model"),
model_name: str = Path(description="model name"),
convert_dest_directory: Optional[str] = Query(
default=None, description="Save the converted model to the designated directory"
),
) -> ConvertModelResponse:
"""Convert a checkpoint model into a diffusers model, optionally saving to the indicated destination directory, or `models` if none."""
logger = ApiDependencies.invoker.services.logger
try:
logger.info(f"Converting model: {model_name}")
dest = pathlib.Path(convert_dest_directory) if convert_dest_directory else None
ApiDependencies.invoker.services.model_manager.convert_model(
model_name,
base_model=base_model,
model_type=model_type,
convert_dest_directory=dest,
)
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
model_name, base_model=base_model, model_type=model_type
)
response = ConvertModelResponseValidator.validate_python(model_raw)
except ModelNotFoundException as e:
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found: {str(e)}")
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
return response
@models_router.get(
"/search",
operation_id="search_for_models",
responses={
200: {"description": "Directory searched successfully"},
404: {"description": "Invalid directory path"},
},
status_code=200,
response_model=List[pathlib.Path],
)
async def search_for_models(
search_path: pathlib.Path = Query(description="Directory path to search for models"),
) -> List[pathlib.Path]:
if not search_path.is_dir():
raise HTTPException(
status_code=404,
detail=f"The search path '{search_path}' does not exist or is not directory",
)
return ApiDependencies.invoker.services.model_manager.search_for_models(search_path)
@models_router.get(
"/ckpt_confs",
operation_id="list_ckpt_configs",
responses={
200: {"description": "paths retrieved successfully"},
},
status_code=200,
response_model=List[pathlib.Path],
)
async def list_ckpt_configs() -> List[pathlib.Path]:
"""Return a list of the legacy checkpoint configuration files stored in `ROOT/configs/stable-diffusion`, relative to ROOT."""
return ApiDependencies.invoker.services.model_manager.list_checkpoint_configs()
@models_router.post(
"/sync",
operation_id="sync_to_config",
responses={
201: {"description": "synchronization successful"},
},
status_code=201,
response_model=bool,
)
async def sync_to_config() -> bool:
"""Call after making changes to models.yaml, autoimport directories or models directory to synchronize
in-memory data structures with disk data structures."""
ApiDependencies.invoker.services.model_manager.sync_to_config()
return True
# There's some weird pydantic-fastapi behaviour that requires this to be a separate class
# TODO: After a few updates, see if it works inside the route operation handler?
class MergeModelsBody(BaseModel):
model_names: List[str] = Field(description="model name", min_length=2, max_length=3)
merged_model_name: Optional[str] = Field(description="Name of destination model")
alpha: Optional[float] = Field(description="Alpha weighting strength to apply to 2d and 3d models", default=0.5)
interp: Optional[MergeInterpolationMethod] = Field(description="Interpolation method")
force: Optional[bool] = Field(
description="Force merging of models created with different versions of diffusers",
default=False,
)
merge_dest_directory: Optional[str] = Field(
description="Save the merged model to the designated directory (with 'merged_model_name' appended)",
default=None,
)
model_config = ConfigDict(protected_namespaces=())
@models_router.put(
"/merge/{base_model}",
operation_id="merge_models",
responses={
200: {"description": "Model converted successfully"},
400: {"description": "Incompatible models"},
404: {"description": "One or more models not found"},
},
status_code=200,
response_model=MergeModelResponse,
)
async def merge_models(
body: Annotated[MergeModelsBody, Body(description="Model configuration", embed=True)],
base_model: BaseModelType = Path(description="Base model"),
) -> MergeModelResponse:
"""Convert a checkpoint model into a diffusers model"""
logger = ApiDependencies.invoker.services.logger
try:
logger.info(
f"Merging models: {body.model_names} into {body.merge_dest_directory or '<MODELS>'}/{body.merged_model_name}"
)
dest = pathlib.Path(body.merge_dest_directory) if body.merge_dest_directory else None
result = ApiDependencies.invoker.services.model_manager.merge_models(
model_names=body.model_names,
base_model=base_model,
merged_model_name=body.merged_model_name or "+".join(body.model_names),
alpha=body.alpha,
interp=body.interp,
force=body.force,
merge_dest_directory=dest,
)
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
result.name,
base_model=base_model,
model_type=ModelType.Main,
)
response = ConvertModelResponseValidator.validate_python(model_raw)
except ModelNotFoundException:
raise HTTPException(
status_code=404,
detail=f"One or more of the models '{body.model_names}' not found",
)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
return response

View File

@ -48,7 +48,7 @@ if True: # hack to make flake8 happy with imports coming after setting up the c
boards, boards,
download_queue, download_queue,
images, images,
model_manager_v2, model_manager,
session_queue, session_queue,
sessions, sessions,
utilities, utilities,
@ -113,7 +113,7 @@ async def shutdown_event() -> None:
app.include_router(sessions.session_router, prefix="/api") app.include_router(sessions.session_router, prefix="/api")
app.include_router(utilities.utilities_router, prefix="/api") app.include_router(utilities.utilities_router, prefix="/api")
app.include_router(model_manager_v2.model_manager_v2_router, prefix="/api") app.include_router(model_manager.model_manager_router, prefix="/api")
app.include_router(download_queue.download_queue_router, prefix="/api") app.include_router(download_queue.download_queue_router, prefix="/api")
app.include_router(images.images_router, prefix="/api") app.include_router(images.images_router, prefix="/api")
app.include_router(boards.boards_router, prefix="/api") app.include_router(boards.boards_router, prefix="/api")
@ -175,21 +175,23 @@ def custom_openapi() -> dict[str, Any]:
invoker_schema["class"] = "invocation" invoker_schema["class"] = "invocation"
openapi_schema["components"]["schemas"][f"{output_type_title}"]["class"] = "output" openapi_schema["components"]["schemas"][f"{output_type_title}"]["class"] = "output"
from invokeai.backend.model_management.models import get_model_config_enums # This code no longer seems to be necessary?
# Leave it here just in case
#
# from invokeai.backend.model_manager import get_model_config_formats
# formats = get_model_config_formats()
# for model_config_name, enum_set in formats.items():
for model_config_format_enum in set(get_model_config_enums()): # if model_config_name in openapi_schema["components"]["schemas"]:
name = model_config_format_enum.__qualname__ # # print(f"Config with name {name} already defined")
# continue
if name in openapi_schema["components"]["schemas"]: # openapi_schema["components"]["schemas"][model_config_name] = {
# print(f"Config with name {name} already defined") # "title": model_config_name,
continue # "description": "An enumeration.",
# "type": "string",
openapi_schema["components"]["schemas"][name] = { # "enum": [v.value for v in enum_set],
"title": name, # }
"description": "An enumeration.",
"type": "string",
"enum": [v.value for v in model_config_format_enum],
}
app.openapi_schema = openapi_schema app.openapi_schema = openapi_schema
return app.openapi_schema return app.openapi_schema

View File

@ -18,15 +18,15 @@ from invokeai.app.services.model_records import UnknownModelException
from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.ti_utils import extract_ti_triggers_from_prompt from invokeai.app.util.ti_utils import extract_ti_triggers_from_prompt
from invokeai.backend.lora import LoRAModelRaw from invokeai.backend.lora import LoRAModelRaw
from invokeai.backend.model_patcher import ModelPatcher
from invokeai.backend.textual_inversion import TextualInversionModelRaw
from invokeai.backend.model_manager import ModelType from invokeai.backend.model_manager import ModelType
from invokeai.backend.model_patcher import ModelPatcher
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
BasicConditioningInfo, BasicConditioningInfo,
ConditioningFieldData, ConditioningFieldData,
ExtraConditioningInfo, ExtraConditioningInfo,
SDXLConditioningInfo, SDXLConditioningInfo,
) )
from invokeai.backend.textual_inversion import TextualInversionModelRaw
from invokeai.backend.util.devices import torch_dtype from invokeai.backend.util.devices import torch_dtype
from .baseinvocation import ( from .baseinvocation import (

View File

@ -68,7 +68,7 @@ class InvokeAISettings(BaseSettings):
return OmegaConf.to_yaml(conf) return OmegaConf.to_yaml(conf)
@classmethod @classmethod
def add_parser_arguments(cls, parser) -> None: def add_parser_arguments(cls, parser: ArgumentParser) -> None:
"""Dynamically create arguments for a settings parser.""" """Dynamically create arguments for a settings parser."""
if "type" in get_type_hints(cls): if "type" in get_type_hints(cls):
settings_stanza = get_args(get_type_hints(cls)["type"])[0] settings_stanza = get_args(get_type_hints(cls)["type"])[0]

View File

@ -29,8 +29,8 @@ writes to the system log is stored in InvocationServices.performance_statistics.
""" """
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from contextlib import AbstractContextManager
from pathlib import Path from pathlib import Path
from typing import Iterator
from invokeai.app.invocations.baseinvocation import BaseInvocation from invokeai.app.invocations.baseinvocation import BaseInvocation
from invokeai.app.services.invocation_stats.invocation_stats_common import InvocationStatsSummary from invokeai.app.services.invocation_stats.invocation_stats_common import InvocationStatsSummary
@ -40,18 +40,17 @@ class InvocationStatsServiceBase(ABC):
"Abstract base class for recording node memory/time performance statistics" "Abstract base class for recording node memory/time performance statistics"
@abstractmethod @abstractmethod
def __init__(self): def __init__(self) -> None:
""" """
Initialize the InvocationStatsService and reset counters to zero Initialize the InvocationStatsService and reset counters to zero
""" """
pass
@abstractmethod @abstractmethod
def collect_stats( def collect_stats(
self, self,
invocation: BaseInvocation, invocation: BaseInvocation,
graph_execution_state_id: str, graph_execution_state_id: str,
) -> AbstractContextManager: ) -> Iterator[None]:
""" """
Return a context object that will capture the statistics on the execution Return a context object that will capture the statistics on the execution
of invocaation. Use with: to place around the part of the code that executes the invocation. of invocaation. Use with: to place around the part of the code that executes the invocation.
@ -61,7 +60,7 @@ class InvocationStatsServiceBase(ABC):
pass pass
@abstractmethod @abstractmethod
def reset_stats(self, graph_execution_state_id: str): def reset_stats(self, graph_execution_state_id: str) -> None:
""" """
Reset all statistics for the indicated graph. Reset all statistics for the indicated graph.
:param graph_execution_state_id: The id of the session whose stats to reset. :param graph_execution_state_id: The id of the session whose stats to reset.
@ -70,7 +69,7 @@ class InvocationStatsServiceBase(ABC):
pass pass
@abstractmethod @abstractmethod
def log_stats(self, graph_execution_state_id: str): def log_stats(self, graph_execution_state_id: str) -> None:
""" """
Write out the accumulated statistics to the log or somewhere else. Write out the accumulated statistics to the log or somewhere else.
:param graph_execution_state_id: The id of the session whose stats to log. :param graph_execution_state_id: The id of the session whose stats to log.

View File

@ -14,7 +14,7 @@ from typing_extensions import Annotated
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.download import DownloadJob, DownloadQueueServiceBase from invokeai.app.services.download import DownloadJob, DownloadQueueServiceBase
from invokeai.app.services.events import EventServiceBase from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.services.invoker import Invoker from invokeai.app.services.invoker import Invoker
from invokeai.app.services.model_records import ModelRecordServiceBase from invokeai.app.services.model_records import ModelRecordServiceBase
from invokeai.backend.model_manager import AnyModelConfig, ModelRepoVariant from invokeai.backend.model_manager import AnyModelConfig, ModelRepoVariant

View File

@ -5,7 +5,7 @@ from abc import ABC, abstractmethod
from typing import Optional from typing import Optional
from invokeai.app.services.shared.invocation_context import InvocationContextData from invokeai.app.services.shared.invocation_context import InvocationContextData
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, BaseModelType, ModelType, SubModelType from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType
from invokeai.backend.model_manager.load import LoadedModel from invokeai.backend.model_manager.load import LoadedModel
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
@ -15,23 +15,7 @@ class ModelLoadServiceBase(ABC):
"""Wrapper around AnyModelLoader.""" """Wrapper around AnyModelLoader."""
@abstractmethod @abstractmethod
def load_model_by_key( def load_model(
self,
key: str,
submodel_type: Optional[SubModelType] = None,
context_data: Optional[InvocationContextData] = None,
) -> LoadedModel:
"""
Given a model's key, load it and return the LoadedModel object.
:param key: Key of model config to be fetched.
:param submodel: For main (pipeline models), the submodel to fetch.
:param context_data: Invocation context data used for event reporting
"""
pass
@abstractmethod
def load_model_by_config(
self, self,
model_config: AnyModelConfig, model_config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None, submodel_type: Optional[SubModelType] = None,
@ -44,34 +28,6 @@ class ModelLoadServiceBase(ABC):
:param submodel: For main (pipeline models), the submodel to fetch. :param submodel: For main (pipeline models), the submodel to fetch.
:param context_data: Invocation context data used for event reporting :param context_data: Invocation context data used for event reporting
""" """
pass
@abstractmethod
def load_model_by_attr(
self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
submodel: Optional[SubModelType] = None,
context_data: Optional[InvocationContextData] = None,
) -> LoadedModel:
"""
Given a model's attributes, search the database for it, and if found, load and return the LoadedModel object.
This is provided for API compatability with the get_model() method
in the original model manager. However, note that LoadedModel is
not the same as the original ModelInfo that ws returned.
:param model_name: Name of to be fetched.
:param base_model: Base model
:param model_type: Type of the model
:param submodel: For main (pipeline models), the submodel to fetch
:param context_data: The invocation context data.
Exceptions: UnknownModelException -- model with these attributes not known
NotImplementedException -- a model loader was not provided at initialization time
ValueError -- more than one model matches this combination
"""
@property @property
@abstractmethod @abstractmethod

View File

@ -1,15 +1,18 @@
# Copyright (c) 2024 Lincoln D. Stein and the InvokeAI Team # Copyright (c) 2024 Lincoln D. Stein and the InvokeAI Team
"""Implementation of model loader service.""" """Implementation of model loader service."""
from typing import Optional from typing import Optional, Type
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.invocation_processor.invocation_processor_common import CanceledException from invokeai.app.services.invocation_processor.invocation_processor_common import CanceledException
from invokeai.app.services.invoker import Invoker from invokeai.app.services.invoker import Invoker
from invokeai.app.services.model_records import ModelRecordServiceBase, UnknownModelException
from invokeai.app.services.shared.invocation_context import InvocationContextData from invokeai.app.services.shared.invocation_context import InvocationContextData
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, BaseModelType, ModelType, SubModelType from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType
from invokeai.backend.model_manager.load import AnyModelLoader, LoadedModel, ModelCache, ModelConvertCache from invokeai.backend.model_manager.load import (
LoadedModel,
ModelLoaderRegistry,
ModelLoaderRegistryBase,
)
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
from invokeai.backend.util.logging import InvokeAILogger from invokeai.backend.util.logging import InvokeAILogger
@ -18,25 +21,23 @@ from .model_load_base import ModelLoadServiceBase
class ModelLoadService(ModelLoadServiceBase): class ModelLoadService(ModelLoadServiceBase):
"""Wrapper around AnyModelLoader.""" """Wrapper around ModelLoaderRegistry."""
def __init__( def __init__(
self, self,
app_config: InvokeAIAppConfig, app_config: InvokeAIAppConfig,
record_store: ModelRecordServiceBase,
ram_cache: ModelCacheBase[AnyModel], ram_cache: ModelCacheBase[AnyModel],
convert_cache: ModelConvertCacheBase, convert_cache: ModelConvertCacheBase,
registry: Optional[Type[ModelLoaderRegistryBase]] = ModelLoaderRegistry,
): ):
"""Initialize the model load service.""" """Initialize the model load service."""
logger = InvokeAILogger.get_logger(self.__class__.__name__) logger = InvokeAILogger.get_logger(self.__class__.__name__)
logger.setLevel(app_config.log_level.upper()) logger.setLevel(app_config.log_level.upper())
self._store = record_store self._logger = logger
self._any_loader = AnyModelLoader( self._app_config = app_config
app_config=app_config, self._ram_cache = ram_cache
logger=logger, self._convert_cache = convert_cache
ram_cache=ram_cache, self._registry = registry
convert_cache=convert_cache,
)
def start(self, invoker: Invoker) -> None: def start(self, invoker: Invoker) -> None:
self._invoker = invoker self._invoker = invoker
@ -44,63 +45,14 @@ class ModelLoadService(ModelLoadServiceBase):
@property @property
def ram_cache(self) -> ModelCacheBase[AnyModel]: def ram_cache(self) -> ModelCacheBase[AnyModel]:
"""Return the RAM cache used by this loader.""" """Return the RAM cache used by this loader."""
return self._any_loader.ram_cache return self._ram_cache
@property @property
def convert_cache(self) -> ModelConvertCacheBase: def convert_cache(self) -> ModelConvertCacheBase:
"""Return the checkpoint convert cache used by this loader.""" """Return the checkpoint convert cache used by this loader."""
return self._any_loader.convert_cache return self._convert_cache
def load_model_by_key( def load_model(
self,
key: str,
submodel_type: Optional[SubModelType] = None,
context_data: Optional[InvocationContextData] = None,
) -> LoadedModel:
"""
Given a model's key, load it and return the LoadedModel object.
:param key: Key of model config to be fetched.
:param submodel: For main (pipeline models), the submodel to fetch.
:param context: Invocation context used for event reporting
"""
config = self._store.get_model(key)
return self.load_model_by_config(config, submodel_type, context_data)
def load_model_by_attr(
self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
submodel: Optional[SubModelType] = None,
context_data: Optional[InvocationContextData] = None,
) -> LoadedModel:
"""
Given a model's attributes, search the database for it, and if found, load and return the LoadedModel object.
This is provided for API compatability with the get_model() method
in the original model manager. However, note that LoadedModel is
not the same as the original ModelInfo that ws returned.
:param model_name: Name of to be fetched.
:param base_model: Base model
:param model_type: Type of the model
:param submodel: For main (pipeline models), the submodel to fetch
:param context: The invocation context.
Exceptions: UnknownModelException -- model with this key not known
NotImplementedException -- a model loader was not provided at initialization time
ValueError -- more than one model matches this combination
"""
configs = self._store.search_by_attr(model_name, base_model, model_type)
if len(configs) == 0:
raise UnknownModelException(f"{base_model}/{model_type}/{model_name}: Unknown model")
elif len(configs) > 1:
raise ValueError(f"{base_model}/{model_type}/{model_name}: More than one model matches.")
else:
return self.load_model_by_key(configs[0].key, submodel)
def load_model_by_config(
self, self,
model_config: AnyModelConfig, model_config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None, submodel_type: Optional[SubModelType] = None,
@ -118,7 +70,15 @@ class ModelLoadService(ModelLoadServiceBase):
context_data=context_data, context_data=context_data,
model_config=model_config, model_config=model_config,
) )
loaded_model = self._any_loader.load_model(model_config, submodel_type)
implementation, model_config, submodel_type = self._registry.get_implementation(model_config, submodel_type) # type: ignore
loaded_model: LoadedModel = implementation(
app_config=self._app_config,
logger=self._logger,
ram_cache=self._ram_cache,
convert_cache=self._convert_cache,
).load_model(model_config, submodel_type)
if context_data: if context_data:
self._emit_load_event( self._emit_load_event(
context_data=context_data, context_data=context_data,

View File

@ -3,7 +3,7 @@
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, BaseModelType, ModelType, SubModelType from invokeai.backend.model_manager import AnyModel, AnyModelConfig, BaseModelType, ModelType, SubModelType
from invokeai.backend.model_manager.load import LoadedModel from invokeai.backend.model_manager.load import LoadedModel
from .model_manager_default import ModelManagerServiceBase, ModelManagerService from .model_manager_default import ModelManagerService, ModelManagerServiceBase
__all__ = [ __all__ = [
"ModelManagerServiceBase", "ModelManagerServiceBase",

View File

@ -1,10 +1,14 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team # Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Optional
from typing_extensions import Self from typing_extensions import Self
from invokeai.app.services.invoker import Invoker from invokeai.app.services.invoker import Invoker
from invokeai.app.services.shared.invocation_context import InvocationContextData
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelType, SubModelType
from invokeai.backend.model_manager.load.load_base import LoadedModel
from ..config import InvokeAIAppConfig from ..config import InvokeAIAppConfig
from ..download import DownloadQueueServiceBase from ..download import DownloadQueueServiceBase
@ -65,3 +69,32 @@ class ModelManagerServiceBase(ABC):
@abstractmethod @abstractmethod
def stop(self, invoker: Invoker) -> None: def stop(self, invoker: Invoker) -> None:
pass pass
@abstractmethod
def load_model_by_config(
self,
model_config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
context_data: Optional[InvocationContextData] = None,
) -> LoadedModel:
pass
@abstractmethod
def load_model_by_key(
self,
key: str,
submodel_type: Optional[SubModelType] = None,
context_data: Optional[InvocationContextData] = None,
) -> LoadedModel:
pass
@abstractmethod
def load_model_by_attr(
self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
submodel: Optional[SubModelType] = None,
context_data: Optional[InvocationContextData] = None,
) -> LoadedModel:
pass

View File

@ -1,10 +1,14 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team # Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
"""Implementation of ModelManagerServiceBase.""" """Implementation of ModelManagerServiceBase."""
from typing import Optional
from typing_extensions import Self from typing_extensions import Self
from invokeai.app.services.invoker import Invoker from invokeai.app.services.invoker import Invoker
from invokeai.backend.model_manager.load import ModelCache, ModelConvertCache from invokeai.app.services.shared.invocation_context import InvocationContextData
from invokeai.backend.model_manager import AnyModelConfig, BaseModelType, LoadedModel, ModelType, SubModelType
from invokeai.backend.model_manager.load import ModelCache, ModelConvertCache, ModelLoaderRegistry
from invokeai.backend.util.logging import InvokeAILogger from invokeai.backend.util.logging import InvokeAILogger
from ..config import InvokeAIAppConfig from ..config import InvokeAIAppConfig
@ -12,7 +16,7 @@ from ..download import DownloadQueueServiceBase
from ..events.events_base import EventServiceBase from ..events.events_base import EventServiceBase
from ..model_install import ModelInstallService, ModelInstallServiceBase from ..model_install import ModelInstallService, ModelInstallServiceBase
from ..model_load import ModelLoadService, ModelLoadServiceBase from ..model_load import ModelLoadService, ModelLoadServiceBase
from ..model_records import ModelRecordServiceBase from ..model_records import ModelRecordServiceBase, UnknownModelException
from .model_manager_base import ModelManagerServiceBase from .model_manager_base import ModelManagerServiceBase
@ -58,6 +62,56 @@ class ModelManagerService(ModelManagerServiceBase):
if hasattr(service, "stop"): if hasattr(service, "stop"):
service.stop(invoker) service.stop(invoker)
def load_model_by_config(
self,
model_config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
context_data: Optional[InvocationContextData] = None,
) -> LoadedModel:
return self.load.load_model(model_config, submodel_type, context_data)
def load_model_by_key(
self,
key: str,
submodel_type: Optional[SubModelType] = None,
context_data: Optional[InvocationContextData] = None,
) -> LoadedModel:
config = self.store.get_model(key)
return self.load.load_model(config, submodel_type, context_data)
def load_model_by_attr(
self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
submodel: Optional[SubModelType] = None,
context_data: Optional[InvocationContextData] = None,
) -> LoadedModel:
"""
Given a model's attributes, search the database for it, and if found, load and return the LoadedModel object.
This is provided for API compatability with the get_model() method
in the original model manager. However, note that LoadedModel is
not the same as the original ModelInfo that ws returned.
:param model_name: Name of to be fetched.
:param base_model: Base model
:param model_type: Type of the model
:param submodel: For main (pipeline models), the submodel to fetch
:param context: The invocation context.
Exceptions: UnknownModelException -- model with this key not known
NotImplementedException -- a model loader was not provided at initialization time
ValueError -- more than one model matches this combination
"""
configs = self.store.search_by_attr(model_name, base_model, model_type)
if len(configs) == 0:
raise UnknownModelException(f"{base_model}/{model_type}/{model_name}: Unknown model")
elif len(configs) > 1:
raise ValueError(f"{base_model}/{model_type}/{model_name}: More than one model matches.")
else:
return self.load.load_model(configs[0], submodel, context_data)
@classmethod @classmethod
def build_model_manager( def build_model_manager(
cls, cls,
@ -82,9 +136,9 @@ class ModelManagerService(ModelManagerServiceBase):
) )
loader = ModelLoadService( loader = ModelLoadService(
app_config=app_config, app_config=app_config,
record_store=model_record_service,
ram_cache=ram_cache, ram_cache=ram_cache,
convert_cache=convert_cache, convert_cache=convert_cache,
registry=ModelLoaderRegistry,
) )
installer = ModelInstallService( installer = ModelInstallService(
app_config=app_config, app_config=app_config,

View File

@ -281,7 +281,7 @@ class ModelsInterface(InvocationContextInterface):
# The model manager emits events as it loads the model. It needs the context data to build # The model manager emits events as it loads the model. It needs the context data to build
# the event payloads. # the event payloads.
return self._services.model_manager.load.load_model_by_key( return self._services.model_manager.load_model_by_key(
key=key, submodel_type=submodel_type, context_data=self._context_data key=key, submodel_type=submodel_type, context_data=self._context_data
) )
@ -296,7 +296,7 @@ class ModelsInterface(InvocationContextInterface):
:param model_type: Type of the model :param model_type: Type of the model
:param submodel: For main (pipeline models), the submodel to fetch :param submodel: For main (pipeline models), the submodel to fetch
""" """
return self._services.model_manager.load.load_model_by_attr( return self._services.model_manager.load_model_by_attr(
model_name=model_name, model_name=model_name,
base_model=base_model, base_model=base_model,
model_type=model_type, model_type=model_type,

View File

@ -1,591 +0,0 @@
"""
Migrate the models directory and models.yaml file from an existing
InvokeAI 2.3 installation to 3.0.0.
"""
import argparse
import os
import shutil
import warnings
from dataclasses import dataclass
from pathlib import Path
from typing import Union
import diffusers
import transformers
import yaml
from diffusers import AutoencoderKL, StableDiffusionPipeline
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from omegaconf import DictConfig, OmegaConf
from transformers import AutoFeatureExtractor, BertTokenizerFast, CLIPTextModel, CLIPTokenizer
import invokeai.backend.util.logging as logger
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.model_management import ModelManager
from invokeai.backend.model_management.model_probe import BaseModelType, ModelProbe, ModelProbeInfo, ModelType
warnings.filterwarnings("ignore")
transformers.logging.set_verbosity_error()
diffusers.logging.set_verbosity_error()
# holder for paths that we will migrate
@dataclass
class ModelPaths:
models: Path
embeddings: Path
loras: Path
controlnets: Path
class MigrateTo3(object):
def __init__(
self,
from_root: Path,
to_models: Path,
model_manager: ModelManager,
src_paths: ModelPaths,
):
self.root_directory = from_root
self.dest_models = to_models
self.mgr = model_manager
self.src_paths = src_paths
@classmethod
def initialize_yaml(cls, yaml_file: Path):
with open(yaml_file, "w") as file:
file.write(yaml.dump({"__metadata__": {"version": "3.0.0"}}))
def create_directory_structure(self):
"""
Create the basic directory structure for the models folder.
"""
for model_base in [BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2]:
for model_type in [
ModelType.Main,
ModelType.Vae,
ModelType.Lora,
ModelType.ControlNet,
ModelType.TextualInversion,
]:
path = self.dest_models / model_base.value / model_type.value
path.mkdir(parents=True, exist_ok=True)
path = self.dest_models / "core"
path.mkdir(parents=True, exist_ok=True)
@staticmethod
def copy_file(src: Path, dest: Path):
"""
copy a single file with logging
"""
if dest.exists():
logger.info(f"Skipping existing {str(dest)}")
return
logger.info(f"Copying {str(src)} to {str(dest)}")
try:
shutil.copy(src, dest)
except Exception as e:
logger.error(f"COPY FAILED: {str(e)}")
@staticmethod
def copy_dir(src: Path, dest: Path):
"""
Recursively copy a directory with logging
"""
if dest.exists():
logger.info(f"Skipping existing {str(dest)}")
return
logger.info(f"Copying {str(src)} to {str(dest)}")
try:
shutil.copytree(src, dest)
except Exception as e:
logger.error(f"COPY FAILED: {str(e)}")
def migrate_models(self, src_dir: Path):
"""
Recursively walk through src directory, probe anything
that looks like a model, and copy the model into the
appropriate location within the destination models directory.
"""
directories_scanned = set()
for root, dirs, files in os.walk(src_dir, followlinks=True):
for d in dirs:
try:
model = Path(root, d)
info = ModelProbe().heuristic_probe(model)
if not info:
continue
dest = self._model_probe_to_path(info) / model.name
self.copy_dir(model, dest)
directories_scanned.add(model)
except Exception as e:
logger.error(str(e))
except KeyboardInterrupt:
raise
for f in files:
# don't copy raw learned_embeds.bin or pytorch_lora_weights.bin
# let them be copied as part of a tree copy operation
try:
if f in {"learned_embeds.bin", "pytorch_lora_weights.bin"}:
continue
model = Path(root, f)
if model.parent in directories_scanned:
continue
info = ModelProbe().heuristic_probe(model)
if not info:
continue
dest = self._model_probe_to_path(info) / f
self.copy_file(model, dest)
except Exception as e:
logger.error(str(e))
except KeyboardInterrupt:
raise
def migrate_support_models(self):
"""
Copy the clipseg, upscaler, and restoration models to their new
locations.
"""
dest_directory = self.dest_models
if (self.root_directory / "models/clipseg").exists():
self.copy_dir(self.root_directory / "models/clipseg", dest_directory / "core/misc/clipseg")
if (self.root_directory / "models/realesrgan").exists():
self.copy_dir(self.root_directory / "models/realesrgan", dest_directory / "core/upscaling/realesrgan")
for d in ["codeformer", "gfpgan"]:
path = self.root_directory / "models" / d
if path.exists():
self.copy_dir(path, dest_directory / f"core/face_restoration/{d}")
def migrate_tuning_models(self):
"""
Migrate the embeddings, loras and controlnets directories to their new homes.
"""
for src in [self.src_paths.embeddings, self.src_paths.loras, self.src_paths.controlnets]:
if not src:
continue
if src.is_dir():
logger.info(f"Scanning {src}")
self.migrate_models(src)
else:
logger.info(f"{src} directory not found; skipping")
continue
def migrate_conversion_models(self):
"""
Migrate all the models that are needed by the ckpt_to_diffusers conversion
script.
"""
dest_directory = self.dest_models
kwargs = {
"cache_dir": self.root_directory / "models/hub",
# local_files_only = True
}
try:
logger.info("Migrating core tokenizers and text encoders")
target_dir = dest_directory / "core" / "convert"
self._migrate_pretrained(
BertTokenizerFast, repo_id="bert-base-uncased", dest=target_dir / "bert-base-uncased", **kwargs
)
# sd-1
repo_id = "openai/clip-vit-large-patch14"
self._migrate_pretrained(
CLIPTokenizer, repo_id=repo_id, dest=target_dir / "clip-vit-large-patch14", **kwargs
)
self._migrate_pretrained(
CLIPTextModel, repo_id=repo_id, dest=target_dir / "clip-vit-large-patch14", force=True, **kwargs
)
# sd-2
repo_id = "stabilityai/stable-diffusion-2"
self._migrate_pretrained(
CLIPTokenizer,
repo_id=repo_id,
dest=target_dir / "stable-diffusion-2-clip" / "tokenizer",
**{"subfolder": "tokenizer", **kwargs},
)
self._migrate_pretrained(
CLIPTextModel,
repo_id=repo_id,
dest=target_dir / "stable-diffusion-2-clip" / "text_encoder",
**{"subfolder": "text_encoder", **kwargs},
)
# VAE
logger.info("Migrating stable diffusion VAE")
self._migrate_pretrained(
AutoencoderKL, repo_id="stabilityai/sd-vae-ft-mse", dest=target_dir / "sd-vae-ft-mse", **kwargs
)
# safety checking
logger.info("Migrating safety checker")
repo_id = "CompVis/stable-diffusion-safety-checker"
self._migrate_pretrained(
AutoFeatureExtractor, repo_id=repo_id, dest=target_dir / "stable-diffusion-safety-checker", **kwargs
)
self._migrate_pretrained(
StableDiffusionSafetyChecker,
repo_id=repo_id,
dest=target_dir / "stable-diffusion-safety-checker",
**kwargs,
)
except KeyboardInterrupt:
raise
except Exception as e:
logger.error(str(e))
def _model_probe_to_path(self, info: ModelProbeInfo) -> Path:
return Path(self.dest_models, info.base_type.value, info.model_type.value)
def _migrate_pretrained(self, model_class, repo_id: str, dest: Path, force: bool = False, **kwargs):
if dest.exists() and not force:
logger.info(f"Skipping existing {dest}")
return
model = model_class.from_pretrained(repo_id, **kwargs)
self._save_pretrained(model, dest, overwrite=force)
def _save_pretrained(self, model, dest: Path, overwrite: bool = False):
model_name = dest.name
if overwrite:
model.save_pretrained(dest, safe_serialization=True)
else:
download_path = dest.with_name(f"{model_name}.downloading")
model.save_pretrained(download_path, safe_serialization=True)
download_path.replace(dest)
def _download_vae(self, repo_id: str, subfolder: str = None) -> Path:
vae = AutoencoderKL.from_pretrained(repo_id, cache_dir=self.root_directory / "models/hub", subfolder=subfolder)
info = ModelProbe().heuristic_probe(vae)
_, model_name = repo_id.split("/")
dest = self._model_probe_to_path(info) / self.unique_name(model_name, info)
vae.save_pretrained(dest, safe_serialization=True)
return dest
def _vae_path(self, vae: Union[str, dict]) -> Path:
"""
Convert 2.3 VAE stanza to a straight path.
"""
vae_path = None
# First get a path
if isinstance(vae, str):
vae_path = vae
elif isinstance(vae, DictConfig):
if p := vae.get("path"):
vae_path = p
elif repo_id := vae.get("repo_id"):
if repo_id == "stabilityai/sd-vae-ft-mse": # this guy is already downloaded
vae_path = "models/core/convert/sd-vae-ft-mse"
return vae_path
else:
vae_path = self._download_vae(repo_id, vae.get("subfolder"))
assert vae_path is not None, "Couldn't find VAE for this model"
# if the VAE is in the old models directory, then we must move it into the new
# one. VAEs outside of this directory can stay where they are.
vae_path = Path(vae_path)
if vae_path.is_relative_to(self.src_paths.models):
info = ModelProbe().heuristic_probe(vae_path)
dest = self._model_probe_to_path(info) / vae_path.name
if not dest.exists():
if vae_path.is_dir():
self.copy_dir(vae_path, dest)
else:
self.copy_file(vae_path, dest)
vae_path = dest
if vae_path.is_relative_to(self.dest_models):
rel_path = vae_path.relative_to(self.dest_models)
return Path("models", rel_path)
else:
return vae_path
def migrate_repo_id(self, repo_id: str, model_name: str = None, **extra_config):
"""
Migrate a locally-cached diffusers pipeline identified with a repo_id
"""
dest_dir = self.dest_models
cache = self.root_directory / "models/hub"
kwargs = {
"cache_dir": cache,
"safety_checker": None,
# local_files_only = True,
}
owner, repo_name = repo_id.split("/")
model_name = model_name or repo_name
model = cache / "--".join(["models", owner, repo_name])
if len(list(model.glob("snapshots/**/model_index.json"))) == 0:
return
revisions = [x.name for x in model.glob("refs/*")]
# if an fp16 is available we use that
revision = "fp16" if len(revisions) > 1 and "fp16" in revisions else revisions[0]
pipeline = StableDiffusionPipeline.from_pretrained(repo_id, revision=revision, **kwargs)
info = ModelProbe().heuristic_probe(pipeline)
if not info:
return
if self.mgr.model_exists(model_name, info.base_type, info.model_type):
logger.warning(f"A model named {model_name} already exists at the destination. Skipping migration.")
return
dest = self._model_probe_to_path(info) / model_name
self._save_pretrained(pipeline, dest)
rel_path = Path("models", dest.relative_to(dest_dir))
self._add_model(model_name, info, rel_path, **extra_config)
def migrate_path(self, location: Path, model_name: str = None, **extra_config):
"""
Migrate a model referred to using 'weights' or 'path'
"""
# handle relative paths
dest_dir = self.dest_models
location = self.root_directory / location
model_name = model_name or location.stem
info = ModelProbe().heuristic_probe(location)
if not info:
return
if self.mgr.model_exists(model_name, info.base_type, info.model_type):
logger.warning(f"A model named {model_name} already exists at the destination. Skipping migration.")
return
# uh oh, weights is in the old models directory - move it into the new one
if Path(location).is_relative_to(self.src_paths.models):
dest = Path(dest_dir, info.base_type.value, info.model_type.value, location.name)
if location.is_dir():
self.copy_dir(location, dest)
else:
self.copy_file(location, dest)
location = Path("models", info.base_type.value, info.model_type.value, location.name)
self._add_model(model_name, info, location, **extra_config)
def _add_model(self, model_name: str, info: ModelProbeInfo, location: Path, **extra_config):
if info.model_type != ModelType.Main:
return
self.mgr.add_model(
model_name=model_name,
base_model=info.base_type,
model_type=info.model_type,
clobber=True,
model_attributes={
"path": str(location),
"description": f"A {info.base_type.value} {info.model_type.value} model",
"model_format": info.format,
"variant": info.variant_type.value,
**extra_config,
},
)
def migrate_defined_models(self):
"""
Migrate models defined in models.yaml
"""
# find any models referred to in old models.yaml
conf = OmegaConf.load(self.root_directory / "configs/models.yaml")
for model_name, stanza in conf.items():
try:
passthru_args = {}
if vae := stanza.get("vae"):
try:
passthru_args["vae"] = str(self._vae_path(vae))
except Exception as e:
logger.warning(f'Could not find a VAE matching "{vae}" for model "{model_name}"')
logger.warning(str(e))
if config := stanza.get("config"):
passthru_args["config"] = config
if description := stanza.get("description"):
passthru_args["description"] = description
if repo_id := stanza.get("repo_id"):
logger.info(f"Migrating diffusers model {model_name}")
self.migrate_repo_id(repo_id, model_name, **passthru_args)
elif location := stanza.get("weights"):
logger.info(f"Migrating checkpoint model {model_name}")
self.migrate_path(Path(location), model_name, **passthru_args)
elif location := stanza.get("path"):
logger.info(f"Migrating diffusers model {model_name}")
self.migrate_path(Path(location), model_name, **passthru_args)
except KeyboardInterrupt:
raise
except Exception as e:
logger.error(str(e))
def migrate(self):
self.create_directory_structure()
# the configure script is doing this
self.migrate_support_models()
self.migrate_conversion_models()
self.migrate_tuning_models()
self.migrate_defined_models()
def _parse_legacy_initfile(root: Path, initfile: Path) -> ModelPaths:
"""
Returns tuple of (embedding_path, lora_path, controlnet_path)
"""
parser = argparse.ArgumentParser(fromfile_prefix_chars="@")
parser.add_argument(
"--embedding_directory",
"--embedding_path",
type=Path,
dest="embedding_path",
default=Path("embeddings"),
)
parser.add_argument(
"--lora_directory",
dest="lora_path",
type=Path,
default=Path("loras"),
)
opt, _ = parser.parse_known_args([f"@{str(initfile)}"])
return ModelPaths(
models=root / "models",
embeddings=root / str(opt.embedding_path).strip('"'),
loras=root / str(opt.lora_path).strip('"'),
controlnets=root / "controlnets",
)
def _parse_legacy_yamlfile(root: Path, initfile: Path) -> ModelPaths:
"""
Returns tuple of (embedding_path, lora_path, controlnet_path)
"""
# Don't use the config object because it is unforgiving of version updates
# Just use omegaconf directly
opt = OmegaConf.load(initfile)
paths = opt.InvokeAI.Paths
models = paths.get("models_dir", "models")
embeddings = paths.get("embedding_dir", "embeddings")
loras = paths.get("lora_dir", "loras")
controlnets = paths.get("controlnet_dir", "controlnets")
return ModelPaths(
models=root / models if models else None,
embeddings=root / embeddings if embeddings else None,
loras=root / loras if loras else None,
controlnets=root / controlnets if controlnets else None,
)
def get_legacy_embeddings(root: Path) -> ModelPaths:
path = root / "invokeai.init"
if path.exists():
return _parse_legacy_initfile(root, path)
path = root / "invokeai.yaml"
if path.exists():
return _parse_legacy_yamlfile(root, path)
def do_migrate(src_directory: Path, dest_directory: Path):
"""
Migrate models from src to dest InvokeAI root directories
"""
config_file = dest_directory / "configs" / "models.yaml.3"
dest_models = dest_directory / "models.3"
version_3 = (dest_directory / "models" / "core").exists()
# Here we create the destination models.yaml file.
# If we are writing into a version 3 directory and the
# file already exists, then we write into a copy of it to
# avoid deleting its previous customizations. Otherwise we
# create a new empty one.
if version_3: # write into the dest directory
try:
shutil.copy(dest_directory / "configs" / "models.yaml", config_file)
except Exception:
MigrateTo3.initialize_yaml(config_file)
mgr = ModelManager(config_file) # important to initialize BEFORE moving the models directory
(dest_directory / "models").replace(dest_models)
else:
MigrateTo3.initialize_yaml(config_file)
mgr = ModelManager(config_file)
paths = get_legacy_embeddings(src_directory)
migrator = MigrateTo3(from_root=src_directory, to_models=dest_models, model_manager=mgr, src_paths=paths)
migrator.migrate()
print("Migration successful.")
if not version_3:
(dest_directory / "models").replace(src_directory / "models.orig")
print(f"Original models directory moved to {dest_directory}/models.orig")
(dest_directory / "configs" / "models.yaml").replace(src_directory / "configs" / "models.yaml.orig")
print(f"Original models.yaml file moved to {dest_directory}/configs/models.yaml.orig")
config_file.replace(config_file.with_suffix(""))
dest_models.replace(dest_models.with_suffix(""))
def main():
parser = argparse.ArgumentParser(
prog="invokeai-migrate3",
description="""
This will copy and convert the models directory and the configs/models.yaml from the InvokeAI 2.3 format
'--from-directory' root to the InvokeAI 3.0 '--to-directory' root. These may be abbreviated '--from' and '--to'.a
The old models directory and config file will be renamed 'models.orig' and 'models.yaml.orig' respectively.
It is safe to provide the same directory for both arguments, but it is better to use the invokeai_configure
script, which will perform a full upgrade in place.""",
)
parser.add_argument(
"--from-directory",
dest="src_root",
type=Path,
required=True,
help='Source InvokeAI 2.3 root directory (containing "invokeai.init" or "invokeai.yaml")',
)
parser.add_argument(
"--to-directory",
dest="dest_root",
type=Path,
required=True,
help='Destination InvokeAI 3.0 directory (containing "invokeai.yaml")',
)
args = parser.parse_args()
src_root = args.src_root
assert src_root.is_dir(), f"{src_root} is not a valid directory"
assert (src_root / "models").is_dir(), f"{src_root} does not contain a 'models' subdirectory"
assert (src_root / "models" / "hub").exists(), f"{src_root} does not contain a version 2.3 models directory"
assert (src_root / "invokeai.init").exists() or (
src_root / "invokeai.yaml"
).exists(), f"{src_root} does not contain an InvokeAI init file."
dest_root = args.dest_root
assert dest_root.is_dir(), f"{dest_root} is not a valid directory"
config = InvokeAIAppConfig.get_config()
config.parse_args(["--root", str(dest_root)])
# TODO: revisit - don't rely on invokeai.yaml to exist yet!
dest_is_setup = (dest_root / "models/core").exists() and (dest_root / "databases").exists()
if not dest_is_setup:
from invokeai.backend.install.invokeai_configure import initialize_rootdir
initialize_rootdir(dest_root, True)
do_migrate(src_root, dest_root)
if __name__ == "__main__":
main()

View File

@ -1,637 +0,0 @@
"""
Utility (backend) functions used by model_install.py
"""
import os
import re
import shutil
import warnings
from dataclasses import dataclass, field
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Callable, Dict, List, Optional, Set, Union
import requests
import torch
from diffusers import DiffusionPipeline
from diffusers import logging as dlogging
from huggingface_hub import HfApi, HfFolder, hf_hub_url
from omegaconf import OmegaConf
from tqdm import tqdm
import invokeai.configs as configs
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.model_management import AddModelResult, BaseModelType, ModelManager, ModelType, ModelVariantType
from invokeai.backend.model_management.model_probe import ModelProbe, ModelProbeInfo, SchedulerPredictionType
from invokeai.backend.util import download_with_resume
from invokeai.backend.util.devices import choose_torch_device, torch_dtype
from ..util.logging import InvokeAILogger
warnings.filterwarnings("ignore")
# --------------------------globals-----------------------
config = InvokeAIAppConfig.get_config()
logger = InvokeAILogger.get_logger(name="InvokeAI")
# the initial "configs" dir is now bundled in the `invokeai.configs` package
Dataset_path = Path(configs.__path__[0]) / "INITIAL_MODELS.yaml"
Config_preamble = """
# This file describes the alternative machine learning models
# available to InvokeAI script.
#
# To add a new model, follow the examples below. Each
# model requires a model config file, a weights file,
# and the width and height of the images it
# was trained on.
"""
LEGACY_CONFIGS = {
BaseModelType.StableDiffusion1: {
ModelVariantType.Normal: {
SchedulerPredictionType.Epsilon: "v1-inference.yaml",
SchedulerPredictionType.VPrediction: "v1-inference-v.yaml",
},
ModelVariantType.Inpaint: {
SchedulerPredictionType.Epsilon: "v1-inpainting-inference.yaml",
SchedulerPredictionType.VPrediction: "v1-inpainting-inference-v.yaml",
},
},
BaseModelType.StableDiffusion2: {
ModelVariantType.Normal: {
SchedulerPredictionType.Epsilon: "v2-inference.yaml",
SchedulerPredictionType.VPrediction: "v2-inference-v.yaml",
},
ModelVariantType.Inpaint: {
SchedulerPredictionType.Epsilon: "v2-inpainting-inference.yaml",
SchedulerPredictionType.VPrediction: "v2-inpainting-inference-v.yaml",
},
},
BaseModelType.StableDiffusionXL: {
ModelVariantType.Normal: "sd_xl_base.yaml",
},
BaseModelType.StableDiffusionXLRefiner: {
ModelVariantType.Normal: "sd_xl_refiner.yaml",
},
}
@dataclass
class InstallSelections:
install_models: List[str] = field(default_factory=list)
remove_models: List[str] = field(default_factory=list)
@dataclass
class ModelLoadInfo:
name: str
model_type: ModelType
base_type: BaseModelType
path: Optional[Path] = None
repo_id: Optional[str] = None
subfolder: Optional[str] = None
description: str = ""
installed: bool = False
recommended: bool = False
default: bool = False
requires: Optional[List[str]] = field(default_factory=list)
class ModelInstall(object):
def __init__(
self,
config: InvokeAIAppConfig,
prediction_type_helper: Optional[Callable[[Path], SchedulerPredictionType]] = None,
model_manager: Optional[ModelManager] = None,
access_token: Optional[str] = None,
civitai_api_key: Optional[str] = None,
):
self.config = config
self.mgr = model_manager or ModelManager(config.model_conf_path)
self.datasets = OmegaConf.load(Dataset_path)
self.prediction_helper = prediction_type_helper
self.access_token = access_token or HfFolder.get_token()
self.civitai_api_key = civitai_api_key or config.civitai_api_key
self.reverse_paths = self._reverse_paths(self.datasets)
def all_models(self) -> Dict[str, ModelLoadInfo]:
"""
Return dict of model_key=>ModelLoadInfo objects.
This method consolidates and simplifies the entries in both
models.yaml and INITIAL_MODELS.yaml so that they can
be treated uniformly. It also sorts the models alphabetically
by their name, to improve the display somewhat.
"""
model_dict = {}
# first populate with the entries in INITIAL_MODELS.yaml
for key, value in self.datasets.items():
name, base, model_type = ModelManager.parse_key(key)
value["name"] = name
value["base_type"] = base
value["model_type"] = model_type
model_info = ModelLoadInfo(**value)
if model_info.subfolder and model_info.repo_id:
model_info.repo_id += f":{model_info.subfolder}"
model_dict[key] = model_info
# supplement with entries in models.yaml
installed_models = list(self.mgr.list_models())
for md in installed_models:
base = md["base_model"]
model_type = md["model_type"]
name = md["model_name"]
key = ModelManager.create_key(name, base, model_type)
if key in model_dict:
model_dict[key].installed = True
else:
model_dict[key] = ModelLoadInfo(
name=name,
base_type=base,
model_type=model_type,
path=value.get("path"),
installed=True,
)
return {x: model_dict[x] for x in sorted(model_dict.keys(), key=lambda y: model_dict[y].name.lower())}
def _is_autoloaded(self, model_info: dict) -> bool:
path = model_info.get("path")
if not path:
return False
for autodir in ["autoimport_dir", "lora_dir", "embedding_dir", "controlnet_dir"]:
if autodir_path := getattr(self.config, autodir):
autodir_path = self.config.root_path / autodir_path
if Path(path).is_relative_to(autodir_path):
return True
return False
def list_models(self, model_type):
installed = self.mgr.list_models(model_type=model_type)
print()
print(f"Installed models of type `{model_type}`:")
print(f"{'Model Key':50} Model Path")
for i in installed:
print(f"{'/'.join([i['base_model'],i['model_type'],i['model_name']]):50} {i['path']}")
print()
# logic here a little reversed to maintain backward compatibility
def starter_models(self, all_models: bool = False) -> Set[str]:
models = set()
for key, _value in self.datasets.items():
name, base, model_type = ModelManager.parse_key(key)
if all_models or model_type in [ModelType.Main, ModelType.Vae]:
models.add(key)
return models
def recommended_models(self) -> Set[str]:
starters = self.starter_models(all_models=True)
return {x for x in starters if self.datasets[x].get("recommended", False)}
def default_model(self) -> str:
starters = self.starter_models()
defaults = [x for x in starters if self.datasets[x].get("default", False)]
return defaults[0]
def install(self, selections: InstallSelections):
verbosity = dlogging.get_verbosity() # quench NSFW nags
dlogging.set_verbosity_error()
job = 1
jobs = len(selections.remove_models) + len(selections.install_models)
# remove requested models
for key in selections.remove_models:
name, base, mtype = self.mgr.parse_key(key)
logger.info(f"Deleting {mtype} model {name} [{job}/{jobs}]")
try:
self.mgr.del_model(name, base, mtype)
except FileNotFoundError as e:
logger.warning(e)
job += 1
# add requested models
self._remove_installed(selections.install_models)
self._add_required_models(selections.install_models)
for path in selections.install_models:
logger.info(f"Installing {path} [{job}/{jobs}]")
try:
self.heuristic_import(path)
except (ValueError, KeyError) as e:
logger.error(str(e))
job += 1
dlogging.set_verbosity(verbosity)
self.mgr.commit()
def heuristic_import(
self,
model_path_id_or_url: Union[str, Path],
models_installed: Set[Path] = None,
) -> Dict[str, AddModelResult]:
"""
:param model_path_id_or_url: A Path to a local model to import, or a string representing its repo_id or URL
:param models_installed: Set of installed models, used for recursive invocation
Returns a set of dict objects corresponding to newly-created stanzas in models.yaml.
"""
if not models_installed:
models_installed = {}
model_path_id_or_url = str(model_path_id_or_url).strip("\"' ")
# A little hack to allow nested routines to retrieve info on the requested ID
self.current_id = model_path_id_or_url
path = Path(model_path_id_or_url)
# fix relative paths
if path.exists() and not path.is_absolute():
path = path.absolute() # make relative to current WD
# checkpoint file, or similar
if path.is_file():
models_installed.update({str(path): self._install_path(path)})
# folders style or similar
elif path.is_dir() and any(
(path / x).exists()
for x in {
"config.json",
"model_index.json",
"learned_embeds.bin",
"pytorch_lora_weights.bin",
"pytorch_lora_weights.safetensors",
}
):
models_installed.update({str(model_path_id_or_url): self._install_path(path)})
# recursive scan
elif path.is_dir():
for child in path.iterdir():
self.heuristic_import(child, models_installed=models_installed)
# huggingface repo
elif len(str(model_path_id_or_url).split("/")) == 2:
models_installed.update({str(model_path_id_or_url): self._install_repo(str(model_path_id_or_url))})
# a URL
elif str(model_path_id_or_url).startswith(("http:", "https:", "ftp:")):
models_installed.update({str(model_path_id_or_url): self._install_url(model_path_id_or_url)})
else:
raise KeyError(f"{str(model_path_id_or_url)} is not recognized as a local path, repo ID or URL. Skipping")
return models_installed
def _remove_installed(self, model_list: List[str]):
all_models = self.all_models()
models_to_remove = []
for path in model_list:
key = self.reverse_paths.get(path)
if key and all_models[key].installed:
models_to_remove.append(path)
for path in models_to_remove:
logger.warning(f"{path} already installed. Skipping")
model_list.remove(path)
def _add_required_models(self, model_list: List[str]):
additional_models = []
all_models = self.all_models()
for path in model_list:
if not (key := self.reverse_paths.get(path)):
continue
for requirement in all_models[key].requires:
requirement_key = self.reverse_paths.get(requirement)
if not all_models[requirement_key].installed:
additional_models.append(requirement)
model_list.extend(additional_models)
# install a model from a local path. The optional info parameter is there to prevent
# the model from being probed twice in the event that it has already been probed.
def _install_path(self, path: Path, info: ModelProbeInfo = None) -> AddModelResult:
info = info or ModelProbe().heuristic_probe(path, self.prediction_helper)
if not info:
logger.warning(f"Unable to parse format of {path}")
return None
model_name = path.stem if path.is_file() else path.name
if self.mgr.model_exists(model_name, info.base_type, info.model_type):
raise ValueError(f'A model named "{model_name}" is already installed.')
attributes = self._make_attributes(path, info)
return self.mgr.add_model(
model_name=model_name,
base_model=info.base_type,
model_type=info.model_type,
model_attributes=attributes,
)
def _install_url(self, url: str) -> AddModelResult:
with TemporaryDirectory(dir=self.config.models_path) as staging:
CIVITAI_RE = r".*civitai.com.*"
civit_url = re.match(CIVITAI_RE, url, re.IGNORECASE)
location = download_with_resume(
url, Path(staging), access_token=self.civitai_api_key if civit_url else None
)
if not location:
logger.error(f"Unable to download {url}. Skipping.")
info = ModelProbe().heuristic_probe(location, self.prediction_helper)
dest = self.config.models_path / info.base_type.value / info.model_type.value / location.name
dest.parent.mkdir(parents=True, exist_ok=True)
models_path = shutil.move(location, dest)
# staged version will be garbage-collected at this time
return self._install_path(Path(models_path), info)
def _install_repo(self, repo_id: str) -> AddModelResult:
# hack to recover models stored in subfolders --
# Required to get the "v2" model of monster-labs/control_v1p_sd15_qrcode_monster
subfolder = None
if match := re.match(r"^([^/]+/[^/]+):(\w+)$", repo_id):
repo_id = match.group(1)
subfolder = match.group(2)
hinfo = HfApi().model_info(repo_id)
# we try to figure out how to download this most economically
# list all the files in the repo
files = [x.rfilename for x in hinfo.siblings]
if subfolder:
files = [x for x in files if x.startswith(f"{subfolder}/")]
prefix = f"{subfolder}/" if subfolder else ""
location = None
with TemporaryDirectory(dir=self.config.models_path) as staging:
staging = Path(staging)
if f"{prefix}model_index.json" in files:
location = self._download_hf_pipeline(repo_id, staging, subfolder=subfolder) # pipeline
elif f"{prefix}unet/model.onnx" in files:
location = self._download_hf_model(repo_id, files, staging)
else:
for suffix in ["safetensors", "bin"]:
if f"{prefix}pytorch_lora_weights.{suffix}" in files:
location = self._download_hf_model(
repo_id, [f"pytorch_lora_weights.{suffix}"], staging, subfolder=subfolder
) # LoRA
break
elif (
self.config.precision == "float16" and f"{prefix}diffusion_pytorch_model.fp16.{suffix}" in files
): # vae, controlnet or some other standalone
files = ["config.json", f"diffusion_pytorch_model.fp16.{suffix}"]
location = self._download_hf_model(repo_id, files, staging, subfolder=subfolder)
break
elif f"{prefix}diffusion_pytorch_model.{suffix}" in files:
files = ["config.json", f"diffusion_pytorch_model.{suffix}"]
location = self._download_hf_model(repo_id, files, staging, subfolder=subfolder)
break
elif f"{prefix}learned_embeds.{suffix}" in files:
location = self._download_hf_model(
repo_id, [f"learned_embeds.{suffix}"], staging, subfolder=subfolder
)
break
elif (
f"{prefix}image_encoder.txt" in files and f"{prefix}ip_adapter.{suffix}" in files
): # IP-Adapter
files = ["image_encoder.txt", f"ip_adapter.{suffix}"]
location = self._download_hf_model(repo_id, files, staging, subfolder=subfolder)
break
elif f"{prefix}model.{suffix}" in files and f"{prefix}config.json" in files:
# This elif-condition is pretty fragile, but it is intended to handle CLIP Vision models hosted
# by InvokeAI for use with IP-Adapters.
files = ["config.json", f"model.{suffix}"]
location = self._download_hf_model(repo_id, files, staging, subfolder=subfolder)
break
if not location:
logger.warning(f"Could not determine type of repo {repo_id}. Skipping install.")
return {}
info = ModelProbe().heuristic_probe(location, self.prediction_helper)
if not info:
logger.warning(f"Could not probe {location}. Skipping install.")
return {}
dest = (
self.config.models_path
/ info.base_type.value
/ info.model_type.value
/ self._get_model_name(repo_id, location)
)
if dest.exists():
shutil.rmtree(dest)
shutil.copytree(location, dest)
return self._install_path(dest, info)
def _get_model_name(self, path_name: str, location: Path) -> str:
"""
Calculate a name for the model - primitive implementation.
"""
if key := self.reverse_paths.get(path_name):
(name, base, mtype) = ModelManager.parse_key(key)
return name
elif location.is_dir():
return location.name
else:
return location.stem
def _make_attributes(self, path: Path, info: ModelProbeInfo) -> dict:
model_name = path.name if path.is_dir() else path.stem
description = f"{info.base_type.value} {info.model_type.value} model {model_name}"
if key := self.reverse_paths.get(self.current_id):
if key in self.datasets:
description = self.datasets[key].get("description") or description
rel_path = self.relative_to_root(path, self.config.models_path)
attributes = {
"path": str(rel_path),
"description": str(description),
"model_format": info.format,
}
legacy_conf = None
if info.model_type == ModelType.Main or info.model_type == ModelType.ONNX:
attributes.update(
{
"variant": info.variant_type,
}
)
if info.format == "checkpoint":
try:
possible_conf = path.with_suffix(".yaml")
if possible_conf.exists():
legacy_conf = str(self.relative_to_root(possible_conf))
elif info.base_type in [BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2]:
legacy_conf = Path(
self.config.legacy_conf_dir,
LEGACY_CONFIGS[info.base_type][info.variant_type][info.prediction_type],
)
else:
legacy_conf = Path(
self.config.legacy_conf_dir, LEGACY_CONFIGS[info.base_type][info.variant_type]
)
except KeyError:
legacy_conf = Path(self.config.legacy_conf_dir, "v1-inference.yaml") # best guess
if info.model_type == ModelType.ControlNet and info.format == "checkpoint":
possible_conf = path.with_suffix(".yaml")
if possible_conf.exists():
legacy_conf = str(self.relative_to_root(possible_conf))
else:
legacy_conf = Path(
self.config.root_path,
"configs/controlnet",
("cldm_v15.yaml" if info.base_type == BaseModelType("sd-1") else "cldm_v21.yaml"),
)
if legacy_conf:
attributes.update({"config": str(legacy_conf)})
return attributes
def relative_to_root(self, path: Path, root: Optional[Path] = None) -> Path:
root = root or self.config.root_path
if path.is_relative_to(root):
return path.relative_to(root)
else:
return path
def _download_hf_pipeline(self, repo_id: str, staging: Path, subfolder: str = None) -> Path:
"""
Retrieve a StableDiffusion model from cache or remote and then
does a save_pretrained() to the indicated staging area.
"""
_, name = repo_id.split("/")
precision = torch_dtype(choose_torch_device())
variants = ["fp16", None] if precision == torch.float16 else [None, "fp16"]
model = None
for variant in variants:
try:
model = DiffusionPipeline.from_pretrained(
repo_id,
variant=variant,
torch_dtype=precision,
safety_checker=None,
subfolder=subfolder,
)
except Exception as e: # most errors are due to fp16 not being present. Fix this to catch other errors
if "fp16" not in str(e):
print(e)
if model:
break
if not model:
logger.error(f"Diffusers model {repo_id} could not be downloaded. Skipping.")
return None
model.save_pretrained(staging / name, safe_serialization=True)
return staging / name
def _download_hf_model(self, repo_id: str, files: List[str], staging: Path, subfolder: None) -> Path:
_, name = repo_id.split("/")
location = staging / name
paths = []
for filename in files:
filePath = Path(filename)
p = hf_download_with_resume(
repo_id,
model_dir=location / filePath.parent,
model_name=filePath.name,
access_token=self.access_token,
subfolder=filePath.parent / subfolder if subfolder else filePath.parent,
)
if p:
paths.append(p)
else:
logger.warning(f"Could not download {filename} from {repo_id}.")
return location if len(paths) > 0 else None
@classmethod
def _reverse_paths(cls, datasets) -> dict:
"""
Reverse mapping from repo_id/path to destination name.
"""
return {v.get("path") or v.get("repo_id"): k for k, v in datasets.items()}
# -------------------------------------
def yes_or_no(prompt: str, default_yes=True):
default = "y" if default_yes else "n"
response = input(f"{prompt} [{default}] ") or default
if default_yes:
return response[0] not in ("n", "N")
else:
return response[0] in ("y", "Y")
# ---------------------------------------------
def hf_download_from_pretrained(model_class: object, model_name: str, destination: Path, **kwargs):
logger = InvokeAILogger.get_logger("InvokeAI")
logger.addFilter(lambda x: "fp16 is not a valid" not in x.getMessage())
model = model_class.from_pretrained(
model_name,
resume_download=True,
**kwargs,
)
model.save_pretrained(destination, safe_serialization=True)
return destination
# ---------------------------------------------
def hf_download_with_resume(
repo_id: str,
model_dir: str,
model_name: str,
model_dest: Path = None,
access_token: str = None,
subfolder: str = None,
) -> Path:
model_dest = model_dest or Path(os.path.join(model_dir, model_name))
os.makedirs(model_dir, exist_ok=True)
url = hf_hub_url(repo_id, model_name, subfolder=subfolder)
header = {"Authorization": f"Bearer {access_token}"} if access_token else {}
open_mode = "wb"
exist_size = 0
if os.path.exists(model_dest):
exist_size = os.path.getsize(model_dest)
header["Range"] = f"bytes={exist_size}-"
open_mode = "ab"
resp = requests.get(url, headers=header, stream=True)
total = int(resp.headers.get("content-length", 0))
if resp.status_code == 416: # "range not satisfiable", which means nothing to return
logger.info(f"{model_name}: complete file found. Skipping.")
return model_dest
elif resp.status_code == 404:
logger.warning("File not found")
return None
elif resp.status_code != 200:
logger.warning(f"{model_name}: {resp.reason}")
elif exist_size > 0:
logger.info(f"{model_name}: partial file found. Resuming...")
else:
logger.info(f"{model_name}: Downloading...")
try:
with (
open(model_dest, open_mode) as file,
tqdm(
desc=model_name,
initial=exist_size,
total=total + exist_size,
unit="iB",
unit_scale=True,
unit_divisor=1000,
) as bar,
):
for data in resp.iter_content(chunk_size=1024):
size = file.write(data)
bar.update(size)
except Exception as e:
logger.error(f"An error occurred while downloading {model_name}: {str(e)}")
return None
return model_dest

View File

@ -9,8 +9,8 @@ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
from invokeai.backend.ip_adapter.ip_attention_weights import IPAttentionWeights from invokeai.backend.ip_adapter.ip_attention_weights import IPAttentionWeights
from .resampler import Resampler
from ..raw_model import RawModel from ..raw_model import RawModel
from .resampler import Resampler
class ImageProjModel(torch.nn.Module): class ImageProjModel(torch.nn.Module):

View File

@ -10,6 +10,7 @@ from safetensors.torch import load_file
from typing_extensions import Self from typing_extensions import Self
from invokeai.backend.model_manager import BaseModelType from invokeai.backend.model_manager import BaseModelType
from .raw_model import RawModel from .raw_model import RawModel
@ -366,6 +367,7 @@ class IA3Layer(LoRALayerBase):
AnyLoRALayer = Union[LoRALayer, LoHALayer, LoKRLayer, FullLayer, IA3Layer] AnyLoRALayer = Union[LoRALayer, LoHALayer, LoKRLayer, FullLayer, IA3Layer]
class LoRAModelRaw(RawModel): # (torch.nn.Module): class LoRAModelRaw(RawModel): # (torch.nn.Module):
_name: str _name: str
layers: Dict[str, AnyLoRALayer] layers: Dict[str, AnyLoRALayer]

View File

@ -1,27 +0,0 @@
# Model Cache
## `glibc` Memory Allocator Fragmentation
Python (and PyTorch) relies on the memory allocator from the C Standard Library (`libc`). On linux, with the GNU C Standard Library implementation (`glibc`), our memory access patterns have been observed to cause severe memory fragmentation. This fragmentation results in large amounts of memory that has been freed but can't be released back to the OS. Loading models from disk and moving them between CPU/CUDA seem to be the operations that contribute most to the fragmentation. This memory fragmentation issue can result in OOM crashes during frequent model switching, even if `max_cache_size` is set to a reasonable value (e.g. a OOM crash with `max_cache_size=16` on a system with 32GB of RAM).
This problem may also exist on other OSes, and other `libc` implementations. But, at the time of writing, it has only been investigated on linux with `glibc`.
To better understand how the `glibc` memory allocator works, see these references:
- Basics: https://www.gnu.org/software/libc/manual/html_node/The-GNU-Allocator.html
- Details: https://sourceware.org/glibc/wiki/MallocInternals
Note the differences between memory allocated as chunks in an arena vs. memory allocated with `mmap`. Under `glibc`'s default configuration, most model tensors get allocated as chunks in an arena making them vulnerable to the problem of fragmentation.
We can work around this memory fragmentation issue by setting the following env var:
```bash
# Force blocks >1MB to be allocated with `mmap` so that they are released to the system immediately when they are freed.
MALLOC_MMAP_THRESHOLD_=1048576
```
See the following references for more information about the `malloc` tunable parameters:
- https://www.gnu.org/software/libc/manual/html_node/Malloc-Tunable-Parameters.html
- https://www.gnu.org/software/libc/manual/html_node/Memory-Allocation-Tunables.html
- https://man7.org/linux/man-pages/man3/mallopt.3.html
The model cache emits debug logs that provide visibility into the state of the `libc` memory allocator. See the `LibcUtil` class for more info on how these `libc` malloc stats are collected.

View File

@ -1,20 +0,0 @@
# ruff: noqa: I001, F401
"""
Initialization file for invokeai.backend.model_management
"""
# This import must be first
from .model_manager import AddModelResult, LoadedModelInfo, ModelManager, SchedulerPredictionType
from .lora import ModelPatcher, ONNXModelPatcher
from .model_cache import ModelCache
from .models import (
BaseModelType,
DuplicateModelException,
ModelNotFoundException,
ModelType,
ModelVariantType,
SubModelType,
)
# This import must be last
from .model_merge import MergeInterpolationMethod, ModelMerger

View File

@ -1,31 +0,0 @@
# Copyright (c) 2024 Lincoln Stein and the InvokeAI Development Team
"""
This module exports the function has_baked_in_sdxl_vae().
It returns True if an SDXL checkpoint model has the original SDXL 1.0 VAE,
which doesn't work properly in fp16 mode.
"""
import hashlib
from pathlib import Path
from safetensors.torch import load_file
SDXL_1_0_VAE_HASH = "bc40b16c3a0fa4625abdfc01c04ffc21bf3cefa6af6c7768ec61eb1f1ac0da51"
def has_baked_in_sdxl_vae(checkpoint_path: Path) -> bool:
"""Return true if the checkpoint contains a custom (non SDXL-1.0) VAE."""
hash = _vae_hash(checkpoint_path)
return hash != SDXL_1_0_VAE_HASH
def _vae_hash(checkpoint_path: Path) -> str:
checkpoint = load_file(checkpoint_path, device="cpu")
vae_keys = [x for x in checkpoint.keys() if x.startswith("first_stage_model.")]
hash = hashlib.new("sha256")
for key in vae_keys:
value = checkpoint[key]
hash.update(bytes(key, "UTF-8"))
hash.update(bytes(str(value), "UTF-8"))
return hash.hexdigest()

View File

@ -1,582 +0,0 @@
from __future__ import annotations
import pickle
from contextlib import contextmanager
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np
import torch
from compel.embeddings_provider import BaseTextualInversionManager
from diffusers.models import UNet2DConditionModel
from safetensors.torch import load_file
from transformers import CLIPTextModel, CLIPTokenizer
from invokeai.app.shared.models import FreeUConfig
from invokeai.backend.model_management.model_load_optimizations import skip_torch_weight_init
from .models.lora import LoRAModel
"""
loras = [
(lora_model1, 0.7),
(lora_model2, 0.4),
]
with LoRAHelper.apply_lora_unet(unet, loras):
# unet with applied loras
# unmodified unet
"""
# TODO: rename smth like ModelPatcher and add TI method?
class ModelPatcher:
@staticmethod
def _resolve_lora_key(model: torch.nn.Module, lora_key: str, prefix: str) -> Tuple[str, torch.nn.Module]:
assert "." not in lora_key
if not lora_key.startswith(prefix):
raise Exception(f"lora_key with invalid prefix: {lora_key}, {prefix}")
module = model
module_key = ""
key_parts = lora_key[len(prefix) :].split("_")
submodule_name = key_parts.pop(0)
while len(key_parts) > 0:
try:
module = module.get_submodule(submodule_name)
module_key += "." + submodule_name
submodule_name = key_parts.pop(0)
except Exception:
submodule_name += "_" + key_parts.pop(0)
module = module.get_submodule(submodule_name)
module_key = (module_key + "." + submodule_name).lstrip(".")
return (module_key, module)
@classmethod
@contextmanager
def apply_lora_unet(
cls,
unet: UNet2DConditionModel,
loras: List[Tuple[LoRAModel, float]],
):
with cls.apply_lora(unet, loras, "lora_unet_"):
yield
@classmethod
@contextmanager
def apply_lora_text_encoder(
cls,
text_encoder: CLIPTextModel,
loras: List[Tuple[LoRAModel, float]],
):
with cls.apply_lora(text_encoder, loras, "lora_te_"):
yield
@classmethod
@contextmanager
def apply_sdxl_lora_text_encoder(
cls,
text_encoder: CLIPTextModel,
loras: List[Tuple[LoRAModel, float]],
):
with cls.apply_lora(text_encoder, loras, "lora_te1_"):
yield
@classmethod
@contextmanager
def apply_sdxl_lora_text_encoder2(
cls,
text_encoder: CLIPTextModel,
loras: List[Tuple[LoRAModel, float]],
):
with cls.apply_lora(text_encoder, loras, "lora_te2_"):
yield
@classmethod
@contextmanager
def apply_lora(
cls,
model: torch.nn.Module,
loras: List[Tuple[LoRAModel, float]], # THIS IS INCORRECT. IT IS ACTUALLY A LoRAModelRaw
prefix: str,
):
original_weights = {}
try:
with torch.no_grad():
for lora, lora_weight in loras:
# assert lora.device.type == "cpu"
for layer_key, layer in lora.layers.items():
if not layer_key.startswith(prefix):
continue
# TODO(ryand): A non-negligible amount of time is currently spent resolving LoRA keys. This
# should be improved in the following ways:
# 1. The key mapping could be more-efficiently pre-computed. This would save time every time a
# LoRA model is applied.
# 2. From an API perspective, there's no reason that the `ModelPatcher` should be aware of the
# intricacies of Stable Diffusion key resolution. It should just expect the input LoRA
# weights to have valid keys.
module_key, module = cls._resolve_lora_key(model, layer_key, prefix)
# All of the LoRA weight calculations will be done on the same device as the module weight.
# (Performance will be best if this is a CUDA device.)
device = module.weight.device
dtype = module.weight.dtype
if module_key not in original_weights:
original_weights[module_key] = module.weight.detach().to(device="cpu", copy=True)
layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0
# We intentionally move to the target device first, then cast. Experimentally, this was found to
# be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the
# same thing in a single call to '.to(...)'.
layer.to(device=device)
layer.to(dtype=torch.float32)
# TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA
# devices here. Experimentally, it was found to be very slow on CPU. More investigation needed.
layer_weight = layer.get_weight(module.weight) * (lora_weight * layer_scale)
layer.to(device="cpu")
if module.weight.shape != layer_weight.shape:
# TODO: debug on lycoris
layer_weight = layer_weight.reshape(module.weight.shape)
module.weight += layer_weight.to(dtype=dtype)
yield # wait for context manager exit
finally:
with torch.no_grad():
for module_key, weight in original_weights.items():
model.get_submodule(module_key).weight.copy_(weight)
@classmethod
@contextmanager
def apply_ti(
cls,
tokenizer: CLIPTokenizer,
text_encoder: CLIPTextModel,
ti_list: List[Tuple[str, Any]],
) -> Tuple[CLIPTokenizer, TextualInversionManager]:
init_tokens_count = None
new_tokens_added = None
# TODO: This is required since Transformers 4.32 see
# https://github.com/huggingface/transformers/pull/25088
# More information by NVIDIA:
# https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#requirements-tc
# This value might need to be changed in the future and take the GPUs model into account as there seem
# to be ideal values for different GPUS. This value is temporary!
# For references to the current discussion please see https://github.com/invoke-ai/InvokeAI/pull/4817
pad_to_multiple_of = 8
try:
# HACK: The CLIPTokenizer API does not include a way to remove tokens after calling add_tokens(...). As a
# workaround, we create a full copy of `tokenizer` so that its original behavior can be restored after
# exiting this `apply_ti(...)` context manager.
#
# In a previous implementation, the deep copy was obtained with `ti_tokenizer = copy.deepcopy(tokenizer)`,
# but a pickle roundtrip was found to be much faster (1 sec vs. 0.05 secs).
ti_tokenizer = pickle.loads(pickle.dumps(tokenizer))
ti_manager = TextualInversionManager(ti_tokenizer)
init_tokens_count = text_encoder.resize_token_embeddings(None, pad_to_multiple_of).num_embeddings
def _get_trigger(ti_name, index):
trigger = ti_name
if index > 0:
trigger += f"-!pad-{i}"
return f"<{trigger}>"
def _get_ti_embedding(model_embeddings, ti):
print(f"DEBUG: model_embeddings={type(model_embeddings)}, ti={type(ti)}")
print(f"DEBUG: is it an nn.Module? {isinstance(model_embeddings, torch.nn.Module)}")
# for SDXL models, select the embedding that matches the text encoder's dimensions
if ti.embedding_2 is not None:
return (
ti.embedding_2
if ti.embedding_2.shape[1] == model_embeddings.weight.data[0].shape[0]
else ti.embedding
)
else:
print(f"DEBUG: ti.embedding={type(ti.embedding)}")
return ti.embedding
# modify tokenizer
new_tokens_added = 0
for ti_name, ti in ti_list:
ti_embedding = _get_ti_embedding(text_encoder.get_input_embeddings(), ti)
for i in range(ti_embedding.shape[0]):
new_tokens_added += ti_tokenizer.add_tokens(_get_trigger(ti_name, i))
# Modify text_encoder.
# resize_token_embeddings(...) constructs a new torch.nn.Embedding internally. Initializing the weights of
# this embedding is slow and unnecessary, so we wrap this step in skip_torch_weight_init() to save some
# time.
with skip_torch_weight_init():
text_encoder.resize_token_embeddings(init_tokens_count + new_tokens_added, pad_to_multiple_of)
model_embeddings = text_encoder.get_input_embeddings()
for ti_name, ti in ti_list:
ti_embedding = _get_ti_embedding(text_encoder.get_input_embeddings(), ti)
ti_tokens = []
for i in range(ti_embedding.shape[0]):
embedding = ti_embedding[i]
trigger = _get_trigger(ti_name, i)
token_id = ti_tokenizer.convert_tokens_to_ids(trigger)
if token_id == ti_tokenizer.unk_token_id:
raise RuntimeError(f"Unable to find token id for token '{trigger}'")
if model_embeddings.weight.data[token_id].shape != embedding.shape:
raise ValueError(
f"Cannot load embedding for {trigger}. It was trained on a model with token dimension"
f" {embedding.shape[0]}, but the current model has token dimension"
f" {model_embeddings.weight.data[token_id].shape[0]}."
)
model_embeddings.weight.data[token_id] = embedding.to(
device=text_encoder.device, dtype=text_encoder.dtype
)
ti_tokens.append(token_id)
if len(ti_tokens) > 1:
ti_manager.pad_tokens[ti_tokens[0]] = ti_tokens[1:]
yield ti_tokenizer, ti_manager
finally:
if init_tokens_count and new_tokens_added:
text_encoder.resize_token_embeddings(init_tokens_count, pad_to_multiple_of)
@classmethod
@contextmanager
def apply_clip_skip(
cls,
text_encoder: CLIPTextModel,
clip_skip: int,
):
skipped_layers = []
try:
for _i in range(clip_skip):
skipped_layers.append(text_encoder.text_model.encoder.layers.pop(-1))
yield
finally:
while len(skipped_layers) > 0:
text_encoder.text_model.encoder.layers.append(skipped_layers.pop())
@classmethod
@contextmanager
def apply_freeu(
cls,
unet: UNet2DConditionModel,
freeu_config: Optional[FreeUConfig] = None,
):
did_apply_freeu = False
try:
if freeu_config is not None:
unet.enable_freeu(b1=freeu_config.b1, b2=freeu_config.b2, s1=freeu_config.s1, s2=freeu_config.s2)
did_apply_freeu = True
yield
finally:
if did_apply_freeu:
unet.disable_freeu()
class TextualInversionModel:
embedding: torch.Tensor # [n, 768]|[n, 1280]
embedding_2: Optional[torch.Tensor] = None # [n, 768]|[n, 1280] - for SDXL models
@classmethod
def from_checkpoint(
cls,
file_path: Union[str, Path],
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
if not isinstance(file_path, Path):
file_path = Path(file_path)
result = cls() # TODO:
if file_path.suffix == ".safetensors":
state_dict = load_file(file_path.absolute().as_posix(), device="cpu")
else:
state_dict = torch.load(file_path, map_location="cpu")
# both v1 and v2 format embeddings
# difference mostly in metadata
if "string_to_param" in state_dict:
if len(state_dict["string_to_param"]) > 1:
print(
f'Warn: Embedding "{file_path.name}" contains multiple tokens, which is not supported. The first',
" token will be used.",
)
result.embedding = next(iter(state_dict["string_to_param"].values()))
# v3 (easynegative)
elif "emb_params" in state_dict:
result.embedding = state_dict["emb_params"]
# v5(sdxl safetensors file)
elif "clip_g" in state_dict and "clip_l" in state_dict:
result.embedding = state_dict["clip_g"]
result.embedding_2 = state_dict["clip_l"]
# v4(diffusers bin files)
else:
result.embedding = next(iter(state_dict.values()))
if len(result.embedding.shape) == 1:
result.embedding = result.embedding.unsqueeze(0)
if not isinstance(result.embedding, torch.Tensor):
raise ValueError(f"Invalid embeddings file: {file_path.name}")
return result
class TextualInversionManager(BaseTextualInversionManager):
pad_tokens: Dict[int, List[int]]
tokenizer: CLIPTokenizer
def __init__(self, tokenizer: CLIPTokenizer):
self.pad_tokens = {}
self.tokenizer = tokenizer
def expand_textual_inversion_token_ids_if_necessary(self, token_ids: list[int]) -> list[int]:
if len(self.pad_tokens) == 0:
return token_ids
if token_ids[0] == self.tokenizer.bos_token_id:
raise ValueError("token_ids must not start with bos_token_id")
if token_ids[-1] == self.tokenizer.eos_token_id:
raise ValueError("token_ids must not end with eos_token_id")
new_token_ids = []
for token_id in token_ids:
new_token_ids.append(token_id)
if token_id in self.pad_tokens:
new_token_ids.extend(self.pad_tokens[token_id])
# Do not exceed the max model input size
# The -2 here is compensating for compensate compel.embeddings_provider.get_token_ids(),
# which first removes and then adds back the start and end tokens.
max_length = list(self.tokenizer.max_model_input_sizes.values())[0] - 2
if len(new_token_ids) > max_length:
new_token_ids = new_token_ids[0:max_length]
return new_token_ids
class ONNXModelPatcher:
from diffusers import OnnxRuntimeModel
from .models.base import IAIOnnxRuntimeModel
@classmethod
@contextmanager
def apply_lora_unet(
cls,
unet: OnnxRuntimeModel,
loras: List[Tuple[LoRAModel, float]],
):
with cls.apply_lora(unet, loras, "lora_unet_"):
yield
@classmethod
@contextmanager
def apply_lora_text_encoder(
cls,
text_encoder: OnnxRuntimeModel,
loras: List[Tuple[LoRAModel, float]],
):
with cls.apply_lora(text_encoder, loras, "lora_te_"):
yield
# based on
# https://github.com/ssube/onnx-web/blob/ca2e436f0623e18b4cfe8a0363fcfcf10508acf7/api/onnx_web/convert/diffusion/lora.py#L323
@classmethod
@contextmanager
def apply_lora(
cls,
model: IAIOnnxRuntimeModel,
loras: List[Tuple[LoRAModel, float]],
prefix: str,
):
from .models.base import IAIOnnxRuntimeModel
if not isinstance(model, IAIOnnxRuntimeModel):
raise Exception("Only IAIOnnxRuntimeModel models supported")
orig_weights = {}
try:
blended_loras = {}
for lora, lora_weight in loras:
for layer_key, layer in lora.layers.items():
if not layer_key.startswith(prefix):
continue
layer.to(dtype=torch.float32)
layer_key = layer_key.replace(prefix, "")
# TODO: rewrite to pass original tensor weight(required by ia3)
layer_weight = layer.get_weight(None).detach().cpu().numpy() * lora_weight
if layer_key is blended_loras:
blended_loras[layer_key] += layer_weight
else:
blended_loras[layer_key] = layer_weight
node_names = {}
for node in model.nodes.values():
node_names[node.name.replace("/", "_").replace(".", "_").lstrip("_")] = node.name
for layer_key, lora_weight in blended_loras.items():
conv_key = layer_key + "_Conv"
gemm_key = layer_key + "_Gemm"
matmul_key = layer_key + "_MatMul"
if conv_key in node_names or gemm_key in node_names:
if conv_key in node_names:
conv_node = model.nodes[node_names[conv_key]]
else:
conv_node = model.nodes[node_names[gemm_key]]
weight_name = [n for n in conv_node.input if ".weight" in n][0]
orig_weight = model.tensors[weight_name]
if orig_weight.shape[-2:] == (1, 1):
if lora_weight.shape[-2:] == (1, 1):
new_weight = orig_weight.squeeze((3, 2)) + lora_weight.squeeze((3, 2))
else:
new_weight = orig_weight.squeeze((3, 2)) + lora_weight
new_weight = np.expand_dims(new_weight, (2, 3))
else:
if orig_weight.shape != lora_weight.shape:
new_weight = orig_weight + lora_weight.reshape(orig_weight.shape)
else:
new_weight = orig_weight + lora_weight
orig_weights[weight_name] = orig_weight
model.tensors[weight_name] = new_weight.astype(orig_weight.dtype)
elif matmul_key in node_names:
weight_node = model.nodes[node_names[matmul_key]]
matmul_name = [n for n in weight_node.input if "MatMul" in n][0]
orig_weight = model.tensors[matmul_name]
new_weight = orig_weight + lora_weight.transpose()
orig_weights[matmul_name] = orig_weight
model.tensors[matmul_name] = new_weight.astype(orig_weight.dtype)
else:
# warn? err?
pass
yield
finally:
# restore original weights
for name, orig_weight in orig_weights.items():
model.tensors[name] = orig_weight
@classmethod
@contextmanager
def apply_ti(
cls,
tokenizer: CLIPTokenizer,
text_encoder: IAIOnnxRuntimeModel,
ti_list: List[Tuple[str, Any]],
) -> Tuple[CLIPTokenizer, TextualInversionManager]:
from .models.base import IAIOnnxRuntimeModel
if not isinstance(text_encoder, IAIOnnxRuntimeModel):
raise Exception("Only IAIOnnxRuntimeModel models supported")
orig_embeddings = None
try:
# HACK: The CLIPTokenizer API does not include a way to remove tokens after calling add_tokens(...). As a
# workaround, we create a full copy of `tokenizer` so that its original behavior can be restored after
# exiting this `apply_ti(...)` context manager.
#
# In a previous implementation, the deep copy was obtained with `ti_tokenizer = copy.deepcopy(tokenizer)`,
# but a pickle roundtrip was found to be much faster (1 sec vs. 0.05 secs).
ti_tokenizer = pickle.loads(pickle.dumps(tokenizer))
ti_manager = TextualInversionManager(ti_tokenizer)
def _get_trigger(ti_name, index):
trigger = ti_name
if index > 0:
trigger += f"-!pad-{i}"
return f"<{trigger}>"
# modify text_encoder
orig_embeddings = text_encoder.tensors["text_model.embeddings.token_embedding.weight"]
# modify tokenizer
new_tokens_added = 0
for ti_name, ti in ti_list:
if ti.embedding_2 is not None:
ti_embedding = (
ti.embedding_2 if ti.embedding_2.shape[1] == orig_embeddings.shape[0] else ti.embedding
)
else:
ti_embedding = ti.embedding
for i in range(ti_embedding.shape[0]):
new_tokens_added += ti_tokenizer.add_tokens(_get_trigger(ti_name, i))
embeddings = np.concatenate(
(np.copy(orig_embeddings), np.zeros((new_tokens_added, orig_embeddings.shape[1]))),
axis=0,
)
for ti_name, _ in ti_list:
ti_tokens = []
for i in range(ti_embedding.shape[0]):
embedding = ti_embedding[i].detach().numpy()
trigger = _get_trigger(ti_name, i)
token_id = ti_tokenizer.convert_tokens_to_ids(trigger)
if token_id == ti_tokenizer.unk_token_id:
raise RuntimeError(f"Unable to find token id for token '{trigger}'")
if embeddings[token_id].shape != embedding.shape:
raise ValueError(
f"Cannot load embedding for {trigger}. It was trained on a model with token dimension"
f" {embedding.shape[0]}, but the current model has token dimension"
f" {embeddings[token_id].shape[0]}."
)
embeddings[token_id] = embedding
ti_tokens.append(token_id)
if len(ti_tokens) > 1:
ti_manager.pad_tokens[ti_tokens[0]] = ti_tokens[1:]
text_encoder.tensors["text_model.embeddings.token_embedding.weight"] = embeddings.astype(
orig_embeddings.dtype
)
yield ti_tokenizer, ti_manager
finally:
# restore
if orig_embeddings is not None:
text_encoder.tensors["text_model.embeddings.token_embedding.weight"] = orig_embeddings

View File

@ -1,99 +0,0 @@
import gc
from typing import Optional
import psutil
import torch
from invokeai.backend.model_management.libc_util import LibcUtil, Struct_mallinfo2
GB = 2**30 # 1 GB
class MemorySnapshot:
"""A snapshot of RAM and VRAM usage. All values are in bytes."""
def __init__(self, process_ram: int, vram: Optional[int], malloc_info: Optional[Struct_mallinfo2]):
"""Initialize a MemorySnapshot.
Most of the time, `MemorySnapshot` will be constructed with `MemorySnapshot.capture()`.
Args:
process_ram (int): CPU RAM used by the current process.
vram (Optional[int]): VRAM used by torch.
malloc_info (Optional[Struct_mallinfo2]): Malloc info obtained from LibcUtil.
"""
self.process_ram = process_ram
self.vram = vram
self.malloc_info = malloc_info
@classmethod
def capture(cls, run_garbage_collector: bool = True):
"""Capture and return a MemorySnapshot.
Note: This function has significant overhead, particularly if `run_garbage_collector == True`.
Args:
run_garbage_collector (bool, optional): If true, gc.collect() will be run before checking the process RAM
usage. Defaults to True.
Returns:
MemorySnapshot
"""
if run_garbage_collector:
gc.collect()
# According to the psutil docs (https://psutil.readthedocs.io/en/latest/#psutil.Process.memory_info), rss is
# supported on all platforms.
process_ram = psutil.Process().memory_info().rss
if torch.cuda.is_available():
vram = torch.cuda.memory_allocated()
else:
# TODO: We could add support for mps.current_allocated_memory() as well. Leaving out for now until we have
# time to test it properly.
vram = None
try:
malloc_info = LibcUtil().mallinfo2()
except (OSError, AttributeError):
# OSError: This is expected in environments that do not have the 'libc.so.6' shared library.
# AttributeError: This is expected in environments that have `libc.so.6` but do not have the `mallinfo2` (e.g. glibc < 2.33)
# TODO: Does `mallinfo` work?
malloc_info = None
return cls(process_ram, vram, malloc_info)
def get_pretty_snapshot_diff(snapshot_1: Optional[MemorySnapshot], snapshot_2: Optional[MemorySnapshot]) -> str:
"""Get a pretty string describing the difference between two `MemorySnapshot`s."""
def get_msg_line(prefix: str, val1: int, val2: int):
diff = val2 - val1
return f"{prefix: <30} ({(diff/GB):+5.3f}): {(val1/GB):5.3f}GB -> {(val2/GB):5.3f}GB\n"
msg = ""
if snapshot_1 is None or snapshot_2 is None:
return msg
msg += get_msg_line("Process RAM", snapshot_1.process_ram, snapshot_2.process_ram)
if snapshot_1.malloc_info is not None and snapshot_2.malloc_info is not None:
msg += get_msg_line("libc mmap allocated", snapshot_1.malloc_info.hblkhd, snapshot_2.malloc_info.hblkhd)
msg += get_msg_line("libc arena used", snapshot_1.malloc_info.uordblks, snapshot_2.malloc_info.uordblks)
msg += get_msg_line("libc arena free", snapshot_1.malloc_info.fordblks, snapshot_2.malloc_info.fordblks)
libc_total_allocated_1 = snapshot_1.malloc_info.arena + snapshot_1.malloc_info.hblkhd
libc_total_allocated_2 = snapshot_2.malloc_info.arena + snapshot_2.malloc_info.hblkhd
msg += get_msg_line("libc total allocated", libc_total_allocated_1, libc_total_allocated_2)
libc_total_used_1 = snapshot_1.malloc_info.uordblks + snapshot_1.malloc_info.hblkhd
libc_total_used_2 = snapshot_2.malloc_info.uordblks + snapshot_2.malloc_info.hblkhd
msg += get_msg_line("libc total used", libc_total_used_1, libc_total_used_2)
if snapshot_1.vram is not None and snapshot_2.vram is not None:
msg += get_msg_line("VRAM", snapshot_1.vram, snapshot_2.vram)
return msg

View File

@ -1,553 +0,0 @@
"""
Manage a RAM cache of diffusion/transformer models for fast switching.
They are moved between GPU VRAM and CPU RAM as necessary. If the cache
grows larger than a preset maximum, then the least recently used
model will be cleared and (re)loaded from disk when next needed.
The cache returns context manager generators designed to load the
model into the GPU within the context, and unload outside the
context. Use like this:
cache = ModelCache(max_cache_size=7.5)
with cache.get_model('runwayml/stable-diffusion-1-5') as SD1,
cache.get_model('stabilityai/stable-diffusion-2') as SD2:
do_something_in_GPU(SD1,SD2)
"""
import gc
import hashlib
import math
import os
import sys
import time
from contextlib import suppress
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Dict, Optional, Type, Union, types
import torch
import invokeai.backend.util.logging as logger
from invokeai.backend.model_management.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff
from invokeai.backend.model_management.model_load_optimizations import skip_torch_weight_init
from ..util.devices import choose_torch_device
from .models import BaseModelType, ModelBase, ModelType, SubModelType
if choose_torch_device() == torch.device("mps"):
from torch import mps
# Maximum size of the cache, in gigs
# Default is roughly enough to hold three fp16 diffusers models in RAM simultaneously
DEFAULT_MAX_CACHE_SIZE = 6.0
# amount of GPU memory to hold in reserve for use by generations (GB)
DEFAULT_MAX_VRAM_CACHE_SIZE = 2.75
# actual size of a gig
GIG = 1073741824
# Size of a MB in bytes.
MB = 2**20
@dataclass
class CacheStats(object):
hits: int = 0 # cache hits
misses: int = 0 # cache misses
high_watermark: int = 0 # amount of cache used
in_cache: int = 0 # number of models in cache
cleared: int = 0 # number of models cleared to make space
cache_size: int = 0 # total size of cache
# {submodel_key => size}
loaded_model_sizes: Dict[str, int] = field(default_factory=dict)
class ModelLocker(object):
"Forward declaration"
pass
class ModelCache(object):
"Forward declaration"
pass
class _CacheRecord:
size: int
model: Any
cache: ModelCache
_locks: int
def __init__(self, cache, model: Any, size: int):
self.size = size
self.model = model
self.cache = cache
self._locks = 0
def lock(self):
self._locks += 1
def unlock(self):
self._locks -= 1
assert self._locks >= 0
@property
def locked(self):
return self._locks > 0
@property
def loaded(self):
if self.model is not None and hasattr(self.model, "device"):
return self.model.device != self.cache.storage_device
else:
return False
class ModelCache(object):
def __init__(
self,
max_cache_size: float = DEFAULT_MAX_CACHE_SIZE,
max_vram_cache_size: float = DEFAULT_MAX_VRAM_CACHE_SIZE,
execution_device: torch.device = torch.device("cuda"),
storage_device: torch.device = torch.device("cpu"),
precision: torch.dtype = torch.float16,
sequential_offload: bool = False,
lazy_offloading: bool = True,
sha_chunksize: int = 16777216,
logger: types.ModuleType = logger,
log_memory_usage: bool = False,
):
"""
:param max_cache_size: Maximum size of the RAM cache [6.0 GB]
:param execution_device: Torch device to load active model into [torch.device('cuda')]
:param storage_device: Torch device to save inactive model in [torch.device('cpu')]
:param precision: Precision for loaded models [torch.float16]
:param lazy_offloading: Keep model in VRAM until another model needs to be loaded
:param sequential_offload: Conserve VRAM by loading and unloading each stage of the pipeline sequentially
:param sha_chunksize: Chunksize to use when calculating sha256 model hash
:param log_memory_usage: If True, a memory snapshot will be captured before and after every model cache
operation, and the result will be logged (at debug level). There is a time cost to capturing the memory
snapshots, so it is recommended to disable this feature unless you are actively inspecting the model cache's
behaviour.
"""
self.model_infos: Dict[str, ModelBase] = {}
# allow lazy offloading only when vram cache enabled
self.lazy_offloading = lazy_offloading and max_vram_cache_size > 0
self.precision: torch.dtype = precision
self.max_cache_size: float = max_cache_size
self.max_vram_cache_size: float = max_vram_cache_size
self.execution_device: torch.device = execution_device
self.storage_device: torch.device = storage_device
self.sha_chunksize = sha_chunksize
self.logger = logger
self._log_memory_usage = log_memory_usage
# used for stats collection
self.stats = None
self._cached_models = {}
self._cache_stack = []
def _capture_memory_snapshot(self) -> Optional[MemorySnapshot]:
if self._log_memory_usage:
return MemorySnapshot.capture()
return None
def get_key(
self,
model_path: str,
base_model: BaseModelType,
model_type: ModelType,
submodel_type: Optional[SubModelType] = None,
):
key = f"{model_path}:{base_model}:{model_type}"
if submodel_type:
key += f":{submodel_type}"
return key
def _get_model_info(
self,
model_path: str,
model_class: Type[ModelBase],
base_model: BaseModelType,
model_type: ModelType,
):
model_info_key = self.get_key(
model_path=model_path,
base_model=base_model,
model_type=model_type,
submodel_type=None,
)
if model_info_key not in self.model_infos:
self.model_infos[model_info_key] = model_class(
model_path,
base_model,
model_type,
)
return self.model_infos[model_info_key]
# TODO: args
def get_model(
self,
model_path: Union[str, Path],
model_class: Type[ModelBase],
base_model: BaseModelType,
model_type: ModelType,
submodel: Optional[SubModelType] = None,
gpu_load: bool = True,
) -> Any:
if not isinstance(model_path, Path):
model_path = Path(model_path)
if not os.path.exists(model_path):
raise Exception(f"Model not found: {model_path}")
model_info = self._get_model_info(
model_path=model_path,
model_class=model_class,
base_model=base_model,
model_type=model_type,
)
key = self.get_key(
model_path=model_path,
base_model=base_model,
model_type=model_type,
submodel_type=submodel,
)
# TODO: lock for no copies on simultaneous calls?
cache_entry = self._cached_models.get(key, None)
if cache_entry is None:
self.logger.info(
f"Loading model {model_path}, type"
f" {base_model.value}:{model_type.value}{':'+submodel.value if submodel else ''}"
)
if self.stats:
self.stats.misses += 1
self_reported_model_size_before_load = model_info.get_size(submodel)
# Remove old models from the cache to make room for the new model.
self._make_cache_room(self_reported_model_size_before_load)
# Load the model from disk and capture a memory snapshot before/after.
start_load_time = time.time()
snapshot_before = self._capture_memory_snapshot()
with skip_torch_weight_init():
model = model_info.get_model(child_type=submodel, torch_dtype=self.precision)
snapshot_after = self._capture_memory_snapshot()
end_load_time = time.time()
self_reported_model_size_after_load = model_info.get_size(submodel)
self.logger.debug(
f"Moved model '{key}' from disk to cpu in {(end_load_time-start_load_time):.2f}s.\n"
f"Self-reported size before/after load: {(self_reported_model_size_before_load/GIG):.3f}GB /"
f" {(self_reported_model_size_after_load/GIG):.3f}GB.\n"
f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
)
if abs(self_reported_model_size_after_load - self_reported_model_size_before_load) > 10 * MB:
self.logger.debug(
f"Model '{key}' mis-reported its size before load. Self-reported size before/after load:"
f" {(self_reported_model_size_before_load/GIG):.2f}GB /"
f" {(self_reported_model_size_after_load/GIG):.2f}GB."
)
cache_entry = _CacheRecord(self, model, self_reported_model_size_after_load)
self._cached_models[key] = cache_entry
else:
if self.stats:
self.stats.hits += 1
if self.stats:
self.stats.cache_size = self.max_cache_size * GIG
self.stats.high_watermark = max(self.stats.high_watermark, self._cache_size())
self.stats.in_cache = len(self._cached_models)
self.stats.loaded_model_sizes[key] = max(
self.stats.loaded_model_sizes.get(key, 0), model_info.get_size(submodel)
)
with suppress(Exception):
self._cache_stack.remove(key)
self._cache_stack.append(key)
return self.ModelLocker(self, key, cache_entry.model, gpu_load, cache_entry.size)
def _move_model_to_device(self, key: str, target_device: torch.device):
cache_entry = self._cached_models[key]
source_device = cache_entry.model.device
# Note: We compare device types only so that 'cuda' == 'cuda:0'. This would need to be revised to support
# multi-GPU.
if torch.device(source_device).type == torch.device(target_device).type:
return
start_model_to_time = time.time()
snapshot_before = self._capture_memory_snapshot()
cache_entry.model.to(target_device)
snapshot_after = self._capture_memory_snapshot()
end_model_to_time = time.time()
self.logger.debug(
f"Moved model '{key}' from {source_device} to"
f" {target_device} in {(end_model_to_time-start_model_to_time):.2f}s.\n"
f"Estimated model size: {(cache_entry.size/GIG):.3f} GB.\n"
f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
)
if (
snapshot_before is not None
and snapshot_after is not None
and snapshot_before.vram is not None
and snapshot_after.vram is not None
):
vram_change = abs(snapshot_before.vram - snapshot_after.vram)
# If the estimated model size does not match the change in VRAM, log a warning.
if not math.isclose(
vram_change,
cache_entry.size,
rel_tol=0.1,
abs_tol=10 * MB,
):
self.logger.debug(
f"Moving model '{key}' from {source_device} to"
f" {target_device} caused an unexpected change in VRAM usage. The model's"
" estimated size may be incorrect. Estimated model size:"
f" {(cache_entry.size/GIG):.3f} GB.\n"
f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
)
class ModelLocker(object):
def __init__(self, cache, key, model, gpu_load, size_needed):
"""
:param cache: The model_cache object
:param key: The key of the model to lock in GPU
:param model: The model to lock
:param gpu_load: True if load into gpu
:param size_needed: Size of the model to load
"""
self.gpu_load = gpu_load
self.cache = cache
self.key = key
self.model = model
self.size_needed = size_needed
self.cache_entry = self.cache._cached_models[self.key]
def __enter__(self) -> Any:
if not hasattr(self.model, "to"):
return self.model
# NOTE that the model has to have the to() method in order for this
# code to move it into GPU!
if self.gpu_load:
self.cache_entry.lock()
try:
if self.cache.lazy_offloading:
self.cache._offload_unlocked_models(self.size_needed)
self.cache._move_model_to_device(self.key, self.cache.execution_device)
self.cache.logger.debug(f"Locking {self.key} in {self.cache.execution_device}")
self.cache._print_cuda_stats()
except Exception:
self.cache_entry.unlock()
raise
# TODO: not fully understand
# in the event that the caller wants the model in RAM, we
# move it into CPU if it is in GPU and not locked
elif self.cache_entry.loaded and not self.cache_entry.locked:
self.cache._move_model_to_device(self.key, self.cache.storage_device)
return self.model
def __exit__(self, type, value, traceback):
if not hasattr(self.model, "to"):
return
self.cache_entry.unlock()
if not self.cache.lazy_offloading:
self.cache._offload_unlocked_models()
self.cache._print_cuda_stats()
# TODO: should it be called untrack_model?
def uncache_model(self, cache_id: str):
with suppress(ValueError):
self._cache_stack.remove(cache_id)
self._cached_models.pop(cache_id, None)
def model_hash(
self,
model_path: Union[str, Path],
) -> str:
"""
Given the HF repo id or path to a model on disk, returns a unique
hash. Works for legacy checkpoint files, HF models on disk, and HF repo IDs
:param model_path: Path to model file/directory on disk.
"""
return self._local_model_hash(model_path)
def cache_size(self) -> float:
"""Return the current size of the cache, in GB."""
return self._cache_size() / GIG
def _has_cuda(self) -> bool:
return self.execution_device.type == "cuda"
def _print_cuda_stats(self):
vram = "%4.2fG" % (torch.cuda.memory_allocated() / GIG)
ram = "%4.2fG" % self.cache_size()
cached_models = 0
loaded_models = 0
locked_models = 0
for model_info in self._cached_models.values():
cached_models += 1
if model_info.loaded:
loaded_models += 1
if model_info.locked:
locked_models += 1
self.logger.debug(
f"Current VRAM/RAM usage: {vram}/{ram}; cached_models/loaded_models/locked_models/ ="
f" {cached_models}/{loaded_models}/{locked_models}"
)
def _cache_size(self) -> int:
return sum([m.size for m in self._cached_models.values()])
def _make_cache_room(self, model_size):
# calculate how much memory this model will require
# multiplier = 2 if self.precision==torch.float32 else 1
bytes_needed = model_size
maximum_size = self.max_cache_size * GIG # stored in GB, convert to bytes
current_size = self._cache_size()
if current_size + bytes_needed > maximum_size:
self.logger.debug(
f"Max cache size exceeded: {(current_size/GIG):.2f}/{self.max_cache_size:.2f} GB, need an additional"
f" {(bytes_needed/GIG):.2f} GB"
)
self.logger.debug(f"Before unloading: cached_models={len(self._cached_models)}")
pos = 0
models_cleared = 0
while current_size + bytes_needed > maximum_size and pos < len(self._cache_stack):
model_key = self._cache_stack[pos]
cache_entry = self._cached_models[model_key]
refs = sys.getrefcount(cache_entry.model)
# HACK: This is a workaround for a memory-management issue that we haven't tracked down yet. We are directly
# going against the advice in the Python docs by using `gc.get_referrers(...)` in this way:
# https://docs.python.org/3/library/gc.html#gc.get_referrers
# manualy clear local variable references of just finished function calls
# for some reason python don't want to collect it even by gc.collect() immidiately
if refs > 2:
while True:
cleared = False
for referrer in gc.get_referrers(cache_entry.model):
if type(referrer).__name__ == "frame":
# RuntimeError: cannot clear an executing frame
with suppress(RuntimeError):
referrer.clear()
cleared = True
# break
# repeat if referrers changes(due to frame clear), else exit loop
if cleared:
gc.collect()
else:
break
device = cache_entry.model.device if hasattr(cache_entry.model, "device") else None
self.logger.debug(
f"Model: {model_key}, locks: {cache_entry._locks}, device: {device}, loaded: {cache_entry.loaded},"
f" refs: {refs}"
)
# Expected refs:
# 1 from cache_entry
# 1 from getrefcount function
# 1 from onnx runtime object
if not cache_entry.locked and refs <= (3 if "onnx" in model_key else 2):
self.logger.debug(
f"Unloading model {model_key} to free {(model_size/GIG):.2f} GB (-{(cache_entry.size/GIG):.2f} GB)"
)
current_size -= cache_entry.size
models_cleared += 1
if self.stats:
self.stats.cleared += 1
del self._cache_stack[pos]
del self._cached_models[model_key]
del cache_entry
else:
pos += 1
if models_cleared > 0:
# There would likely be some 'garbage' to be collected regardless of whether a model was cleared or not, but
# there is a significant time cost to calling `gc.collect()`, so we want to use it sparingly. (The time cost
# is high even if no garbage gets collected.)
#
# Calling gc.collect(...) when a model is cleared seems like a good middle-ground:
# - If models had to be cleared, it's a signal that we are close to our memory limit.
# - If models were cleared, there's a good chance that there's a significant amount of garbage to be
# collected.
#
# Keep in mind that gc is only responsible for handling reference cycles. Most objects should be cleaned up
# immediately when their reference count hits 0.
gc.collect()
torch.cuda.empty_cache()
if choose_torch_device() == torch.device("mps"):
mps.empty_cache()
self.logger.debug(f"After unloading: cached_models={len(self._cached_models)}")
def _offload_unlocked_models(self, size_needed: int = 0):
reserved = self.max_vram_cache_size * GIG
vram_in_use = torch.cuda.memory_allocated()
self.logger.debug(f"{(vram_in_use/GIG):.2f}GB VRAM used for models; max allowed={(reserved/GIG):.2f}GB")
for model_key, cache_entry in sorted(self._cached_models.items(), key=lambda x: x[1].size):
if vram_in_use <= reserved:
break
if not cache_entry.locked and cache_entry.loaded:
self._move_model_to_device(model_key, self.storage_device)
vram_in_use = torch.cuda.memory_allocated()
self.logger.debug(f"{(vram_in_use/GIG):.2f}GB VRAM used for models; max allowed={(reserved/GIG):.2f}GB")
torch.cuda.empty_cache()
if choose_torch_device() == torch.device("mps"):
mps.empty_cache()
def _local_model_hash(self, model_path: Union[str, Path]) -> str:
sha = hashlib.sha256()
path = Path(model_path)
hashpath = path / "checksum.sha256"
if hashpath.exists() and path.stat().st_mtime <= hashpath.stat().st_mtime:
with open(hashpath) as f:
hash = f.read()
return hash
self.logger.debug(f"computing hash of model {path.name}")
for file in list(path.rglob("*.ckpt")) + list(path.rglob("*.safetensors")) + list(path.rglob("*.pth")):
with open(file, "rb") as f:
while chunk := f.read(self.sha_chunksize):
sha.update(chunk)
hash = sha.hexdigest()
with open(hashpath, "w") as f:
f.write(hash)
return hash

View File

@ -1,30 +0,0 @@
from contextlib import contextmanager
import torch
def _no_op(*args, **kwargs):
pass
@contextmanager
def skip_torch_weight_init():
"""A context manager that monkey-patches several of the common torch layers (torch.nn.Linear, torch.nn.Conv1d, etc.)
to skip weight initialization.
By default, `torch.nn.Linear` and `torch.nn.ConvNd` layers initialize their weights (according to a particular
distribution) when __init__ is called. This weight initialization step can take a significant amount of time, and is
completely unnecessary if the intent is to load checkpoint weights from disk for the layer. This context manager
monkey-patches common torch layers to skip the weight initialization step.
"""
torch_modules = [torch.nn.Linear, torch.nn.modules.conv._ConvNd, torch.nn.Embedding]
saved_functions = [m.reset_parameters for m in torch_modules]
try:
for torch_module in torch_modules:
torch_module.reset_parameters = _no_op
yield None
finally:
for torch_module, saved_function in zip(torch_modules, saved_functions, strict=True):
torch_module.reset_parameters = saved_function

File diff suppressed because it is too large Load Diff

View File

@ -1,140 +0,0 @@
"""
invokeai.backend.model_management.model_merge exports:
merge_diffusion_models() -- combine multiple models by location and return a pipeline object
merge_diffusion_models_and_commit() -- combine multiple models by ModelManager ID and write to models.yaml
Copyright (c) 2023 Lincoln Stein and the InvokeAI Development Team
"""
import warnings
from enum import Enum
from pathlib import Path
from typing import List, Optional, Union
from diffusers import DiffusionPipeline
from diffusers import logging as dlogging
import invokeai.backend.util.logging as logger
from ...backend.model_management import AddModelResult, BaseModelType, ModelManager, ModelType, ModelVariantType
class MergeInterpolationMethod(str, Enum):
WeightedSum = "weighted_sum"
Sigmoid = "sigmoid"
InvSigmoid = "inv_sigmoid"
AddDifference = "add_difference"
class ModelMerger(object):
def __init__(self, manager: ModelManager):
self.manager = manager
def merge_diffusion_models(
self,
model_paths: List[Path],
alpha: float = 0.5,
interp: Optional[MergeInterpolationMethod] = None,
force: bool = False,
**kwargs,
) -> DiffusionPipeline:
"""
:param model_paths: up to three models, designated by their local paths or HuggingFace repo_ids
:param alpha: The interpolation parameter. Ranges from 0 to 1. It affects the ratio in which the checkpoints are merged. A 0.8 alpha
would mean that the first model checkpoints would affect the final result far less than an alpha of 0.2
:param interp: The interpolation method to use for the merging. Supports "sigmoid", "inv_sigmoid", "add_difference" and None.
Passing None uses the default interpolation which is weighted sum interpolation. For merging three checkpoints, only "add_difference" is supported.
:param force: Whether to ignore mismatch in model_config.json for the current models. Defaults to False.
**kwargs - the default DiffusionPipeline.get_config_dict kwargs:
cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map
"""
with warnings.catch_warnings():
warnings.simplefilter("ignore")
verbosity = dlogging.get_verbosity()
dlogging.set_verbosity_error()
pipe = DiffusionPipeline.from_pretrained(
model_paths[0],
custom_pipeline="checkpoint_merger",
)
merged_pipe = pipe.merge(
pretrained_model_name_or_path_list=model_paths,
alpha=alpha,
interp=interp.value if interp else None, # diffusers API treats None as "weighted sum"
force=force,
**kwargs,
)
dlogging.set_verbosity(verbosity)
return merged_pipe
def merge_diffusion_models_and_save(
self,
model_names: List[str],
base_model: Union[BaseModelType, str],
merged_model_name: str,
alpha: float = 0.5,
interp: Optional[MergeInterpolationMethod] = None,
force: bool = False,
merge_dest_directory: Optional[Path] = None,
**kwargs,
) -> AddModelResult:
"""
:param models: up to three models, designated by their InvokeAI models.yaml model name
:param base_model: base model (must be the same for all merged models!)
:param merged_model_name: name for new model
:param alpha: The interpolation parameter. Ranges from 0 to 1. It affects the ratio in which the checkpoints are merged. A 0.8 alpha
would mean that the first model checkpoints would affect the final result far less than an alpha of 0.2
:param interp: The interpolation method to use for the merging. Supports "weighted_average", "sigmoid", "inv_sigmoid", "add_difference" and None.
Passing None uses the default interpolation which is weighted sum interpolation. For merging three checkpoints, only "add_difference" is supported. Add_difference is A+(B-C).
:param force: Whether to ignore mismatch in model_config.json for the current models. Defaults to False.
:param merge_dest_directory: Save the merged model to the designated directory (with 'merged_model_name' appended)
**kwargs - the default DiffusionPipeline.get_config_dict kwargs:
cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map
"""
model_paths = []
config = self.manager.app_config
base_model = BaseModelType(base_model)
vae = None
for mod in model_names:
info = self.manager.list_model(mod, base_model=base_model, model_type=ModelType.Main)
assert info, f"model {mod}, base_model {base_model}, is unknown"
assert (
info["model_format"] == "diffusers"
), f"{mod} is not a diffusers model. It must be optimized before merging"
assert info["variant"] == "normal", f"{mod} is a {info['variant']} model, which cannot currently be merged"
assert (
len(model_names) <= 2 or interp == MergeInterpolationMethod.AddDifference
), "When merging three models, only the 'add_difference' merge method is supported"
# pick up the first model's vae
if mod == model_names[0]:
vae = info.get("vae")
model_paths.extend([(config.root_path / info["path"]).as_posix()])
merge_method = None if interp == "weighted_sum" else MergeInterpolationMethod(interp)
logger.debug(f"interp = {interp}, merge_method={merge_method}")
merged_pipe = self.merge_diffusion_models(model_paths, alpha, merge_method, force, **kwargs)
dump_path = (
Path(merge_dest_directory)
if merge_dest_directory
else config.models_path / base_model.value / ModelType.Main.value
)
dump_path.mkdir(parents=True, exist_ok=True)
dump_path = (dump_path / merged_model_name).as_posix()
merged_pipe.save_pretrained(dump_path, safe_serialization=True)
attributes = {
"path": dump_path,
"description": f"Merge of models {', '.join(model_names)}",
"model_format": "diffusers",
"variant": ModelVariantType.Normal.value,
"vae": vae,
}
return self.manager.add_model(
merged_model_name,
base_model=base_model,
model_type=ModelType.Main,
model_attributes=attributes,
clobber=True,
)

View File

@ -1,664 +0,0 @@
import json
import re
from dataclasses import dataclass
from pathlib import Path
from typing import Callable, Dict, Literal, Optional, Union
import safetensors.torch
import torch
from diffusers import ConfigMixin, ModelMixin
from picklescan.scanner import scan_file_path
from invokeai.backend.model_management.models.ip_adapter import IPAdapterModelFormat
from .models import (
BaseModelType,
InvalidModelException,
ModelType,
ModelVariantType,
SchedulerPredictionType,
SilenceWarnings,
)
from .models.base import read_checkpoint_meta
from .util import lora_token_vector_length
@dataclass
class ModelProbeInfo(object):
model_type: ModelType
base_type: BaseModelType
variant_type: ModelVariantType
prediction_type: SchedulerPredictionType
upcast_attention: bool
format: Literal["diffusers", "checkpoint", "lycoris", "olive", "onnx"]
image_size: int
name: Optional[str] = None
description: Optional[str] = None
class ProbeBase(object):
"""forward declaration"""
pass
class ModelProbe(object):
PROBES = {
"diffusers": {},
"checkpoint": {},
"onnx": {},
}
CLASS2TYPE = {
"StableDiffusionPipeline": ModelType.Main,
"StableDiffusionInpaintPipeline": ModelType.Main,
"StableDiffusionXLPipeline": ModelType.Main,
"StableDiffusionXLImg2ImgPipeline": ModelType.Main,
"StableDiffusionXLInpaintPipeline": ModelType.Main,
"LatentConsistencyModelPipeline": ModelType.Main,
"AutoencoderKL": ModelType.Vae,
"AutoencoderTiny": ModelType.Vae,
"ControlNetModel": ModelType.ControlNet,
"CLIPVisionModelWithProjection": ModelType.CLIPVision,
"T2IAdapter": ModelType.T2IAdapter,
}
@classmethod
def register_probe(
cls, format: Literal["diffusers", "checkpoint", "onnx"], model_type: ModelType, probe_class: ProbeBase
):
cls.PROBES[format][model_type] = probe_class
@classmethod
def heuristic_probe(
cls,
model: Union[Dict, ModelMixin, Path],
prediction_type_helper: Callable[[Path], SchedulerPredictionType] = None,
) -> ModelProbeInfo:
if isinstance(model, Path):
return cls.probe(model_path=model, prediction_type_helper=prediction_type_helper)
elif isinstance(model, (dict, ModelMixin, ConfigMixin)):
return cls.probe(model_path=None, model=model, prediction_type_helper=prediction_type_helper)
else:
raise InvalidModelException("model parameter {model} is neither a Path, nor a model")
@classmethod
def probe(
cls,
model_path: Path,
model: Optional[Union[Dict, ModelMixin]] = None,
prediction_type_helper: Optional[Callable[[Path], SchedulerPredictionType]] = None,
) -> ModelProbeInfo:
"""
Probe the model at model_path and return sufficient information about it
to place it somewhere in the models directory hierarchy. If the model is
already loaded into memory, you may provide it as model in order to avoid
opening it a second time. The prediction_type_helper callable is a function that receives
the path to the model and returns the SchedulerPredictionType.
"""
if model_path:
format_type = "diffusers" if model_path.is_dir() else "checkpoint"
else:
format_type = "diffusers" if isinstance(model, (ConfigMixin, ModelMixin)) else "checkpoint"
model_info = None
try:
model_type = (
cls.get_model_type_from_folder(model_path, model)
if format_type == "diffusers"
else cls.get_model_type_from_checkpoint(model_path, model)
)
format_type = "onnx" if model_type == ModelType.ONNX else format_type
probe_class = cls.PROBES[format_type].get(model_type)
if not probe_class:
return None
probe = probe_class(model_path, model, prediction_type_helper)
base_type = probe.get_base_type()
variant_type = probe.get_variant_type()
prediction_type = probe.get_scheduler_prediction_type()
name = cls.get_model_name(model_path)
description = f"{base_type.value} {model_type.value} model {name}"
format = probe.get_format()
model_info = ModelProbeInfo(
model_type=model_type,
base_type=base_type,
variant_type=variant_type,
prediction_type=prediction_type,
name=name,
description=description,
upcast_attention=(
base_type == BaseModelType.StableDiffusion2
and prediction_type == SchedulerPredictionType.VPrediction
),
format=format,
image_size=(
1024
if (base_type in {BaseModelType.StableDiffusionXL, BaseModelType.StableDiffusionXLRefiner})
else (
768
if (
base_type == BaseModelType.StableDiffusion2
and prediction_type == SchedulerPredictionType.VPrediction
)
else 512
)
),
)
except Exception:
raise
return model_info
@classmethod
def get_model_name(cls, model_path: Path) -> str:
if model_path.suffix in {".safetensors", ".bin", ".pt", ".ckpt"}:
return model_path.stem
else:
return model_path.name
@classmethod
def get_model_type_from_checkpoint(cls, model_path: Path, checkpoint: dict) -> ModelType:
if model_path.suffix not in (".bin", ".pt", ".ckpt", ".safetensors", ".pth"):
return None
if model_path.name == "learned_embeds.bin":
return ModelType.TextualInversion
ckpt = checkpoint if checkpoint else read_checkpoint_meta(model_path, scan=True)
ckpt = ckpt.get("state_dict", ckpt)
for key in ckpt.keys():
if any(key.startswith(v) for v in {"cond_stage_model.", "first_stage_model.", "model.diffusion_model."}):
return ModelType.Main
elif any(key.startswith(v) for v in {"encoder.conv_in", "decoder.conv_in"}):
return ModelType.Vae
elif any(key.startswith(v) for v in {"lora_te_", "lora_unet_"}):
return ModelType.Lora
elif any(key.endswith(v) for v in {"to_k_lora.up.weight", "to_q_lora.down.weight"}):
return ModelType.Lora
elif any(key.startswith(v) for v in {"control_model", "input_blocks"}):
return ModelType.ControlNet
elif key in {"emb_params", "string_to_param"}:
return ModelType.TextualInversion
else:
# diffusers-ti
if len(ckpt) < 10 and all(isinstance(v, torch.Tensor) for v in ckpt.values()):
return ModelType.TextualInversion
raise InvalidModelException(f"Unable to determine model type for {model_path}")
@classmethod
def get_model_type_from_folder(cls, folder_path: Path, model: ModelMixin) -> ModelType:
"""
Get the model type of a hugging-face style folder.
"""
class_name = None
error_hint = None
if model:
class_name = model.__class__.__name__
else:
for suffix in ["bin", "safetensors"]:
if (folder_path / f"learned_embeds.{suffix}").exists():
return ModelType.TextualInversion
if (folder_path / f"pytorch_lora_weights.{suffix}").exists():
return ModelType.Lora
if (folder_path / "unet/model.onnx").exists():
return ModelType.ONNX
if (folder_path / "image_encoder.txt").exists():
return ModelType.IPAdapter
i = folder_path / "model_index.json"
c = folder_path / "config.json"
config_path = i if i.exists() else c if c.exists() else None
if config_path:
with open(config_path, "r") as file:
conf = json.load(file)
if "_class_name" in conf:
class_name = conf["_class_name"]
elif "architectures" in conf:
class_name = conf["architectures"][0]
else:
class_name = None
else:
error_hint = f"No model_index.json or config.json found in {folder_path}."
if class_name and (type := cls.CLASS2TYPE.get(class_name)):
return type
else:
error_hint = f"class {class_name} is not one of the supported classes [{', '.join(cls.CLASS2TYPE.keys())}]"
# give up
raise InvalidModelException(
f"Unable to determine model type for {folder_path}" + (f"; {error_hint}" if error_hint else "")
)
@classmethod
def _scan_and_load_checkpoint(cls, model_path: Path) -> dict:
with SilenceWarnings():
if model_path.suffix.endswith((".ckpt", ".pt", ".bin")):
cls._scan_model(model_path, model_path)
return torch.load(model_path, map_location="cpu")
else:
return safetensors.torch.load_file(model_path)
@classmethod
def _scan_model(cls, model_name, checkpoint):
"""
Apply picklescanner to the indicated checkpoint and issue a warning
and option to exit if an infected file is identified.
"""
# scan model
scan_result = scan_file_path(checkpoint)
if scan_result.infected_files != 0:
raise Exception("The model {model_name} is potentially infected by malware. Aborting import.")
# ##################################################3
# Checkpoint probing
# ##################################################3
class ProbeBase(object):
def get_base_type(self) -> BaseModelType:
pass
def get_variant_type(self) -> ModelVariantType:
pass
def get_scheduler_prediction_type(self) -> SchedulerPredictionType:
pass
def get_format(self) -> str:
pass
class CheckpointProbeBase(ProbeBase):
def __init__(
self, checkpoint_path: Path, checkpoint: dict, helper: Callable[[Path], SchedulerPredictionType] = None
) -> BaseModelType:
self.checkpoint = checkpoint or ModelProbe._scan_and_load_checkpoint(checkpoint_path)
self.checkpoint_path = checkpoint_path
self.helper = helper
def get_base_type(self) -> BaseModelType:
pass
def get_format(self) -> str:
return "checkpoint"
def get_variant_type(self) -> ModelVariantType:
model_type = ModelProbe.get_model_type_from_checkpoint(self.checkpoint_path, self.checkpoint)
if model_type != ModelType.Main:
return ModelVariantType.Normal
state_dict = self.checkpoint.get("state_dict") or self.checkpoint
in_channels = state_dict["model.diffusion_model.input_blocks.0.0.weight"].shape[1]
if in_channels == 9:
return ModelVariantType.Inpaint
elif in_channels == 5:
return ModelVariantType.Depth
elif in_channels == 4:
return ModelVariantType.Normal
else:
raise InvalidModelException(
f"Cannot determine variant type (in_channels={in_channels}) at {self.checkpoint_path}"
)
class PipelineCheckpointProbe(CheckpointProbeBase):
def get_base_type(self) -> BaseModelType:
checkpoint = self.checkpoint
state_dict = self.checkpoint.get("state_dict") or checkpoint
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
if key_name in state_dict and state_dict[key_name].shape[-1] == 768:
return BaseModelType.StableDiffusion1
if key_name in state_dict and state_dict[key_name].shape[-1] == 1024:
return BaseModelType.StableDiffusion2
key_name = "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_k.weight"
if key_name in state_dict and state_dict[key_name].shape[-1] == 2048:
return BaseModelType.StableDiffusionXL
elif key_name in state_dict and state_dict[key_name].shape[-1] == 1280:
return BaseModelType.StableDiffusionXLRefiner
else:
raise InvalidModelException("Cannot determine base type")
def get_scheduler_prediction_type(self) -> Optional[SchedulerPredictionType]:
"""Return model prediction type."""
# if there is a .yaml associated with this checkpoint, then we do not need
# to probe for the prediction type as it will be ignored.
if self.checkpoint_path and self.checkpoint_path.with_suffix(".yaml").exists():
return None
type = self.get_base_type()
if type == BaseModelType.StableDiffusion2:
checkpoint = self.checkpoint
state_dict = self.checkpoint.get("state_dict") or checkpoint
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
if key_name in state_dict and state_dict[key_name].shape[-1] == 1024:
if "global_step" in checkpoint:
if checkpoint["global_step"] == 220000:
return SchedulerPredictionType.Epsilon
elif checkpoint["global_step"] == 110000:
return SchedulerPredictionType.VPrediction
if self.helper and self.checkpoint_path:
if helper_guess := self.helper(self.checkpoint_path):
return helper_guess
return SchedulerPredictionType.VPrediction # a guess for sd2 ckpts
elif type == BaseModelType.StableDiffusion1:
if self.helper and self.checkpoint_path:
if helper_guess := self.helper(self.checkpoint_path):
return helper_guess
return SchedulerPredictionType.Epsilon # a reasonable guess for sd1 ckpts
else:
return None
class VaeCheckpointProbe(CheckpointProbeBase):
def get_base_type(self) -> BaseModelType:
# I can't find any standalone 2.X VAEs to test with!
return BaseModelType.StableDiffusion1
class LoRACheckpointProbe(CheckpointProbeBase):
def get_format(self) -> str:
return "lycoris"
def get_base_type(self) -> BaseModelType:
checkpoint = self.checkpoint
token_vector_length = lora_token_vector_length(checkpoint)
if token_vector_length == 768:
return BaseModelType.StableDiffusion1
elif token_vector_length == 1024:
return BaseModelType.StableDiffusion2
elif token_vector_length == 1280:
return BaseModelType.StableDiffusionXL # recognizes format at https://civitai.com/models/224641
elif token_vector_length == 2048:
return BaseModelType.StableDiffusionXL
else:
raise InvalidModelException(f"Unknown LoRA type: {self.checkpoint_path}")
class TextualInversionCheckpointProbe(CheckpointProbeBase):
def get_format(self) -> str:
return None
def get_base_type(self) -> BaseModelType:
checkpoint = self.checkpoint
if "string_to_token" in checkpoint:
token_dim = list(checkpoint["string_to_param"].values())[0].shape[-1]
elif "emb_params" in checkpoint:
token_dim = checkpoint["emb_params"].shape[-1]
elif "clip_g" in checkpoint:
token_dim = checkpoint["clip_g"].shape[-1]
else:
token_dim = list(checkpoint.values())[0].shape[-1]
if token_dim == 768:
return BaseModelType.StableDiffusion1
elif token_dim == 1024:
return BaseModelType.StableDiffusion2
elif token_dim == 1280:
return BaseModelType.StableDiffusionXL
else:
return None
class ControlNetCheckpointProbe(CheckpointProbeBase):
def get_base_type(self) -> BaseModelType:
checkpoint = self.checkpoint
for key_name in (
"control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight",
"input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight",
):
if key_name not in checkpoint:
continue
if checkpoint[key_name].shape[-1] == 768:
return BaseModelType.StableDiffusion1
elif checkpoint[key_name].shape[-1] == 1024:
return BaseModelType.StableDiffusion2
elif self.checkpoint_path and self.helper:
return self.helper(self.checkpoint_path)
raise InvalidModelException("Unable to determine base type for {self.checkpoint_path}")
class IPAdapterCheckpointProbe(CheckpointProbeBase):
def get_base_type(self) -> BaseModelType:
raise NotImplementedError()
class CLIPVisionCheckpointProbe(CheckpointProbeBase):
def get_base_type(self) -> BaseModelType:
raise NotImplementedError()
class T2IAdapterCheckpointProbe(CheckpointProbeBase):
def get_base_type(self) -> BaseModelType:
raise NotImplementedError()
########################################################
# classes for probing folders
#######################################################
class FolderProbeBase(ProbeBase):
def __init__(self, folder_path: Path, model: ModelMixin = None, helper: Callable = None): # not used
self.model = model
self.folder_path = folder_path
def get_variant_type(self) -> ModelVariantType:
return ModelVariantType.Normal
def get_format(self) -> str:
return "diffusers"
class PipelineFolderProbe(FolderProbeBase):
def get_base_type(self) -> BaseModelType:
if self.model:
unet_conf = self.model.unet.config
else:
with open(self.folder_path / "unet" / "config.json", "r") as file:
unet_conf = json.load(file)
if unet_conf["cross_attention_dim"] == 768:
return BaseModelType.StableDiffusion1
elif unet_conf["cross_attention_dim"] == 1024:
return BaseModelType.StableDiffusion2
elif unet_conf["cross_attention_dim"] == 1280:
return BaseModelType.StableDiffusionXLRefiner
elif unet_conf["cross_attention_dim"] == 2048:
return BaseModelType.StableDiffusionXL
else:
raise InvalidModelException(f"Unknown base model for {self.folder_path}")
def get_scheduler_prediction_type(self) -> SchedulerPredictionType:
if self.model:
scheduler_conf = self.model.scheduler.config
else:
with open(self.folder_path / "scheduler" / "scheduler_config.json", "r") as file:
scheduler_conf = json.load(file)
if scheduler_conf["prediction_type"] == "v_prediction":
return SchedulerPredictionType.VPrediction
elif scheduler_conf["prediction_type"] == "epsilon":
return SchedulerPredictionType.Epsilon
else:
return None
def get_variant_type(self) -> ModelVariantType:
# This only works for pipelines! Any kind of
# exception results in our returning the
# "normal" variant type
try:
if self.model:
conf = self.model.unet.config
else:
config_file = self.folder_path / "unet" / "config.json"
with open(config_file, "r") as file:
conf = json.load(file)
in_channels = conf["in_channels"]
if in_channels == 9:
return ModelVariantType.Inpaint
elif in_channels == 5:
return ModelVariantType.Depth
elif in_channels == 4:
return ModelVariantType.Normal
except Exception:
pass
return ModelVariantType.Normal
class VaeFolderProbe(FolderProbeBase):
def get_base_type(self) -> BaseModelType:
if self._config_looks_like_sdxl():
return BaseModelType.StableDiffusionXL
elif self._name_looks_like_sdxl():
# but SD and SDXL VAE are the same shape (3-channel RGB to 4-channel float scaled down
# by a factor of 8), we can't necessarily tell them apart by config hyperparameters.
return BaseModelType.StableDiffusionXL
else:
return BaseModelType.StableDiffusion1
def _config_looks_like_sdxl(self) -> bool:
# config values that distinguish Stability's SD 1.x VAE from their SDXL VAE.
config_file = self.folder_path / "config.json"
if not config_file.exists():
raise InvalidModelException(f"Cannot determine base type for {self.folder_path}")
with open(config_file, "r") as file:
config = json.load(file)
return config.get("scaling_factor", 0) == 0.13025 and config.get("sample_size") in [512, 1024]
def _name_looks_like_sdxl(self) -> bool:
return bool(re.search(r"xl\b", self._guess_name(), re.IGNORECASE))
def _guess_name(self) -> str:
name = self.folder_path.name
if name == "vae":
name = self.folder_path.parent.name
return name
class TextualInversionFolderProbe(FolderProbeBase):
def get_format(self) -> str:
return None
def get_base_type(self) -> BaseModelType:
path = self.folder_path / "learned_embeds.bin"
if not path.exists():
return None
checkpoint = ModelProbe._scan_and_load_checkpoint(path)
return TextualInversionCheckpointProbe(None, checkpoint=checkpoint).get_base_type()
class ONNXFolderProbe(FolderProbeBase):
def get_format(self) -> str:
return "onnx"
def get_base_type(self) -> BaseModelType:
return BaseModelType.StableDiffusion1
def get_variant_type(self) -> ModelVariantType:
return ModelVariantType.Normal
class ControlNetFolderProbe(FolderProbeBase):
def get_base_type(self) -> BaseModelType:
config_file = self.folder_path / "config.json"
if not config_file.exists():
raise InvalidModelException(f"Cannot determine base type for {self.folder_path}")
with open(config_file, "r") as file:
config = json.load(file)
# no obvious way to distinguish between sd2-base and sd2-768
dimension = config["cross_attention_dim"]
base_model = (
BaseModelType.StableDiffusion1
if dimension == 768
else (
BaseModelType.StableDiffusion2
if dimension == 1024
else BaseModelType.StableDiffusionXL
if dimension == 2048
else None
)
)
if not base_model:
raise InvalidModelException(f"Unable to determine model base for {self.folder_path}")
return base_model
class LoRAFolderProbe(FolderProbeBase):
def get_base_type(self) -> BaseModelType:
model_file = None
for suffix in ["safetensors", "bin"]:
base_file = self.folder_path / f"pytorch_lora_weights.{suffix}"
if base_file.exists():
model_file = base_file
break
if not model_file:
raise InvalidModelException("Unknown LoRA format encountered")
return LoRACheckpointProbe(model_file, None).get_base_type()
class IPAdapterFolderProbe(FolderProbeBase):
def get_format(self) -> str:
return IPAdapterModelFormat.InvokeAI.value
def get_base_type(self) -> BaseModelType:
model_file = self.folder_path / "ip_adapter.bin"
if not model_file.exists():
raise InvalidModelException("Unknown IP-Adapter model format.")
state_dict = torch.load(model_file, map_location="cpu")
cross_attention_dim = state_dict["ip_adapter"]["1.to_k_ip.weight"].shape[-1]
if cross_attention_dim == 768:
return BaseModelType.StableDiffusion1
elif cross_attention_dim == 1024:
return BaseModelType.StableDiffusion2
elif cross_attention_dim == 2048:
return BaseModelType.StableDiffusionXL
else:
raise InvalidModelException(f"IP-Adapter had unexpected cross-attention dimension: {cross_attention_dim}.")
class CLIPVisionFolderProbe(FolderProbeBase):
def get_base_type(self) -> BaseModelType:
return BaseModelType.Any
class T2IAdapterFolderProbe(FolderProbeBase):
def get_base_type(self) -> BaseModelType:
config_file = self.folder_path / "config.json"
if not config_file.exists():
raise InvalidModelException(f"Cannot determine base type for {self.folder_path}")
with open(config_file, "r") as file:
config = json.load(file)
adapter_type = config.get("adapter_type", None)
if adapter_type == "full_adapter_xl":
return BaseModelType.StableDiffusionXL
elif adapter_type == "full_adapter" or "light_adapter":
# I haven't seen any T2I adapter models for SD2, so assume that this is an SD1 adapter.
return BaseModelType.StableDiffusion1
else:
raise InvalidModelException(
f"Unable to determine base model for '{self.folder_path}' (adapter_type = {adapter_type})."
)
############## register probe classes ######
ModelProbe.register_probe("diffusers", ModelType.Main, PipelineFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.Vae, VaeFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.Lora, LoRAFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.TextualInversion, TextualInversionFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.ControlNet, ControlNetFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.IPAdapter, IPAdapterFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.CLIPVision, CLIPVisionFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.T2IAdapter, T2IAdapterFolderProbe)
ModelProbe.register_probe("checkpoint", ModelType.Main, PipelineCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.Vae, VaeCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.Lora, LoRACheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.TextualInversion, TextualInversionCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.ControlNet, ControlNetCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.IPAdapter, IPAdapterCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.CLIPVision, CLIPVisionCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.T2IAdapter, T2IAdapterCheckpointProbe)
ModelProbe.register_probe("onnx", ModelType.ONNX, ONNXFolderProbe)

View File

@ -1,112 +0,0 @@
# Copyright 2023, Lincoln D. Stein and the InvokeAI Team
"""
Abstract base class for recursive directory search for models.
"""
import os
from abc import ABC, abstractmethod
from pathlib import Path
from typing import List, Set, types
import invokeai.backend.util.logging as logger
class ModelSearch(ABC):
def __init__(self, directories: List[Path], logger: types.ModuleType = logger):
"""
Initialize a recursive model directory search.
:param directories: List of directory Paths to recurse through
:param logger: Logger to use
"""
self.directories = directories
self.logger = logger
self._items_scanned = 0
self._models_found = 0
self._scanned_dirs = set()
self._scanned_paths = set()
self._pruned_paths = set()
@abstractmethod
def on_search_started(self):
"""
Called before the scan starts.
"""
pass
@abstractmethod
def on_model_found(self, model: Path):
"""
Process a found model. Raise an exception if something goes wrong.
:param model: Model to process - could be a directory or checkpoint.
"""
pass
@abstractmethod
def on_search_completed(self):
"""
Perform some activity when the scan is completed. May use instance
variables, items_scanned and models_found
"""
pass
def search(self):
self.on_search_started()
for dir in self.directories:
self.walk_directory(dir)
self.on_search_completed()
def walk_directory(self, path: Path):
for root, dirs, files in os.walk(path, followlinks=True):
if str(Path(root).name).startswith("."):
self._pruned_paths.add(root)
if any(Path(root).is_relative_to(x) for x in self._pruned_paths):
continue
self._items_scanned += len(dirs) + len(files)
for d in dirs:
path = Path(root) / d
if path in self._scanned_paths or path.parent in self._scanned_dirs:
self._scanned_dirs.add(path)
continue
if any(
(path / x).exists()
for x in {
"config.json",
"model_index.json",
"learned_embeds.bin",
"pytorch_lora_weights.bin",
"image_encoder.txt",
}
):
try:
self.on_model_found(path)
self._models_found += 1
self._scanned_dirs.add(path)
except Exception as e:
self.logger.warning(f"Failed to process '{path}': {e}")
for f in files:
path = Path(root) / f
if path.parent in self._scanned_dirs:
continue
if path.suffix in {".ckpt", ".bin", ".pth", ".safetensors", ".pt"}:
try:
self.on_model_found(path)
self._models_found += 1
except Exception as e:
self.logger.warning(f"Failed to process '{path}': {e}")
class FindModels(ModelSearch):
def on_search_started(self):
self.models_found: Set[Path] = set()
def on_model_found(self, model: Path):
self.models_found.add(model)
def on_search_completed(self):
pass
def list_models(self) -> List[Path]:
self.search()
return list(self.models_found)

View File

@ -1,167 +0,0 @@
import inspect
from enum import Enum
from typing import Literal, get_origin
from pydantic import BaseModel, ConfigDict, create_model
from .base import ( # noqa: F401
BaseModelType,
DuplicateModelException,
InvalidModelException,
ModelBase,
ModelConfigBase,
ModelError,
ModelNotFoundException,
ModelType,
ModelVariantType,
SchedulerPredictionType,
SilenceWarnings,
SubModelType,
)
from .clip_vision import CLIPVisionModel
from .controlnet import ControlNetModel # TODO:
from .ip_adapter import IPAdapterModel
from .lora import LoRAModel
from .sdxl import StableDiffusionXLModel
from .stable_diffusion import StableDiffusion1Model, StableDiffusion2Model
from .stable_diffusion_onnx import ONNXStableDiffusion1Model, ONNXStableDiffusion2Model
from .t2i_adapter import T2IAdapterModel
from .textual_inversion import TextualInversionModel
from .vae import VaeModel
MODEL_CLASSES = {
BaseModelType.StableDiffusion1: {
ModelType.ONNX: ONNXStableDiffusion1Model,
ModelType.Main: StableDiffusion1Model,
ModelType.Vae: VaeModel,
ModelType.Lora: LoRAModel,
ModelType.ControlNet: ControlNetModel,
ModelType.TextualInversion: TextualInversionModel,
ModelType.IPAdapter: IPAdapterModel,
ModelType.CLIPVision: CLIPVisionModel,
ModelType.T2IAdapter: T2IAdapterModel,
},
BaseModelType.StableDiffusion2: {
ModelType.ONNX: ONNXStableDiffusion2Model,
ModelType.Main: StableDiffusion2Model,
ModelType.Vae: VaeModel,
ModelType.Lora: LoRAModel,
ModelType.ControlNet: ControlNetModel,
ModelType.TextualInversion: TextualInversionModel,
ModelType.IPAdapter: IPAdapterModel,
ModelType.CLIPVision: CLIPVisionModel,
ModelType.T2IAdapter: T2IAdapterModel,
},
BaseModelType.StableDiffusionXL: {
ModelType.Main: StableDiffusionXLModel,
ModelType.Vae: VaeModel,
# will not work until support written
ModelType.Lora: LoRAModel,
ModelType.ControlNet: ControlNetModel,
ModelType.TextualInversion: TextualInversionModel,
ModelType.ONNX: ONNXStableDiffusion2Model,
ModelType.IPAdapter: IPAdapterModel,
ModelType.CLIPVision: CLIPVisionModel,
ModelType.T2IAdapter: T2IAdapterModel,
},
BaseModelType.StableDiffusionXLRefiner: {
ModelType.Main: StableDiffusionXLModel,
ModelType.Vae: VaeModel,
# will not work until support written
ModelType.Lora: LoRAModel,
ModelType.ControlNet: ControlNetModel,
ModelType.TextualInversion: TextualInversionModel,
ModelType.ONNX: ONNXStableDiffusion2Model,
ModelType.IPAdapter: IPAdapterModel,
ModelType.CLIPVision: CLIPVisionModel,
ModelType.T2IAdapter: T2IAdapterModel,
},
BaseModelType.Any: {
ModelType.CLIPVision: CLIPVisionModel,
# The following model types are not expected to be used with BaseModelType.Any.
ModelType.ONNX: ONNXStableDiffusion2Model,
ModelType.Main: StableDiffusion2Model,
ModelType.Vae: VaeModel,
ModelType.Lora: LoRAModel,
ModelType.ControlNet: ControlNetModel,
ModelType.TextualInversion: TextualInversionModel,
ModelType.IPAdapter: IPAdapterModel,
ModelType.T2IAdapter: T2IAdapterModel,
},
# BaseModelType.Kandinsky2_1: {
# ModelType.Main: Kandinsky2_1Model,
# ModelType.MoVQ: MoVQModel,
# ModelType.Lora: LoRAModel,
# ModelType.ControlNet: ControlNetModel,
# ModelType.TextualInversion: TextualInversionModel,
# },
}
MODEL_CONFIGS = []
OPENAPI_MODEL_CONFIGS = []
class OpenAPIModelInfoBase(BaseModel):
model_name: str
base_model: BaseModelType
model_type: ModelType
model_config = ConfigDict(protected_namespaces=())
for _base_model, models in MODEL_CLASSES.items():
for model_type, model_class in models.items():
model_configs = set(model_class._get_configs().values())
model_configs.discard(None)
MODEL_CONFIGS.extend(model_configs)
# LS: sort to get the checkpoint configs first, which makes
# for a better template in the Swagger docs
for cfg in sorted(model_configs, key=lambda x: str(x)):
model_name, cfg_name = cfg.__qualname__.split(".")[-2:]
openapi_cfg_name = model_name + cfg_name
if openapi_cfg_name in vars():
continue
api_wrapper = create_model(
openapi_cfg_name,
__base__=(cfg, OpenAPIModelInfoBase),
model_type=(Literal[model_type], model_type), # type: ignore
)
vars()[openapi_cfg_name] = api_wrapper
OPENAPI_MODEL_CONFIGS.append(api_wrapper)
def get_model_config_enums():
enums = []
for model_config in MODEL_CONFIGS:
if hasattr(inspect, "get_annotations"):
fields = inspect.get_annotations(model_config)
else:
fields = model_config.__annotations__
try:
field = fields["model_format"]
except Exception:
raise Exception("format field not found")
# model_format: None
# model_format: SomeModelFormat
# model_format: Literal[SomeModelFormat.Diffusers]
# model_format: Literal[SomeModelFormat.Diffusers, SomeModelFormat.Checkpoint]
if isinstance(field, type) and issubclass(field, str) and issubclass(field, Enum):
enums.append(field)
elif get_origin(field) is Literal and all(
isinstance(arg, str) and isinstance(arg, Enum) for arg in field.__args__
):
enums.append(type(field.__args__[0]))
elif field is None:
pass
else:
raise Exception(f"Unsupported format definition in {model_configs.__qualname__}")
return enums

View File

@ -1,681 +0,0 @@
import inspect
import json
import os
import sys
import typing
import warnings
from abc import ABCMeta, abstractmethod
from contextlib import suppress
from enum import Enum
from pathlib import Path
from typing import Any, Callable, Dict, Generic, List, Literal, Optional, Type, TypeVar, Union
import numpy as np
import onnx
import safetensors.torch
import torch
from diffusers import ConfigMixin, DiffusionPipeline
from diffusers import logging as diffusers_logging
from onnx import numpy_helper
from onnxruntime import InferenceSession, SessionOptions, get_available_providers
from picklescan.scanner import scan_file_path
from pydantic import BaseModel, ConfigDict, Field
from transformers import logging as transformers_logging
class DuplicateModelException(Exception):
pass
class InvalidModelException(Exception):
pass
class ModelNotFoundException(Exception):
pass
class BaseModelType(str, Enum):
Any = "any" # For models that are not associated with any particular base model.
StableDiffusion1 = "sd-1"
StableDiffusion2 = "sd-2"
StableDiffusionXL = "sdxl"
StableDiffusionXLRefiner = "sdxl-refiner"
# Kandinsky2_1 = "kandinsky-2.1"
class ModelType(str, Enum):
ONNX = "onnx"
Main = "main"
Vae = "vae"
Lora = "lora"
ControlNet = "controlnet" # used by model_probe
TextualInversion = "embedding"
IPAdapter = "ip_adapter"
CLIPVision = "clip_vision"
T2IAdapter = "t2i_adapter"
class SubModelType(str, Enum):
UNet = "unet"
TextEncoder = "text_encoder"
TextEncoder2 = "text_encoder_2"
Tokenizer = "tokenizer"
Tokenizer2 = "tokenizer_2"
Vae = "vae"
VaeDecoder = "vae_decoder"
VaeEncoder = "vae_encoder"
Scheduler = "scheduler"
SafetyChecker = "safety_checker"
# MoVQ = "movq"
class ModelVariantType(str, Enum):
Normal = "normal"
Inpaint = "inpaint"
Depth = "depth"
class SchedulerPredictionType(str, Enum):
Epsilon = "epsilon"
VPrediction = "v_prediction"
Sample = "sample"
class ModelError(str, Enum):
NotFound = "not_found"
def model_config_json_schema_extra(schema: dict[str, Any]) -> None:
if "required" not in schema:
schema["required"] = []
schema["required"].append("model_type")
class ModelConfigBase(BaseModel):
path: str # or Path
description: Optional[str] = Field(None)
model_format: Optional[str] = Field(None)
error: Optional[ModelError] = Field(None)
model_config = ConfigDict(
use_enum_values=True, protected_namespaces=(), json_schema_extra=model_config_json_schema_extra
)
class EmptyConfigLoader(ConfigMixin):
@classmethod
def load_config(cls, *args, **kwargs):
cls.config_name = kwargs.pop("config_name")
return super().load_config(*args, **kwargs)
T_co = TypeVar("T_co", covariant=True)
class classproperty(Generic[T_co]):
def __init__(self, fget: Callable[[Any], T_co]) -> None:
self.fget = fget
def __get__(self, instance: Optional[Any], owner: Type[Any]) -> T_co:
return self.fget(owner)
def __set__(self, instance: Optional[Any], value: Any) -> None:
raise AttributeError("cannot set attribute")
class ModelBase(metaclass=ABCMeta):
# model_path: str
# base_model: BaseModelType
# model_type: ModelType
def __init__(
self,
model_path: str,
base_model: BaseModelType,
model_type: ModelType,
):
self.model_path = model_path
self.base_model = base_model
self.model_type = model_type
def _hf_definition_to_type(self, subtypes: List[str]) -> Type:
if len(subtypes) < 2:
raise Exception("Invalid subfolder definition!")
if all(t is None for t in subtypes):
return None
elif any(t is None for t in subtypes):
raise Exception(f"Unsupported definition: {subtypes}")
if subtypes[0] in ["diffusers", "transformers"]:
res_type = sys.modules[subtypes[0]]
subtypes = subtypes[1:]
else:
res_type = sys.modules["diffusers"]
res_type = res_type.pipelines
for subtype in subtypes:
res_type = getattr(res_type, subtype)
return res_type
@classmethod
def _get_configs(cls):
with suppress(Exception):
return cls.__configs
configs = {}
for name in dir(cls):
if name.startswith("__"):
continue
value = getattr(cls, name)
if not isinstance(value, type) or not issubclass(value, ModelConfigBase):
continue
if hasattr(inspect, "get_annotations"):
fields = inspect.get_annotations(value)
else:
fields = value.__annotations__
try:
field = fields["model_format"]
except Exception:
raise Exception(f"Invalid config definition - format field not found({cls.__qualname__})")
if isinstance(field, type) and issubclass(field, str) and issubclass(field, Enum):
for model_format in field:
configs[model_format.value] = value
elif typing.get_origin(field) is Literal and all(
isinstance(arg, str) and isinstance(arg, Enum) for arg in field.__args__
):
for model_format in field.__args__:
configs[model_format.value] = value
elif field is None:
configs[None] = value
else:
raise Exception(f"Unsupported format definition in {cls.__qualname__}")
cls.__configs = configs
return cls.__configs
@classmethod
def create_config(cls, **kwargs) -> ModelConfigBase:
if "model_format" not in kwargs:
raise Exception("Field 'model_format' not found in model config")
configs = cls._get_configs()
return configs[kwargs["model_format"]](**kwargs)
@classmethod
def probe_config(cls, path: str, **kwargs) -> ModelConfigBase:
return cls.create_config(
path=path,
model_format=cls.detect_format(path),
)
@classmethod
@abstractmethod
def detect_format(cls, path: str) -> str:
raise NotImplementedError()
@classproperty
@abstractmethod
def save_to_config(cls) -> bool:
raise NotImplementedError()
@abstractmethod
def get_size(self, child_type: Optional[SubModelType] = None) -> int:
raise NotImplementedError()
@abstractmethod
def get_model(
self,
torch_dtype: Optional[torch.dtype],
child_type: Optional[SubModelType] = None,
) -> Any:
raise NotImplementedError()
class DiffusersModel(ModelBase):
# child_types: Dict[str, Type]
# child_sizes: Dict[str, int]
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
super().__init__(model_path, base_model, model_type)
self.child_types: Dict[str, Type] = {}
self.child_sizes: Dict[str, int] = {}
try:
config_data = DiffusionPipeline.load_config(self.model_path)
# config_data = json.loads(os.path.join(self.model_path, "model_index.json"))
except Exception:
raise Exception("Invalid diffusers model! (model_index.json not found or invalid)")
config_data.pop("_ignore_files", None)
# retrieve all folder_names that contain relevant files
child_components = [k for k, v in config_data.items() if isinstance(v, list)]
for child_name in child_components:
child_type = self._hf_definition_to_type(config_data[child_name])
self.child_types[child_name] = child_type
self.child_sizes[child_name] = calc_model_size_by_fs(self.model_path, subfolder=child_name)
def get_size(self, child_type: Optional[SubModelType] = None):
if child_type is None:
return sum(self.child_sizes.values())
else:
return self.child_sizes[child_type]
def get_model(
self,
torch_dtype: Optional[torch.dtype],
child_type: Optional[SubModelType] = None,
):
# return pipeline in different function to pass more arguments
if child_type is None:
raise Exception("Child model type can't be null on diffusers model")
if child_type not in self.child_types:
return None # TODO: or raise
if torch_dtype == torch.float16:
variants = ["fp16", None]
else:
variants = [None, "fp16"]
# TODO: better error handling(differentiate not found from others)
for variant in variants:
try:
# TODO: set cache_dir to /dev/null to be sure that cache not used?
model = self.child_types[child_type].from_pretrained(
self.model_path,
subfolder=child_type.value,
torch_dtype=torch_dtype,
variant=variant,
local_files_only=True,
)
break
except Exception as e:
if not str(e).startswith("Error no file"):
print("====ERR LOAD====")
print(f"{variant}: {e}")
pass
else:
raise Exception(f"Failed to load {self.base_model}:{self.model_type}:{child_type} model")
# calc more accurate size
self.child_sizes[child_type] = calc_model_size_by_data(model)
return model
# def convert_if_required(model_path: str, cache_path: str, config: Optional[dict]) -> str:
def calc_model_size_by_fs(model_path: str, subfolder: Optional[str] = None, variant: Optional[str] = None):
if subfolder is not None:
model_path = os.path.join(model_path, subfolder)
# this can happen when, for example, the safety checker
# is not downloaded.
if not os.path.exists(model_path):
return 0
all_files = os.listdir(model_path)
all_files = [f for f in all_files if os.path.isfile(os.path.join(model_path, f))]
fp16_files = {f for f in all_files if ".fp16." in f or ".fp16-" in f}
bit8_files = {f for f in all_files if ".8bit." in f or ".8bit-" in f}
other_files = set(all_files) - fp16_files - bit8_files
if variant is None:
files = other_files
elif variant == "fp16":
files = fp16_files
elif variant == "8bit":
files = bit8_files
else:
raise NotImplementedError(f"Unknown variant: {variant}")
# try read from index if exists
index_postfix = ".index.json"
if variant is not None:
index_postfix = f".index.{variant}.json"
for file in files:
if not file.endswith(index_postfix):
continue
try:
with open(os.path.join(model_path, file), "r") as f:
index_data = json.loads(f.read())
return int(index_data["metadata"]["total_size"])
except Exception:
pass
# calculate files size if there is no index file
formats = [
(".safetensors",), # safetensors
(".bin",), # torch
(".onnx", ".pb"), # onnx
(".msgpack",), # flax
(".ckpt",), # tf
(".h5",), # tf2
]
for file_format in formats:
model_files = [f for f in files if f.endswith(file_format)]
if len(model_files) == 0:
continue
model_size = 0
for model_file in model_files:
file_stats = os.stat(os.path.join(model_path, model_file))
model_size += file_stats.st_size
return model_size
# raise NotImplementedError(f"Unknown model structure! Files: {all_files}")
return 0 # scheduler/feature_extractor/tokenizer - models without loading to gpu
def calc_model_size_by_data(model) -> int:
if isinstance(model, DiffusionPipeline):
return _calc_pipeline_by_data(model)
elif isinstance(model, torch.nn.Module):
return _calc_model_by_data(model)
elif isinstance(model, IAIOnnxRuntimeModel):
return _calc_onnx_model_by_data(model)
else:
return 0
def _calc_pipeline_by_data(pipeline) -> int:
res = 0
for submodel_key in pipeline.components.keys():
submodel = getattr(pipeline, submodel_key)
if submodel is not None and isinstance(submodel, torch.nn.Module):
res += _calc_model_by_data(submodel)
return res
def _calc_model_by_data(model) -> int:
mem_params = sum([param.nelement() * param.element_size() for param in model.parameters()])
mem_bufs = sum([buf.nelement() * buf.element_size() for buf in model.buffers()])
mem = mem_params + mem_bufs # in bytes
return mem
def _calc_onnx_model_by_data(model) -> int:
tensor_size = model.tensors.size() * 2 # The session doubles this
mem = tensor_size # in bytes
return mem
def _fast_safetensors_reader(path: str):
checkpoint = {}
device = torch.device("meta")
with open(path, "rb") as f:
definition_len = int.from_bytes(f.read(8), "little")
definition_json = f.read(definition_len)
definition = json.loads(definition_json)
if "__metadata__" in definition and definition["__metadata__"].get("format", "pt") not in {
"pt",
"torch",
"pytorch",
}:
raise Exception("Supported only pytorch safetensors files")
definition.pop("__metadata__", None)
for key, info in definition.items():
dtype = {
"I8": torch.int8,
"I16": torch.int16,
"I32": torch.int32,
"I64": torch.int64,
"F16": torch.float16,
"F32": torch.float32,
"F64": torch.float64,
}[info["dtype"]]
checkpoint[key] = torch.empty(info["shape"], dtype=dtype, device=device)
return checkpoint
def read_checkpoint_meta(path: Union[str, Path], scan: bool = False):
if str(path).endswith(".safetensors"):
try:
checkpoint = _fast_safetensors_reader(path)
except Exception:
# TODO: create issue for support "meta"?
checkpoint = safetensors.torch.load_file(path, device="cpu")
else:
if scan:
scan_result = scan_file_path(path)
if scan_result.infected_files != 0:
raise Exception(f'The model file "{path}" is potentially infected by malware. Aborting import.')
checkpoint = torch.load(path, map_location=torch.device("meta"))
return checkpoint
class SilenceWarnings(object):
def __init__(self):
self.transformers_verbosity = transformers_logging.get_verbosity()
self.diffusers_verbosity = diffusers_logging.get_verbosity()
def __enter__(self):
transformers_logging.set_verbosity_error()
diffusers_logging.set_verbosity_error()
warnings.simplefilter("ignore")
def __exit__(self, type, value, traceback):
transformers_logging.set_verbosity(self.transformers_verbosity)
diffusers_logging.set_verbosity(self.diffusers_verbosity)
warnings.simplefilter("default")
ONNX_WEIGHTS_NAME = "model.onnx"
class IAIOnnxRuntimeModel:
class _tensor_access:
def __init__(self, model):
self.model = model
self.indexes = {}
for idx, obj in enumerate(self.model.proto.graph.initializer):
self.indexes[obj.name] = idx
def __getitem__(self, key: str):
value = self.model.proto.graph.initializer[self.indexes[key]]
return numpy_helper.to_array(value)
def __setitem__(self, key: str, value: np.ndarray):
new_node = numpy_helper.from_array(value)
# set_external_data(new_node, location="in-memory-location")
new_node.name = key
# new_node.ClearField("raw_data")
del self.model.proto.graph.initializer[self.indexes[key]]
self.model.proto.graph.initializer.insert(self.indexes[key], new_node)
# self.model.data[key] = OrtValue.ortvalue_from_numpy(value)
# __delitem__
def __contains__(self, key: str):
return self.indexes[key] in self.model.proto.graph.initializer
def items(self):
raise NotImplementedError("tensor.items")
# return [(obj.name, obj) for obj in self.raw_proto]
def keys(self):
return self.indexes.keys()
def values(self):
raise NotImplementedError("tensor.values")
# return [obj for obj in self.raw_proto]
def size(self):
bytesSum = 0
for node in self.model.proto.graph.initializer:
bytesSum += sys.getsizeof(node.raw_data)
return bytesSum
class _access_helper:
def __init__(self, raw_proto):
self.indexes = {}
self.raw_proto = raw_proto
for idx, obj in enumerate(raw_proto):
self.indexes[obj.name] = idx
def __getitem__(self, key: str):
return self.raw_proto[self.indexes[key]]
def __setitem__(self, key: str, value):
index = self.indexes[key]
del self.raw_proto[index]
self.raw_proto.insert(index, value)
# __delitem__
def __contains__(self, key: str):
return key in self.indexes
def items(self):
return [(obj.name, obj) for obj in self.raw_proto]
def keys(self):
return self.indexes.keys()
def values(self):
return list(self.raw_proto)
def __init__(self, model_path: str, provider: Optional[str]):
self.path = model_path
self.session = None
self.provider = provider
"""
self.data_path = self.path + "_data"
if not os.path.exists(self.data_path):
print(f"Moving model tensors to separate file: {self.data_path}")
tmp_proto = onnx.load(model_path, load_external_data=True)
onnx.save_model(tmp_proto, self.path, save_as_external_data=True, all_tensors_to_one_file=True, location=os.path.basename(self.data_path), size_threshold=1024, convert_attribute=False)
del tmp_proto
gc.collect()
self.proto = onnx.load(model_path, load_external_data=False)
"""
self.proto = onnx.load(model_path, load_external_data=True)
# self.data = dict()
# for tensor in self.proto.graph.initializer:
# name = tensor.name
# if tensor.HasField("raw_data"):
# npt = numpy_helper.to_array(tensor)
# orv = OrtValue.ortvalue_from_numpy(npt)
# # self.data[name] = orv
# # set_external_data(tensor, location="in-memory-location")
# tensor.name = name
# # tensor.ClearField("raw_data")
self.nodes = self._access_helper(self.proto.graph.node)
# self.initializers = self._access_helper(self.proto.graph.initializer)
# print(self.proto.graph.input)
# print(self.proto.graph.initializer)
self.tensors = self._tensor_access(self)
# TODO: integrate with model manager/cache
def create_session(self, height=None, width=None):
if self.session is None or self.session_width != width or self.session_height != height:
# onnx.save(self.proto, "tmp.onnx")
# onnx.save_model(self.proto, "tmp.onnx", save_as_external_data=True, all_tensors_to_one_file=True, location="tmp.onnx_data", size_threshold=1024, convert_attribute=False)
# TODO: something to be able to get weight when they already moved outside of model proto
# (trimmed_model, external_data) = buffer_external_data_tensors(self.proto)
sess = SessionOptions()
# self._external_data.update(**external_data)
# sess.add_external_initializers(list(self.data.keys()), list(self.data.values()))
# sess.enable_profiling = True
# sess.intra_op_num_threads = 1
# sess.inter_op_num_threads = 1
# sess.execution_mode = ExecutionMode.ORT_SEQUENTIAL
# sess.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
# sess.enable_cpu_mem_arena = True
# sess.enable_mem_pattern = True
# sess.add_session_config_entry("session.intra_op.use_xnnpack_threadpool", "1") ########### It's the key code
self.session_height = height
self.session_width = width
if height and width:
sess.add_free_dimension_override_by_name("unet_sample_batch", 2)
sess.add_free_dimension_override_by_name("unet_sample_channels", 4)
sess.add_free_dimension_override_by_name("unet_hidden_batch", 2)
sess.add_free_dimension_override_by_name("unet_hidden_sequence", 77)
sess.add_free_dimension_override_by_name("unet_sample_height", self.session_height)
sess.add_free_dimension_override_by_name("unet_sample_width", self.session_width)
sess.add_free_dimension_override_by_name("unet_time_batch", 1)
providers = []
if self.provider:
providers.append(self.provider)
else:
providers = get_available_providers()
if "TensorrtExecutionProvider" in providers:
providers.remove("TensorrtExecutionProvider")
try:
self.session = InferenceSession(self.proto.SerializeToString(), providers=providers, sess_options=sess)
except Exception as e:
raise e
# self.session = InferenceSession("tmp.onnx", providers=[self.provider], sess_options=self.sess_options)
# self.io_binding = self.session.io_binding()
def release_session(self):
self.session = None
import gc
gc.collect()
return
def __call__(self, **kwargs):
if self.session is None:
raise Exception("You should call create_session before running model")
inputs = {k: np.array(v) for k, v in kwargs.items()}
# output_names = self.session.get_outputs()
# for k in inputs:
# self.io_binding.bind_cpu_input(k, inputs[k])
# for name in output_names:
# self.io_binding.bind_output(name.name)
# self.session.run_with_iobinding(self.io_binding, None)
# return self.io_binding.copy_outputs_to_cpu()
return self.session.run(None, inputs)
# compatability with diffusers load code
@classmethod
def from_pretrained(
cls,
model_id: Union[str, Path],
subfolder: Union[str, Path] = None,
file_name: Optional[str] = None,
provider: Optional[str] = None,
sess_options: Optional["SessionOptions"] = None,
**kwargs,
):
file_name = file_name or ONNX_WEIGHTS_NAME
if os.path.isdir(model_id):
model_path = model_id
if subfolder is not None:
model_path = os.path.join(model_path, subfolder)
model_path = os.path.join(model_path, file_name)
else:
model_path = model_id
# load model from local directory
if not os.path.isfile(model_path):
raise Exception(f"Model not found: {model_path}")
# TODO: session options
return cls(model_path, provider=provider)

View File

@ -1,82 +0,0 @@
import os
from enum import Enum
from typing import Literal, Optional
import torch
from transformers import CLIPVisionModelWithProjection
from invokeai.backend.model_management.models.base import (
BaseModelType,
InvalidModelException,
ModelBase,
ModelConfigBase,
ModelType,
SubModelType,
calc_model_size_by_data,
calc_model_size_by_fs,
classproperty,
)
class CLIPVisionModelFormat(str, Enum):
Diffusers = "diffusers"
class CLIPVisionModel(ModelBase):
class DiffusersConfig(ModelConfigBase):
model_format: Literal[CLIPVisionModelFormat.Diffusers]
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
assert model_type == ModelType.CLIPVision
super().__init__(model_path, base_model, model_type)
self.model_size = calc_model_size_by_fs(self.model_path)
@classmethod
def detect_format(cls, path: str) -> str:
if not os.path.exists(path):
raise ModuleNotFoundError(f"No CLIP Vision model at path '{path}'.")
if os.path.isdir(path) and os.path.exists(os.path.join(path, "config.json")):
return CLIPVisionModelFormat.Diffusers
raise InvalidModelException(f"Unexpected CLIP Vision model format: {path}")
@classproperty
def save_to_config(cls) -> bool:
return True
def get_size(self, child_type: Optional[SubModelType] = None) -> int:
if child_type is not None:
raise ValueError("There are no child models in a CLIP Vision model.")
return self.model_size
def get_model(
self,
torch_dtype: Optional[torch.dtype],
child_type: Optional[SubModelType] = None,
) -> CLIPVisionModelWithProjection:
if child_type is not None:
raise ValueError("There are no child models in a CLIP Vision model.")
model = CLIPVisionModelWithProjection.from_pretrained(self.model_path, torch_dtype=torch_dtype)
# Calculate a more accurate model size.
self.model_size = calc_model_size_by_data(model)
return model
@classmethod
def convert_if_required(
cls,
model_path: str,
output_path: str,
config: ModelConfigBase,
base_model: BaseModelType,
) -> str:
format = cls.detect_format(model_path)
if format == CLIPVisionModelFormat.Diffusers:
return model_path
else:
raise ValueError(f"Unsupported format: '{format}'.")

View File

@ -1,162 +0,0 @@
import os
from enum import Enum
from pathlib import Path
from typing import Literal, Optional
import torch
import invokeai.backend.util.logging as logger
from invokeai.app.services.config import InvokeAIAppConfig
from .base import (
BaseModelType,
EmptyConfigLoader,
InvalidModelException,
ModelBase,
ModelConfigBase,
ModelNotFoundException,
ModelType,
SubModelType,
calc_model_size_by_data,
calc_model_size_by_fs,
classproperty,
)
class ControlNetModelFormat(str, Enum):
Checkpoint = "checkpoint"
Diffusers = "diffusers"
class ControlNetModel(ModelBase):
# model_class: Type
# model_size: int
class DiffusersConfig(ModelConfigBase):
model_format: Literal[ControlNetModelFormat.Diffusers]
class CheckpointConfig(ModelConfigBase):
model_format: Literal[ControlNetModelFormat.Checkpoint]
config: str
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
assert model_type == ModelType.ControlNet
super().__init__(model_path, base_model, model_type)
try:
config = EmptyConfigLoader.load_config(self.model_path, config_name="config.json")
# config = json.loads(os.path.join(self.model_path, "config.json"))
except Exception:
raise Exception("Invalid controlnet model! (config.json not found or invalid)")
model_class_name = config.get("_class_name", None)
if model_class_name not in {"ControlNetModel"}:
raise Exception(f"Invalid ControlNet model! Unknown _class_name: {model_class_name}")
try:
self.model_class = self._hf_definition_to_type(["diffusers", model_class_name])
self.model_size = calc_model_size_by_fs(self.model_path)
except Exception:
raise Exception("Invalid ControlNet model!")
def get_size(self, child_type: Optional[SubModelType] = None):
if child_type is not None:
raise Exception("There is no child models in controlnet model")
return self.model_size
def get_model(
self,
torch_dtype: Optional[torch.dtype],
child_type: Optional[SubModelType] = None,
):
if child_type is not None:
raise Exception("There are no child models in controlnet model")
model = None
for variant in ["fp16", None]:
try:
model = self.model_class.from_pretrained(
self.model_path,
torch_dtype=torch_dtype,
variant=variant,
)
break
except Exception:
pass
if not model:
raise ModelNotFoundException()
# calc more accurate size
self.model_size = calc_model_size_by_data(model)
return model
@classproperty
def save_to_config(cls) -> bool:
return False
@classmethod
def detect_format(cls, path: str):
if not os.path.exists(path):
raise ModelNotFoundException()
if os.path.isdir(path):
if os.path.exists(os.path.join(path, "config.json")):
return ControlNetModelFormat.Diffusers
if os.path.isfile(path):
if any(path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt", "pth"]):
return ControlNetModelFormat.Checkpoint
raise InvalidModelException(f"Not a valid model: {path}")
@classmethod
def convert_if_required(
cls,
model_path: str,
output_path: str,
config: ModelConfigBase,
base_model: BaseModelType,
) -> str:
if cls.detect_format(model_path) == ControlNetModelFormat.Checkpoint:
return _convert_controlnet_ckpt_and_cache(
model_path=model_path,
model_config=config.config,
output_path=output_path,
base_model=base_model,
)
else:
return model_path
def _convert_controlnet_ckpt_and_cache(
model_path: str,
output_path: str,
base_model: BaseModelType,
model_config: str,
) -> str:
"""
Convert the controlnet from checkpoint format to diffusers format,
cache it to disk, and return Path to converted
file. If already on disk then just returns Path.
"""
app_config = InvokeAIAppConfig.get_config()
weights = app_config.root_path / model_path
output_path = Path(output_path)
logger.info(f"Converting {weights} to diffusers format")
# return cached version if it exists
if output_path.exists():
return output_path
# to avoid circular import errors
from ..convert_ckpt_to_diffusers import convert_controlnet_to_diffusers
convert_controlnet_to_diffusers(
weights,
output_path,
original_config_file=app_config.root_path / model_config,
image_size=512,
scan_needed=True,
from_safetensors=weights.suffix == ".safetensors",
)
return output_path

View File

@ -1,98 +0,0 @@
import os
import typing
from enum import Enum
from typing import Literal, Optional
import torch
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus, build_ip_adapter
from invokeai.backend.model_management.models.base import (
BaseModelType,
InvalidModelException,
ModelBase,
ModelConfigBase,
ModelType,
SubModelType,
calc_model_size_by_fs,
classproperty,
)
class IPAdapterModelFormat(str, Enum):
# The custom IP-Adapter model format defined by InvokeAI.
InvokeAI = "invokeai"
class IPAdapterModel(ModelBase):
class InvokeAIConfig(ModelConfigBase):
model_format: Literal[IPAdapterModelFormat.InvokeAI]
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
assert model_type == ModelType.IPAdapter
super().__init__(model_path, base_model, model_type)
self.model_size = calc_model_size_by_fs(self.model_path)
@classmethod
def detect_format(cls, path: str) -> str:
if not os.path.exists(path):
raise ModuleNotFoundError(f"No IP-Adapter model at path '{path}'.")
if os.path.isdir(path):
model_file = os.path.join(path, "ip_adapter.bin")
image_encoder_config_file = os.path.join(path, "image_encoder.txt")
if os.path.exists(model_file) and os.path.exists(image_encoder_config_file):
return IPAdapterModelFormat.InvokeAI
raise InvalidModelException(f"Unexpected IP-Adapter model format: {path}")
@classproperty
def save_to_config(cls) -> bool:
return True
def get_size(self, child_type: Optional[SubModelType] = None) -> int:
if child_type is not None:
raise ValueError("There are no child models in an IP-Adapter model.")
return self.model_size
def get_model(
self,
torch_dtype: torch.dtype,
child_type: Optional[SubModelType] = None,
) -> typing.Union[IPAdapter, IPAdapterPlus]:
if child_type is not None:
raise ValueError("There are no child models in an IP-Adapter model.")
model = build_ip_adapter(
ip_adapter_ckpt_path=os.path.join(self.model_path, "ip_adapter.bin"),
device=torch.device("cpu"),
dtype=torch_dtype,
)
self.model_size = model.calc_size()
return model
@classmethod
def convert_if_required(
cls,
model_path: str,
output_path: str,
config: ModelConfigBase,
base_model: BaseModelType,
) -> str:
format = cls.detect_format(model_path)
if format == IPAdapterModelFormat.InvokeAI:
return model_path
else:
raise ValueError(f"Unsupported format: '{format}'.")
def get_ip_adapter_image_encoder_model_id(model_path: str):
"""Read the ID of the image encoder associated with the IP-Adapter at `model_path`."""
image_encoder_config_file = os.path.join(model_path, "image_encoder.txt")
with open(image_encoder_config_file, "r") as f:
image_encoder_model = f.readline().strip()
return image_encoder_model

View File

@ -1,696 +0,0 @@
import bisect
import os
from enum import Enum
from pathlib import Path
from typing import Dict, Optional, Union
import torch
from safetensors.torch import load_file
from .base import (
BaseModelType,
InvalidModelException,
ModelBase,
ModelConfigBase,
ModelNotFoundException,
ModelType,
SubModelType,
classproperty,
)
class LoRAModelFormat(str, Enum):
LyCORIS = "lycoris"
Diffusers = "diffusers"
class LoRAModel(ModelBase):
# model_size: int
class Config(ModelConfigBase):
model_format: LoRAModelFormat # TODO:
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
assert model_type == ModelType.Lora
super().__init__(model_path, base_model, model_type)
self.model_size = os.path.getsize(self.model_path)
def get_size(self, child_type: Optional[SubModelType] = None):
if child_type is not None:
raise Exception("There is no child models in lora")
return self.model_size
def get_model(
self,
torch_dtype: Optional[torch.dtype],
child_type: Optional[SubModelType] = None,
):
if child_type is not None:
raise Exception("There is no child models in lora")
model = LoRAModelRaw.from_checkpoint(
file_path=self.model_path,
dtype=torch_dtype,
base_model=self.base_model,
)
self.model_size = model.calc_size()
return model
@classproperty
def save_to_config(cls) -> bool:
return True
@classmethod
def detect_format(cls, path: str):
if not os.path.exists(path):
raise ModelNotFoundException()
if os.path.isdir(path):
for ext in ["safetensors", "bin"]:
if os.path.exists(os.path.join(path, f"pytorch_lora_weights.{ext}")):
return LoRAModelFormat.Diffusers
if os.path.isfile(path):
if any(path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]):
return LoRAModelFormat.LyCORIS
raise InvalidModelException(f"Not a valid model: {path}")
@classmethod
def convert_if_required(
cls,
model_path: str,
output_path: str,
config: ModelConfigBase,
base_model: BaseModelType,
) -> str:
if cls.detect_format(model_path) == LoRAModelFormat.Diffusers:
for ext in ["safetensors", "bin"]: # return path to the safetensors file inside the folder
path = Path(model_path, f"pytorch_lora_weights.{ext}")
if path.exists():
return path
else:
return model_path
class LoRALayerBase:
# rank: Optional[int]
# alpha: Optional[float]
# bias: Optional[torch.Tensor]
# layer_key: str
# @property
# def scale(self):
# return self.alpha / self.rank if (self.alpha and self.rank) else 1.0
def __init__(
self,
layer_key: str,
values: dict,
):
if "alpha" in values:
self.alpha = values["alpha"].item()
else:
self.alpha = None
if "bias_indices" in values and "bias_values" in values and "bias_size" in values:
self.bias = torch.sparse_coo_tensor(
values["bias_indices"],
values["bias_values"],
tuple(values["bias_size"]),
)
else:
self.bias = None
self.rank = None # set in layer implementation
self.layer_key = layer_key
def get_weight(self, orig_weight: torch.Tensor):
raise NotImplementedError()
def calc_size(self) -> int:
model_size = 0
for val in [self.bias]:
if val is not None:
model_size += val.nelement() * val.element_size()
return model_size
def to(
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
if self.bias is not None:
self.bias = self.bias.to(device=device, dtype=dtype)
# TODO: find and debug lora/locon with bias
class LoRALayer(LoRALayerBase):
# up: torch.Tensor
# mid: Optional[torch.Tensor]
# down: torch.Tensor
def __init__(
self,
layer_key: str,
values: dict,
):
super().__init__(layer_key, values)
self.up = values["lora_up.weight"]
self.down = values["lora_down.weight"]
if "lora_mid.weight" in values:
self.mid = values["lora_mid.weight"]
else:
self.mid = None
self.rank = self.down.shape[0]
def get_weight(self, orig_weight: torch.Tensor):
if self.mid is not None:
up = self.up.reshape(self.up.shape[0], self.up.shape[1])
down = self.down.reshape(self.down.shape[0], self.down.shape[1])
weight = torch.einsum("m n w h, i m, n j -> i j w h", self.mid, up, down)
else:
weight = self.up.reshape(self.up.shape[0], -1) @ self.down.reshape(self.down.shape[0], -1)
return weight
def calc_size(self) -> int:
model_size = super().calc_size()
for val in [self.up, self.mid, self.down]:
if val is not None:
model_size += val.nelement() * val.element_size()
return model_size
def to(
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
super().to(device=device, dtype=dtype)
self.up = self.up.to(device=device, dtype=dtype)
self.down = self.down.to(device=device, dtype=dtype)
if self.mid is not None:
self.mid = self.mid.to(device=device, dtype=dtype)
class LoHALayer(LoRALayerBase):
# w1_a: torch.Tensor
# w1_b: torch.Tensor
# w2_a: torch.Tensor
# w2_b: torch.Tensor
# t1: Optional[torch.Tensor] = None
# t2: Optional[torch.Tensor] = None
def __init__(
self,
layer_key: str,
values: dict,
):
super().__init__(layer_key, values)
self.w1_a = values["hada_w1_a"]
self.w1_b = values["hada_w1_b"]
self.w2_a = values["hada_w2_a"]
self.w2_b = values["hada_w2_b"]
if "hada_t1" in values:
self.t1 = values["hada_t1"]
else:
self.t1 = None
if "hada_t2" in values:
self.t2 = values["hada_t2"]
else:
self.t2 = None
self.rank = self.w1_b.shape[0]
def get_weight(self, orig_weight: torch.Tensor):
if self.t1 is None:
weight = (self.w1_a @ self.w1_b) * (self.w2_a @ self.w2_b)
else:
rebuild1 = torch.einsum("i j k l, j r, i p -> p r k l", self.t1, self.w1_b, self.w1_a)
rebuild2 = torch.einsum("i j k l, j r, i p -> p r k l", self.t2, self.w2_b, self.w2_a)
weight = rebuild1 * rebuild2
return weight
def calc_size(self) -> int:
model_size = super().calc_size()
for val in [self.w1_a, self.w1_b, self.w2_a, self.w2_b, self.t1, self.t2]:
if val is not None:
model_size += val.nelement() * val.element_size()
return model_size
def to(
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
super().to(device=device, dtype=dtype)
self.w1_a = self.w1_a.to(device=device, dtype=dtype)
self.w1_b = self.w1_b.to(device=device, dtype=dtype)
if self.t1 is not None:
self.t1 = self.t1.to(device=device, dtype=dtype)
self.w2_a = self.w2_a.to(device=device, dtype=dtype)
self.w2_b = self.w2_b.to(device=device, dtype=dtype)
if self.t2 is not None:
self.t2 = self.t2.to(device=device, dtype=dtype)
class LoKRLayer(LoRALayerBase):
# w1: Optional[torch.Tensor] = None
# w1_a: Optional[torch.Tensor] = None
# w1_b: Optional[torch.Tensor] = None
# w2: Optional[torch.Tensor] = None
# w2_a: Optional[torch.Tensor] = None
# w2_b: Optional[torch.Tensor] = None
# t2: Optional[torch.Tensor] = None
def __init__(
self,
layer_key: str,
values: dict,
):
super().__init__(layer_key, values)
if "lokr_w1" in values:
self.w1 = values["lokr_w1"]
self.w1_a = None
self.w1_b = None
else:
self.w1 = None
self.w1_a = values["lokr_w1_a"]
self.w1_b = values["lokr_w1_b"]
if "lokr_w2" in values:
self.w2 = values["lokr_w2"]
self.w2_a = None
self.w2_b = None
else:
self.w2 = None
self.w2_a = values["lokr_w2_a"]
self.w2_b = values["lokr_w2_b"]
if "lokr_t2" in values:
self.t2 = values["lokr_t2"]
else:
self.t2 = None
if "lokr_w1_b" in values:
self.rank = values["lokr_w1_b"].shape[0]
elif "lokr_w2_b" in values:
self.rank = values["lokr_w2_b"].shape[0]
else:
self.rank = None # unscaled
def get_weight(self, orig_weight: torch.Tensor):
w1 = self.w1
if w1 is None:
w1 = self.w1_a @ self.w1_b
w2 = self.w2
if w2 is None:
if self.t2 is None:
w2 = self.w2_a @ self.w2_b
else:
w2 = torch.einsum("i j k l, i p, j r -> p r k l", self.t2, self.w2_a, self.w2_b)
if len(w2.shape) == 4:
w1 = w1.unsqueeze(2).unsqueeze(2)
w2 = w2.contiguous()
weight = torch.kron(w1, w2)
return weight
def calc_size(self) -> int:
model_size = super().calc_size()
for val in [self.w1, self.w1_a, self.w1_b, self.w2, self.w2_a, self.w2_b, self.t2]:
if val is not None:
model_size += val.nelement() * val.element_size()
return model_size
def to(
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
super().to(device=device, dtype=dtype)
if self.w1 is not None:
self.w1 = self.w1.to(device=device, dtype=dtype)
else:
self.w1_a = self.w1_a.to(device=device, dtype=dtype)
self.w1_b = self.w1_b.to(device=device, dtype=dtype)
if self.w2 is not None:
self.w2 = self.w2.to(device=device, dtype=dtype)
else:
self.w2_a = self.w2_a.to(device=device, dtype=dtype)
self.w2_b = self.w2_b.to(device=device, dtype=dtype)
if self.t2 is not None:
self.t2 = self.t2.to(device=device, dtype=dtype)
class FullLayer(LoRALayerBase):
# weight: torch.Tensor
def __init__(
self,
layer_key: str,
values: dict,
):
super().__init__(layer_key, values)
self.weight = values["diff"]
if len(values.keys()) > 1:
_keys = list(values.keys())
_keys.remove("diff")
raise NotImplementedError(f"Unexpected keys in lora diff layer: {_keys}")
self.rank = None # unscaled
def get_weight(self, orig_weight: torch.Tensor):
return self.weight
def calc_size(self) -> int:
model_size = super().calc_size()
model_size += self.weight.nelement() * self.weight.element_size()
return model_size
def to(
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
super().to(device=device, dtype=dtype)
self.weight = self.weight.to(device=device, dtype=dtype)
class IA3Layer(LoRALayerBase):
# weight: torch.Tensor
# on_input: torch.Tensor
def __init__(
self,
layer_key: str,
values: dict,
):
super().__init__(layer_key, values)
self.weight = values["weight"]
self.on_input = values["on_input"]
self.rank = None # unscaled
def get_weight(self, orig_weight: torch.Tensor):
weight = self.weight
if not self.on_input:
weight = weight.reshape(-1, 1)
return orig_weight * weight
def calc_size(self) -> int:
model_size = super().calc_size()
model_size += self.weight.nelement() * self.weight.element_size()
model_size += self.on_input.nelement() * self.on_input.element_size()
return model_size
def to(
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
super().to(device=device, dtype=dtype)
self.weight = self.weight.to(device=device, dtype=dtype)
self.on_input = self.on_input.to(device=device, dtype=dtype)
# TODO: rename all methods used in model logic with Info postfix and remove here Raw postfix
class LoRAModelRaw: # (torch.nn.Module):
_name: str
layers: Dict[str, LoRALayer]
def __init__(
self,
name: str,
layers: Dict[str, LoRALayer],
):
self._name = name
self.layers = layers
@property
def name(self):
return self._name
def to(
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
# TODO: try revert if exception?
for _key, layer in self.layers.items():
layer.to(device=device, dtype=dtype)
def calc_size(self) -> int:
model_size = 0
for _, layer in self.layers.items():
model_size += layer.calc_size()
return model_size
@classmethod
def _convert_sdxl_keys_to_diffusers_format(cls, state_dict):
"""Convert the keys of an SDXL LoRA state_dict to diffusers format.
The input state_dict can be in either Stability AI format or diffusers format. If the state_dict is already in
diffusers format, then this function will have no effect.
This function is adapted from:
https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L385-L409
Args:
state_dict (Dict[str, Tensor]): The SDXL LoRA state_dict.
Raises:
ValueError: If state_dict contains an unrecognized key, or not all keys could be converted.
Returns:
Dict[str, Tensor]: The diffusers-format state_dict.
"""
converted_count = 0 # The number of Stability AI keys converted to diffusers format.
not_converted_count = 0 # The number of keys that were not converted.
# Get a sorted list of Stability AI UNet keys so that we can efficiently search for keys with matching prefixes.
# For example, we want to efficiently find `input_blocks_4_1` in the list when searching for
# `input_blocks_4_1_proj_in`.
stability_unet_keys = list(SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP)
stability_unet_keys.sort()
new_state_dict = {}
for full_key, value in state_dict.items():
if full_key.startswith("lora_unet_"):
search_key = full_key.replace("lora_unet_", "")
# Use bisect to find the key in stability_unet_keys that *may* match the search_key's prefix.
position = bisect.bisect_right(stability_unet_keys, search_key)
map_key = stability_unet_keys[position - 1]
# Now, check if the map_key *actually* matches the search_key.
if search_key.startswith(map_key):
new_key = full_key.replace(map_key, SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP[map_key])
new_state_dict[new_key] = value
converted_count += 1
else:
new_state_dict[full_key] = value
not_converted_count += 1
elif full_key.startswith("lora_te1_") or full_key.startswith("lora_te2_"):
# The CLIP text encoders have the same keys in both Stability AI and diffusers formats.
new_state_dict[full_key] = value
continue
else:
raise ValueError(f"Unrecognized SDXL LoRA key prefix: '{full_key}'.")
if converted_count > 0 and not_converted_count > 0:
raise ValueError(
f"The SDXL LoRA could only be partially converted to diffusers format. converted={converted_count},"
f" not_converted={not_converted_count}"
)
return new_state_dict
@classmethod
def from_checkpoint(
cls,
file_path: Union[str, Path],
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
base_model: Optional[BaseModelType] = None,
):
device = device or torch.device("cpu")
dtype = dtype or torch.float32
if isinstance(file_path, str):
file_path = Path(file_path)
model = cls(
name=file_path.stem, # TODO:
layers={},
)
if file_path.suffix == ".safetensors":
state_dict = load_file(file_path.absolute().as_posix(), device="cpu")
else:
state_dict = torch.load(file_path, map_location="cpu")
state_dict = cls._group_state(state_dict)
if base_model == BaseModelType.StableDiffusionXL:
state_dict = cls._convert_sdxl_keys_to_diffusers_format(state_dict)
for layer_key, values in state_dict.items():
# lora and locon
if "lora_down.weight" in values:
layer = LoRALayer(layer_key, values)
# loha
elif "hada_w1_b" in values:
layer = LoHALayer(layer_key, values)
# lokr
elif "lokr_w1_b" in values or "lokr_w1" in values:
layer = LoKRLayer(layer_key, values)
# diff
elif "diff" in values:
layer = FullLayer(layer_key, values)
# ia3
elif "weight" in values and "on_input" in values:
layer = IA3Layer(layer_key, values)
else:
print(f">> Encountered unknown lora layer module in {model.name}: {layer_key} - {list(values.keys())}")
raise Exception("Unknown lora format!")
# lower memory consumption by removing already parsed layer values
state_dict[layer_key].clear()
layer.to(device=device, dtype=dtype)
model.layers[layer_key] = layer
return model
@staticmethod
def _group_state(state_dict: dict):
state_dict_groupped = {}
for key, value in state_dict.items():
stem, leaf = key.split(".", 1)
if stem not in state_dict_groupped:
state_dict_groupped[stem] = {}
state_dict_groupped[stem][leaf] = value
return state_dict_groupped
# code from
# https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L15C1-L97C32
def make_sdxl_unet_conversion_map():
"""Create a dict mapping state_dict keys from Stability AI SDXL format to diffusers SDXL format."""
unet_conversion_map_layer = []
for i in range(3): # num_blocks is 3 in sdxl
# loop over downblocks/upblocks
for j in range(2):
# loop over resnets/attentions for downblocks
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
if i < 3:
# no attention layers in down_blocks.3
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
for j in range(3):
# loop over resnets/attentions for upblocks
hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
# if i > 0: commentout for sdxl
# no attention layers in up_blocks.0
hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
if i < 3:
# no downsample in down_blocks.3
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
# no upsample in up_blocks.3
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
sd_upsample_prefix = f"output_blocks.{3*i + 2}.{2}." # change for sdxl
unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
hf_mid_atn_prefix = "mid_block.attentions.0."
sd_mid_atn_prefix = "middle_block.1."
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
for j in range(2):
hf_mid_res_prefix = f"mid_block.resnets.{j}."
sd_mid_res_prefix = f"middle_block.{2*j}."
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
unet_conversion_map_resnet = [
# (stable-diffusion, HF Diffusers)
("in_layers.0.", "norm1."),
("in_layers.2.", "conv1."),
("out_layers.0.", "norm2."),
("out_layers.3.", "conv2."),
("emb_layers.1.", "time_emb_proj."),
("skip_connection.", "conv_shortcut."),
]
unet_conversion_map = []
for sd, hf in unet_conversion_map_layer:
if "resnets" in hf:
for sd_res, hf_res in unet_conversion_map_resnet:
unet_conversion_map.append((sd + sd_res, hf + hf_res))
else:
unet_conversion_map.append((sd, hf))
for j in range(2):
hf_time_embed_prefix = f"time_embedding.linear_{j+1}."
sd_time_embed_prefix = f"time_embed.{j*2}."
unet_conversion_map.append((sd_time_embed_prefix, hf_time_embed_prefix))
for j in range(2):
hf_label_embed_prefix = f"add_embedding.linear_{j+1}."
sd_label_embed_prefix = f"label_emb.0.{j*2}."
unet_conversion_map.append((sd_label_embed_prefix, hf_label_embed_prefix))
unet_conversion_map.append(("input_blocks.0.0.", "conv_in."))
unet_conversion_map.append(("out.0.", "conv_norm_out."))
unet_conversion_map.append(("out.2.", "conv_out."))
return unet_conversion_map
SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP = {
sd.rstrip(".").replace(".", "_"): hf.rstrip(".").replace(".", "_") for sd, hf in make_sdxl_unet_conversion_map()
}

View File

@ -1,148 +0,0 @@
import json
import os
from enum import Enum
from pathlib import Path
from typing import Literal, Optional
from omegaconf import OmegaConf
from pydantic import Field
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.model_management.detect_baked_in_vae import has_baked_in_sdxl_vae
from invokeai.backend.util.logging import InvokeAILogger
from .base import (
BaseModelType,
DiffusersModel,
InvalidModelException,
ModelConfigBase,
ModelType,
ModelVariantType,
classproperty,
read_checkpoint_meta,
)
class StableDiffusionXLModelFormat(str, Enum):
Checkpoint = "checkpoint"
Diffusers = "diffusers"
class StableDiffusionXLModel(DiffusersModel):
# TODO: check that configs overwriten properly
class DiffusersConfig(ModelConfigBase):
model_format: Literal[StableDiffusionXLModelFormat.Diffusers]
vae: Optional[str] = Field(None)
variant: ModelVariantType
class CheckpointConfig(ModelConfigBase):
model_format: Literal[StableDiffusionXLModelFormat.Checkpoint]
vae: Optional[str] = Field(None)
config: str
variant: ModelVariantType
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
assert base_model in {BaseModelType.StableDiffusionXL, BaseModelType.StableDiffusionXLRefiner}
assert model_type == ModelType.Main
super().__init__(
model_path=model_path,
base_model=BaseModelType.StableDiffusionXL,
model_type=ModelType.Main,
)
@classmethod
def probe_config(cls, path: str, **kwargs):
model_format = cls.detect_format(path)
ckpt_config_path = kwargs.get("config", None)
if model_format == StableDiffusionXLModelFormat.Checkpoint:
if ckpt_config_path:
ckpt_config = OmegaConf.load(ckpt_config_path)
in_channels = ckpt_config["model"]["params"]["unet_config"]["params"]["in_channels"]
else:
checkpoint = read_checkpoint_meta(path)
checkpoint = checkpoint.get("state_dict", checkpoint)
in_channels = checkpoint["model.diffusion_model.input_blocks.0.0.weight"].shape[1]
elif model_format == StableDiffusionXLModelFormat.Diffusers:
unet_config_path = os.path.join(path, "unet", "config.json")
if os.path.exists(unet_config_path):
with open(unet_config_path, "r") as f:
unet_config = json.loads(f.read())
in_channels = unet_config["in_channels"]
else:
raise InvalidModelException(f"{path} is not a recognized Stable Diffusion diffusers model")
else:
raise NotImplementedError(f"Unknown stable diffusion 2.* format: {model_format}")
if in_channels == 9:
variant = ModelVariantType.Inpaint
elif in_channels == 5:
variant = ModelVariantType.Depth
elif in_channels == 4:
variant = ModelVariantType.Normal
else:
raise Exception("Unkown stable diffusion 2.* model format")
if ckpt_config_path is None:
# avoid circular import
from .stable_diffusion import _select_ckpt_config
ckpt_config_path = _select_ckpt_config(kwargs.get("model_base", BaseModelType.StableDiffusionXL), variant)
return cls.create_config(
path=path,
model_format=model_format,
config=ckpt_config_path,
variant=variant,
)
@classproperty
def save_to_config(cls) -> bool:
return True
@classmethod
def detect_format(cls, model_path: str):
if os.path.isdir(model_path):
return StableDiffusionXLModelFormat.Diffusers
else:
return StableDiffusionXLModelFormat.Checkpoint
@classmethod
def convert_if_required(
cls,
model_path: str,
output_path: str,
config: ModelConfigBase,
base_model: BaseModelType,
) -> str:
# The convert script adapted from the diffusers package uses
# strings for the base model type. To avoid making too many
# source code changes, we simply translate here
if Path(output_path).exists():
return output_path
if isinstance(config, cls.CheckpointConfig):
from invokeai.backend.model_management.models.stable_diffusion import _convert_ckpt_and_cache
# Hack in VAE-fp16 fix - If model sdxl-vae-fp16-fix is installed,
# then we bake it into the converted model unless there is already
# a nonstandard VAE installed.
kwargs = {}
app_config = InvokeAIAppConfig.get_config()
vae_path = app_config.models_path / "sdxl/vae/sdxl-vae-fp16-fix"
if vae_path.exists() and not has_baked_in_sdxl_vae(Path(model_path)):
InvokeAILogger.get_logger().warning("No baked-in VAE detected. Inserting sdxl-vae-fp16-fix.")
kwargs["vae_path"] = vae_path
return _convert_ckpt_and_cache(
version=base_model,
model_config=config,
output_path=output_path,
use_safetensors=True,
**kwargs,
)
else:
return model_path

View File

@ -1,337 +0,0 @@
import json
import os
from enum import Enum
from pathlib import Path
from typing import Literal, Optional, Union
from diffusers import StableDiffusionInpaintPipeline, StableDiffusionPipeline
from omegaconf import OmegaConf
from pydantic import Field
import invokeai.backend.util.logging as logger
from invokeai.app.services.config import InvokeAIAppConfig
from .base import (
BaseModelType,
DiffusersModel,
InvalidModelException,
ModelConfigBase,
ModelNotFoundException,
ModelType,
ModelVariantType,
SilenceWarnings,
classproperty,
read_checkpoint_meta,
)
from .sdxl import StableDiffusionXLModel
class StableDiffusion1ModelFormat(str, Enum):
Checkpoint = "checkpoint"
Diffusers = "diffusers"
class StableDiffusion1Model(DiffusersModel):
class DiffusersConfig(ModelConfigBase):
model_format: Literal[StableDiffusion1ModelFormat.Diffusers]
vae: Optional[str] = Field(None)
variant: ModelVariantType
class CheckpointConfig(ModelConfigBase):
model_format: Literal[StableDiffusion1ModelFormat.Checkpoint]
vae: Optional[str] = Field(None)
config: str
variant: ModelVariantType
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
assert base_model == BaseModelType.StableDiffusion1
assert model_type == ModelType.Main
super().__init__(
model_path=model_path,
base_model=BaseModelType.StableDiffusion1,
model_type=ModelType.Main,
)
@classmethod
def probe_config(cls, path: str, **kwargs):
model_format = cls.detect_format(path)
ckpt_config_path = kwargs.get("config", None)
if model_format == StableDiffusion1ModelFormat.Checkpoint:
if ckpt_config_path:
ckpt_config = OmegaConf.load(ckpt_config_path)
ckpt_config["model"]["params"]["unet_config"]["params"]["in_channels"]
else:
checkpoint = read_checkpoint_meta(path)
checkpoint = checkpoint.get("state_dict", checkpoint)
in_channels = checkpoint["model.diffusion_model.input_blocks.0.0.weight"].shape[1]
elif model_format == StableDiffusion1ModelFormat.Diffusers:
unet_config_path = os.path.join(path, "unet", "config.json")
if os.path.exists(unet_config_path):
with open(unet_config_path, "r") as f:
unet_config = json.loads(f.read())
in_channels = unet_config["in_channels"]
else:
raise NotImplementedError(f"{path} is not a supported stable diffusion diffusers format")
else:
raise NotImplementedError(f"Unknown stable diffusion 1.* format: {model_format}")
if in_channels == 9:
variant = ModelVariantType.Inpaint
elif in_channels == 4:
variant = ModelVariantType.Normal
else:
raise Exception("Unkown stable diffusion 1.* model format")
if ckpt_config_path is None:
ckpt_config_path = _select_ckpt_config(BaseModelType.StableDiffusion1, variant)
return cls.create_config(
path=path,
model_format=model_format,
config=ckpt_config_path,
variant=variant,
)
@classproperty
def save_to_config(cls) -> bool:
return True
@classmethod
def detect_format(cls, model_path: str):
if not os.path.exists(model_path):
raise ModelNotFoundException()
if os.path.isdir(model_path):
if os.path.exists(os.path.join(model_path, "model_index.json")):
return StableDiffusion1ModelFormat.Diffusers
if os.path.isfile(model_path):
if any(model_path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]):
return StableDiffusion1ModelFormat.Checkpoint
raise InvalidModelException(f"Not a valid model: {model_path}")
@classmethod
def convert_if_required(
cls,
model_path: str,
output_path: str,
config: ModelConfigBase,
base_model: BaseModelType,
) -> str:
if isinstance(config, cls.CheckpointConfig):
return _convert_ckpt_and_cache(
version=BaseModelType.StableDiffusion1,
model_config=config,
load_safety_checker=False,
output_path=output_path,
)
else:
return model_path
class StableDiffusion2ModelFormat(str, Enum):
Checkpoint = "checkpoint"
Diffusers = "diffusers"
class StableDiffusion2Model(DiffusersModel):
# TODO: check that configs overwriten properly
class DiffusersConfig(ModelConfigBase):
model_format: Literal[StableDiffusion2ModelFormat.Diffusers]
vae: Optional[str] = Field(None)
variant: ModelVariantType
class CheckpointConfig(ModelConfigBase):
model_format: Literal[StableDiffusion2ModelFormat.Checkpoint]
vae: Optional[str] = Field(None)
config: str
variant: ModelVariantType
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
assert base_model == BaseModelType.StableDiffusion2
assert model_type == ModelType.Main
super().__init__(
model_path=model_path,
base_model=BaseModelType.StableDiffusion2,
model_type=ModelType.Main,
)
@classmethod
def probe_config(cls, path: str, **kwargs):
model_format = cls.detect_format(path)
ckpt_config_path = kwargs.get("config", None)
if model_format == StableDiffusion2ModelFormat.Checkpoint:
if ckpt_config_path:
ckpt_config = OmegaConf.load(ckpt_config_path)
ckpt_config["model"]["params"]["unet_config"]["params"]["in_channels"]
else:
checkpoint = read_checkpoint_meta(path)
checkpoint = checkpoint.get("state_dict", checkpoint)
in_channels = checkpoint["model.diffusion_model.input_blocks.0.0.weight"].shape[1]
elif model_format == StableDiffusion2ModelFormat.Diffusers:
unet_config_path = os.path.join(path, "unet", "config.json")
if os.path.exists(unet_config_path):
with open(unet_config_path, "r") as f:
unet_config = json.loads(f.read())
in_channels = unet_config["in_channels"]
else:
raise Exception("Not supported stable diffusion diffusers format(possibly onnx?)")
else:
raise NotImplementedError(f"Unknown stable diffusion 2.* format: {model_format}")
if in_channels == 9:
variant = ModelVariantType.Inpaint
elif in_channels == 5:
variant = ModelVariantType.Depth
elif in_channels == 4:
variant = ModelVariantType.Normal
else:
raise Exception("Unkown stable diffusion 2.* model format")
if ckpt_config_path is None:
ckpt_config_path = _select_ckpt_config(BaseModelType.StableDiffusion2, variant)
return cls.create_config(
path=path,
model_format=model_format,
config=ckpt_config_path,
variant=variant,
)
@classproperty
def save_to_config(cls) -> bool:
return True
@classmethod
def detect_format(cls, model_path: str):
if not os.path.exists(model_path):
raise ModelNotFoundException()
if os.path.isdir(model_path):
if os.path.exists(os.path.join(model_path, "model_index.json")):
return StableDiffusion2ModelFormat.Diffusers
if os.path.isfile(model_path):
if any(model_path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]):
return StableDiffusion2ModelFormat.Checkpoint
raise InvalidModelException(f"Not a valid model: {model_path}")
@classmethod
def convert_if_required(
cls,
model_path: str,
output_path: str,
config: ModelConfigBase,
base_model: BaseModelType,
) -> str:
if isinstance(config, cls.CheckpointConfig):
return _convert_ckpt_and_cache(
version=BaseModelType.StableDiffusion2,
model_config=config,
output_path=output_path,
)
else:
return model_path
# TODO: rework
# pass precision - currently defaulting to fp16
def _convert_ckpt_and_cache(
version: BaseModelType,
model_config: Union[
StableDiffusion1Model.CheckpointConfig,
StableDiffusion2Model.CheckpointConfig,
StableDiffusionXLModel.CheckpointConfig,
],
output_path: str,
use_save_model: bool = False,
**kwargs,
) -> str:
"""
Convert the checkpoint model indicated in mconfig into a
diffusers, cache it to disk, and return Path to converted
file. If already on disk then just returns Path.
"""
app_config = InvokeAIAppConfig.get_config()
weights = app_config.models_path / model_config.path
config_file = app_config.root_path / model_config.config
output_path = Path(output_path)
variant = model_config.variant
pipeline_class = StableDiffusionInpaintPipeline if variant == "inpaint" else StableDiffusionPipeline
# return cached version if it exists
if output_path.exists():
return output_path
# to avoid circular import errors
from ...util.devices import choose_torch_device, torch_dtype
from ..convert_ckpt_to_diffusers import convert_ckpt_to_diffusers
model_base_to_model_type = {
BaseModelType.StableDiffusion1: "FrozenCLIPEmbedder",
BaseModelType.StableDiffusion2: "FrozenOpenCLIPEmbedder",
BaseModelType.StableDiffusionXL: "SDXL",
BaseModelType.StableDiffusionXLRefiner: "SDXL-Refiner",
}
logger.info(f"Converting {weights} to diffusers format")
with SilenceWarnings():
convert_ckpt_to_diffusers(
weights,
output_path,
model_type=model_base_to_model_type[version],
model_version=version,
model_variant=model_config.variant,
original_config_file=config_file,
extract_ema=True,
scan_needed=True,
pipeline_class=pipeline_class,
from_safetensors=weights.suffix == ".safetensors",
precision=torch_dtype(choose_torch_device()),
**kwargs,
)
return output_path
def _select_ckpt_config(version: BaseModelType, variant: ModelVariantType):
ckpt_configs = {
BaseModelType.StableDiffusion1: {
ModelVariantType.Normal: "v1-inference.yaml",
ModelVariantType.Inpaint: "v1-inpainting-inference.yaml",
},
BaseModelType.StableDiffusion2: {
ModelVariantType.Normal: "v2-inference-v.yaml", # best guess, as we can't differentiate with base(512)
ModelVariantType.Inpaint: "v2-inpainting-inference.yaml",
ModelVariantType.Depth: "v2-midas-inference.yaml",
},
BaseModelType.StableDiffusionXL: {
ModelVariantType.Normal: "sd_xl_base.yaml",
ModelVariantType.Inpaint: None,
ModelVariantType.Depth: None,
},
BaseModelType.StableDiffusionXLRefiner: {
ModelVariantType.Normal: "sd_xl_refiner.yaml",
ModelVariantType.Inpaint: None,
ModelVariantType.Depth: None,
},
}
app_config = InvokeAIAppConfig.get_config()
try:
config_path = app_config.legacy_conf_path / ckpt_configs[version][variant]
if config_path.is_relative_to(app_config.root_path):
config_path = config_path.relative_to(app_config.root_path)
return str(config_path)
except Exception:
return None

View File

@ -1,150 +0,0 @@
from enum import Enum
from typing import Literal
from diffusers import OnnxRuntimeModel
from .base import (
BaseModelType,
DiffusersModel,
IAIOnnxRuntimeModel,
ModelConfigBase,
ModelType,
ModelVariantType,
SchedulerPredictionType,
classproperty,
)
class StableDiffusionOnnxModelFormat(str, Enum):
Olive = "olive"
Onnx = "onnx"
class ONNXStableDiffusion1Model(DiffusersModel):
class Config(ModelConfigBase):
model_format: Literal[StableDiffusionOnnxModelFormat.Onnx]
variant: ModelVariantType
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
assert base_model == BaseModelType.StableDiffusion1
assert model_type == ModelType.ONNX
super().__init__(
model_path=model_path,
base_model=BaseModelType.StableDiffusion1,
model_type=ModelType.ONNX,
)
for child_name, child_type in self.child_types.items():
if child_type is OnnxRuntimeModel:
self.child_types[child_name] = IAIOnnxRuntimeModel
# TODO: check that no optimum models provided
@classmethod
def probe_config(cls, path: str, **kwargs):
model_format = cls.detect_format(path)
in_channels = 4 # TODO:
if in_channels == 9:
variant = ModelVariantType.Inpaint
elif in_channels == 4:
variant = ModelVariantType.Normal
else:
raise Exception("Unkown stable diffusion 1.* model format")
return cls.create_config(
path=path,
model_format=model_format,
variant=variant,
)
@classproperty
def save_to_config(cls) -> bool:
return True
@classmethod
def detect_format(cls, model_path: str):
# TODO: Detect onnx vs olive
return StableDiffusionOnnxModelFormat.Onnx
@classmethod
def convert_if_required(
cls,
model_path: str,
output_path: str,
config: ModelConfigBase,
base_model: BaseModelType,
) -> str:
return model_path
class ONNXStableDiffusion2Model(DiffusersModel):
# TODO: check that configs overwriten properly
class Config(ModelConfigBase):
model_format: Literal[StableDiffusionOnnxModelFormat.Onnx]
variant: ModelVariantType
prediction_type: SchedulerPredictionType
upcast_attention: bool
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
assert base_model == BaseModelType.StableDiffusion2
assert model_type == ModelType.ONNX
super().__init__(
model_path=model_path,
base_model=BaseModelType.StableDiffusion2,
model_type=ModelType.ONNX,
)
for child_name, child_type in self.child_types.items():
if child_type is OnnxRuntimeModel:
self.child_types[child_name] = IAIOnnxRuntimeModel
# TODO: check that no optimum models provided
@classmethod
def probe_config(cls, path: str, **kwargs):
model_format = cls.detect_format(path)
in_channels = 4 # TODO:
if in_channels == 9:
variant = ModelVariantType.Inpaint
elif in_channels == 5:
variant = ModelVariantType.Depth
elif in_channels == 4:
variant = ModelVariantType.Normal
else:
raise Exception("Unkown stable diffusion 2.* model format")
if variant == ModelVariantType.Normal:
prediction_type = SchedulerPredictionType.VPrediction
upcast_attention = True
else:
prediction_type = SchedulerPredictionType.Epsilon
upcast_attention = False
return cls.create_config(
path=path,
model_format=model_format,
variant=variant,
prediction_type=prediction_type,
upcast_attention=upcast_attention,
)
@classproperty
def save_to_config(cls) -> bool:
return True
@classmethod
def detect_format(cls, model_path: str):
# TODO: Detect onnx vs olive
return StableDiffusionOnnxModelFormat.Onnx
@classmethod
def convert_if_required(
cls,
model_path: str,
output_path: str,
config: ModelConfigBase,
base_model: BaseModelType,
) -> str:
return model_path

View File

@ -1,102 +0,0 @@
import os
from enum import Enum
from typing import Literal, Optional
import torch
from diffusers import T2IAdapter
from invokeai.backend.model_management.models.base import (
BaseModelType,
EmptyConfigLoader,
InvalidModelException,
ModelBase,
ModelConfigBase,
ModelNotFoundException,
ModelType,
SubModelType,
calc_model_size_by_data,
calc_model_size_by_fs,
classproperty,
)
class T2IAdapterModelFormat(str, Enum):
Diffusers = "diffusers"
class T2IAdapterModel(ModelBase):
class DiffusersConfig(ModelConfigBase):
model_format: Literal[T2IAdapterModelFormat.Diffusers]
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
assert model_type == ModelType.T2IAdapter
super().__init__(model_path, base_model, model_type)
config = EmptyConfigLoader.load_config(self.model_path, config_name="config.json")
model_class_name = config.get("_class_name", None)
if model_class_name not in {"T2IAdapter"}:
raise InvalidModelException(f"Invalid T2I-Adapter model. Unknown _class_name: '{model_class_name}'.")
self.model_class = self._hf_definition_to_type(["diffusers", model_class_name])
self.model_size = calc_model_size_by_fs(self.model_path)
def get_size(self, child_type: Optional[SubModelType] = None):
if child_type is not None:
raise ValueError(f"T2I-Adapters do not have child models. Invalid child type: '{child_type}'.")
return self.model_size
def get_model(
self,
torch_dtype: Optional[torch.dtype],
child_type: Optional[SubModelType] = None,
) -> T2IAdapter:
if child_type is not None:
raise ValueError(f"T2I-Adapters do not have child models. Invalid child type: '{child_type}'.")
model = None
for variant in ["fp16", None]:
try:
model = self.model_class.from_pretrained(
self.model_path,
torch_dtype=torch_dtype,
variant=variant,
)
break
except Exception:
pass
if not model:
raise ModelNotFoundException()
# Calculate a more accurate size after loading the model into memory.
self.model_size = calc_model_size_by_data(model)
return model
@classproperty
def save_to_config(cls) -> bool:
return False
@classmethod
def detect_format(cls, path: str):
if not os.path.exists(path):
raise ModelNotFoundException(f"Model not found at '{path}'.")
if os.path.isdir(path):
if os.path.exists(os.path.join(path, "config.json")):
return T2IAdapterModelFormat.Diffusers
raise InvalidModelException(f"Unsupported T2I-Adapter format: '{path}'.")
@classmethod
def convert_if_required(
cls,
model_path: str,
output_path: str,
config: ModelConfigBase,
base_model: BaseModelType,
) -> str:
format = cls.detect_format(model_path)
if format == T2IAdapterModelFormat.Diffusers:
return model_path
else:
raise ValueError(f"Unsupported format: '{format}'.")

View File

@ -1,87 +0,0 @@
import os
from typing import Optional
import torch
# TODO: naming
from ..lora import TextualInversionModel as TextualInversionModelRaw
from .base import (
BaseModelType,
InvalidModelException,
ModelBase,
ModelConfigBase,
ModelNotFoundException,
ModelType,
SubModelType,
classproperty,
)
class TextualInversionModel(ModelBase):
# model_size: int
class Config(ModelConfigBase):
model_format: None
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
assert model_type == ModelType.TextualInversion
super().__init__(model_path, base_model, model_type)
self.model_size = os.path.getsize(self.model_path)
def get_size(self, child_type: Optional[SubModelType] = None):
if child_type is not None:
raise Exception("There is no child models in textual inversion")
return self.model_size
def get_model(
self,
torch_dtype: Optional[torch.dtype],
child_type: Optional[SubModelType] = None,
):
if child_type is not None:
raise Exception("There is no child models in textual inversion")
checkpoint_path = self.model_path
if os.path.isdir(checkpoint_path):
checkpoint_path = os.path.join(checkpoint_path, "learned_embeds.bin")
if not os.path.exists(checkpoint_path):
raise ModelNotFoundException()
model = TextualInversionModelRaw.from_checkpoint(
file_path=checkpoint_path,
dtype=torch_dtype,
)
self.model_size = model.embedding.nelement() * model.embedding.element_size()
return model
@classproperty
def save_to_config(cls) -> bool:
return False
@classmethod
def detect_format(cls, path: str):
if not os.path.exists(path):
raise ModelNotFoundException()
if os.path.isdir(path):
if os.path.exists(os.path.join(path, "learned_embeds.bin")):
return None # diffusers-ti
if os.path.isfile(path):
if any(path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt", "bin"]):
return None
raise InvalidModelException(f"Not a valid model: {path}")
@classmethod
def convert_if_required(
cls,
model_path: str,
output_path: str,
config: ModelConfigBase,
base_model: BaseModelType,
) -> str:
return model_path

View File

@ -1,179 +0,0 @@
import os
from enum import Enum
from pathlib import Path
from typing import Optional
import safetensors
import torch
from omegaconf import OmegaConf
from invokeai.app.services.config import InvokeAIAppConfig
from .base import (
BaseModelType,
EmptyConfigLoader,
InvalidModelException,
ModelBase,
ModelConfigBase,
ModelNotFoundException,
ModelType,
ModelVariantType,
SubModelType,
calc_model_size_by_data,
calc_model_size_by_fs,
classproperty,
)
class VaeModelFormat(str, Enum):
Checkpoint = "checkpoint"
Diffusers = "diffusers"
class VaeModel(ModelBase):
# vae_class: Type
# model_size: int
class Config(ModelConfigBase):
model_format: VaeModelFormat
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
assert model_type == ModelType.Vae
super().__init__(model_path, base_model, model_type)
try:
config = EmptyConfigLoader.load_config(self.model_path, config_name="config.json")
# config = json.loads(os.path.join(self.model_path, "config.json"))
except Exception:
raise Exception("Invalid vae model! (config.json not found or invalid)")
try:
vae_class_name = config.get("_class_name", "AutoencoderKL")
self.vae_class = self._hf_definition_to_type(["diffusers", vae_class_name])
self.model_size = calc_model_size_by_fs(self.model_path)
except Exception:
raise Exception("Invalid vae model! (Unkown vae type)")
def get_size(self, child_type: Optional[SubModelType] = None):
if child_type is not None:
raise Exception("There is no child models in vae model")
return self.model_size
def get_model(
self,
torch_dtype: Optional[torch.dtype],
child_type: Optional[SubModelType] = None,
):
if child_type is not None:
raise Exception("There is no child models in vae model")
model = self.vae_class.from_pretrained(
self.model_path,
torch_dtype=torch_dtype,
)
# calc more accurate size
self.model_size = calc_model_size_by_data(model)
return model
@classproperty
def save_to_config(cls) -> bool:
return False
@classmethod
def detect_format(cls, path: str):
if not os.path.exists(path):
raise ModelNotFoundException(f"Does not exist as local file: {path}")
if os.path.isdir(path):
if os.path.exists(os.path.join(path, "config.json")):
return VaeModelFormat.Diffusers
if os.path.isfile(path):
if any(path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]):
return VaeModelFormat.Checkpoint
raise InvalidModelException(f"Not a valid model: {path}")
@classmethod
def convert_if_required(
cls,
model_path: str,
output_path: str,
config: ModelConfigBase, # empty config or config of parent model
base_model: BaseModelType,
) -> str:
if cls.detect_format(model_path) == VaeModelFormat.Checkpoint:
return _convert_vae_ckpt_and_cache(
weights_path=model_path,
output_path=output_path,
base_model=base_model,
model_config=config,
)
else:
return model_path
# TODO: rework
def _convert_vae_ckpt_and_cache(
weights_path: str,
output_path: str,
base_model: BaseModelType,
model_config: ModelConfigBase,
) -> str:
"""
Convert the VAE indicated in mconfig into a diffusers AutoencoderKL
object, cache it to disk, and return Path to converted
file. If already on disk then just returns Path.
"""
app_config = InvokeAIAppConfig.get_config()
weights_path = app_config.root_dir / weights_path
output_path = Path(output_path)
"""
this size used only in when tiling enabled to separate input in tiles
sizes in configs from stable diffusion githubs(1 and 2) set to 256
on huggingface it:
1.5 - 512
1.5-inpainting - 256
2-inpainting - 512
2-depth - 256
2-base - 512
2 - 768
2.1-base - 768
2.1 - 768
"""
image_size = 512
# return cached version if it exists
if output_path.exists():
return output_path
if base_model in {BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2}:
from .stable_diffusion import _select_ckpt_config
# all sd models use same vae settings
config_file = _select_ckpt_config(base_model, ModelVariantType.Normal)
else:
raise Exception(f"Vae conversion not supported for model type: {base_model}")
# this avoids circular import error
from ..convert_ckpt_to_diffusers import convert_ldm_vae_to_diffusers
if weights_path.suffix == ".safetensors":
checkpoint = safetensors.torch.load_file(weights_path, device="cpu")
else:
checkpoint = torch.load(weights_path, map_location="cpu")
# sometimes weights are hidden under "state_dict", and sometimes not
if "state_dict" in checkpoint:
checkpoint = checkpoint["state_dict"]
config = OmegaConf.load(app_config.root_path / config_file)
vae_model = convert_ldm_vae_to_diffusers(
checkpoint=checkpoint,
vae_config=config,
image_size=image_size,
)
vae_model.save_pretrained(output_path, safe_serialization=True)
return output_path

View File

@ -1,84 +0,0 @@
from __future__ import annotations
from contextlib import contextmanager
from typing import Callable, List, Union
import torch.nn as nn
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
def _conv_forward_asymmetric(self, input, weight, bias):
"""
Patch for Conv2d._conv_forward that supports asymmetric padding
"""
working = nn.functional.pad(input, self.asymmetric_padding["x"], mode=self.asymmetric_padding_mode["x"])
working = nn.functional.pad(working, self.asymmetric_padding["y"], mode=self.asymmetric_padding_mode["y"])
return nn.functional.conv2d(
working,
weight,
bias,
self.stride,
nn.modules.utils._pair(0),
self.dilation,
self.groups,
)
@contextmanager
def set_seamless(model: Union[UNet2DConditionModel, AutoencoderKL], seamless_axes: List[str]):
# Callable: (input: Tensor, weight: Tensor, bias: Optional[Tensor]) -> Tensor
to_restore: list[tuple[nn.Conv2d | nn.ConvTranspose2d, Callable]] = []
try:
# Hard coded to skip down block layers, allowing for seamless tiling at the expense of prompt adherence
skipped_layers = 1
for m_name, m in model.named_modules():
if not isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
continue
if isinstance(model, UNet2DConditionModel) and m_name.startswith("down_blocks.") and ".resnets." in m_name:
# down_blocks.1.resnets.1.conv1
_, block_num, _, resnet_num, submodule_name = m_name.split(".")
block_num = int(block_num)
resnet_num = int(resnet_num)
if block_num >= len(model.down_blocks) - skipped_layers:
continue
# Skip the second resnet (could be configurable)
if resnet_num > 0:
continue
# Skip Conv2d layers (could be configurable)
if submodule_name == "conv2":
continue
m.asymmetric_padding_mode = {}
m.asymmetric_padding = {}
m.asymmetric_padding_mode["x"] = "circular" if ("x" in seamless_axes) else "constant"
m.asymmetric_padding["x"] = (
m._reversed_padding_repeated_twice[0],
m._reversed_padding_repeated_twice[1],
0,
0,
)
m.asymmetric_padding_mode["y"] = "circular" if ("y" in seamless_axes) else "constant"
m.asymmetric_padding["y"] = (
0,
0,
m._reversed_padding_repeated_twice[2],
m._reversed_padding_repeated_twice[3],
)
to_restore.append((m, m._conv_forward))
m._conv_forward = _conv_forward_asymmetric.__get__(m, nn.Conv2d)
yield
finally:
for module, orig_conv_forward in to_restore:
module._conv_forward = orig_conv_forward
if hasattr(module, "asymmetric_padding_mode"):
del module.asymmetric_padding_mode
if hasattr(module, "asymmetric_padding"):
del module.asymmetric_padding

View File

@ -1,79 +0,0 @@
# Copyright (c) 2023 The InvokeAI Development Team
"""Utilities used by the Model Manager"""
def lora_token_vector_length(checkpoint: dict) -> int:
"""
Given a checkpoint in memory, return the lora token vector length
:param checkpoint: The checkpoint
"""
def _get_shape_1(key: str, tensor, checkpoint) -> int:
lora_token_vector_length = None
if "." not in key:
return lora_token_vector_length # wrong key format
model_key, lora_key = key.split(".", 1)
# check lora/locon
if lora_key == "lora_down.weight":
lora_token_vector_length = tensor.shape[1]
# check loha (don't worry about hada_t1/hada_t2 as it used only in 4d shapes)
elif lora_key in ["hada_w1_b", "hada_w2_b"]:
lora_token_vector_length = tensor.shape[1]
# check lokr (don't worry about lokr_t2 as it used only in 4d shapes)
elif "lokr_" in lora_key:
if model_key + ".lokr_w1" in checkpoint:
_lokr_w1 = checkpoint[model_key + ".lokr_w1"]
elif model_key + "lokr_w1_b" in checkpoint:
_lokr_w1 = checkpoint[model_key + ".lokr_w1_b"]
else:
return lora_token_vector_length # unknown format
if model_key + ".lokr_w2" in checkpoint:
_lokr_w2 = checkpoint[model_key + ".lokr_w2"]
elif model_key + "lokr_w2_b" in checkpoint:
_lokr_w2 = checkpoint[model_key + ".lokr_w2_b"]
else:
return lora_token_vector_length # unknown format
lora_token_vector_length = _lokr_w1.shape[1] * _lokr_w2.shape[1]
elif lora_key == "diff":
lora_token_vector_length = tensor.shape[1]
# ia3 can be detected only by shape[0] in text encoder
elif lora_key == "weight" and "lora_unet_" not in model_key:
lora_token_vector_length = tensor.shape[0]
return lora_token_vector_length
lora_token_vector_length = None
lora_te1_length = None
lora_te2_length = None
for key, tensor in checkpoint.items():
if key.startswith("lora_unet_") and ("_attn2_to_k." in key or "_attn2_to_v." in key):
lora_token_vector_length = _get_shape_1(key, tensor, checkpoint)
elif key.startswith("lora_unet_") and (
"time_emb_proj.lora_down" in key
): # recognizes format at https://civitai.com/models/224641
lora_token_vector_length = _get_shape_1(key, tensor, checkpoint)
elif key.startswith("lora_te") and "_self_attn_" in key:
tmp_length = _get_shape_1(key, tensor, checkpoint)
if key.startswith("lora_te_"):
lora_token_vector_length = tmp_length
elif key.startswith("lora_te1_"):
lora_te1_length = tmp_length
elif key.startswith("lora_te2_"):
lora_te2_length = tmp_length
if lora_te1_length is not None and lora_te2_length is not None:
lora_token_vector_length = lora_te1_length + lora_te2_length
if lora_token_vector_length is not None:
break
return lora_token_vector_length

View File

@ -1,5 +1,4 @@
"""Re-export frequently-used symbols from the Model Manager backend.""" """Re-export frequently-used symbols from the Model Manager backend."""
from .config import ( from .config import (
AnyModel, AnyModel,
AnyModelConfig, AnyModelConfig,
@ -33,3 +32,42 @@ __all__ = [
"SchedulerPredictionType", "SchedulerPredictionType",
"SubModelType", "SubModelType",
] ]
########## to help populate the openapi_schema with format enums for each config ###########
# This code is no longer necessary?
# leave it here just in case
#
# import inspect
# from enum import Enum
# from typing import Any, Iterable, Dict, get_args, Set
# def _expand(something: Any) -> Iterable[type]:
# if isinstance(something, type):
# yield something
# else:
# for x in get_args(something):
# for y in _expand(x):
# yield y
# def _find_format(cls: type) -> Iterable[Enum]:
# if hasattr(inspect, "get_annotations"):
# fields = inspect.get_annotations(cls)
# else:
# fields = cls.__annotations__
# if "format" in fields:
# for x in get_args(fields["format"]):
# yield x
# for parent_class in cls.__bases__:
# for x in _find_format(parent_class):
# yield x
# return None
# def get_model_config_formats() -> Dict[str, Set[Enum]]:
# result: Dict[str, Set[Enum]] = {}
# for model_config in _expand(AnyModelConfig):
# for field in _find_format(model_config):
# if field is None:
# continue
# if not result.get(model_config.__qualname__):
# result[model_config.__qualname__] = set()
# result[model_config.__qualname__].add(field)
# return result

View File

@ -6,12 +6,22 @@ from importlib import import_module
from pathlib import Path from pathlib import Path
from .convert_cache.convert_cache_default import ModelConvertCache from .convert_cache.convert_cache_default import ModelConvertCache
from .load_base import AnyModelLoader, LoadedModel from .load_base import LoadedModel, ModelLoaderBase
from .load_default import ModelLoader
from .model_cache.model_cache_default import ModelCache from .model_cache.model_cache_default import ModelCache
from .model_loader_registry import ModelLoaderRegistry, ModelLoaderRegistryBase
# This registers the subclasses that implement loaders of specific model types # This registers the subclasses that implement loaders of specific model types
loaders = [x.stem for x in Path(Path(__file__).parent, "model_loaders").glob("*.py") if x.stem != "__init__"] loaders = [x.stem for x in Path(Path(__file__).parent, "model_loaders").glob("*.py") if x.stem != "__init__"]
for module in loaders: for module in loaders:
import_module(f"{__package__}.model_loaders.{module}") import_module(f"{__package__}.model_loaders.{module}")
__all__ = ["AnyModelLoader", "LoadedModel", "ModelCache", "ModelConvertCache"] __all__ = [
"LoadedModel",
"ModelCache",
"ModelConvertCache",
"ModelLoaderBase",
"ModelLoader",
"ModelLoaderRegistryBase",
"ModelLoaderRegistry",
]

View File

@ -1,37 +1,22 @@
# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team # Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team
""" """
Base class for model loading in InvokeAI. Base class for model loading in InvokeAI.
Use like this:
loader = AnyModelLoader(...)
loaded_model = loader.get_model('019ab39adfa1840455')
with loaded_model as model: # context manager moves model into VRAM
# do something with loaded_model
""" """
import hashlib
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from logging import Logger from logging import Logger
from pathlib import Path from pathlib import Path
from typing import Any, Callable, Dict, Optional, Tuple, Type from typing import Any, Optional
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.model_manager.config import ( from invokeai.backend.model_manager.config import (
AnyModel, AnyModel,
AnyModelConfig, AnyModelConfig,
BaseModelType,
ModelConfigBase,
ModelFormat,
ModelType,
SubModelType, SubModelType,
VaeCheckpointConfig,
VaeDiffusersConfig,
) )
from invokeai.backend.model_manager.load.convert_cache.convert_cache_base import ModelConvertCacheBase from invokeai.backend.model_manager.load.convert_cache.convert_cache_base import ModelConvertCacheBase
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase
from invokeai.backend.util.logging import InvokeAILogger
@dataclass @dataclass
@ -56,6 +41,14 @@ class LoadedModel:
return self._locker.model return self._locker.model
# TODO(MM2):
# Some "intermediary" subclasses in the ModelLoaderBase class hierarchy define methods that their subclasses don't
# know about. I think the problem may be related to this class being an ABC.
#
# For example, GenericDiffusersLoader defines `get_hf_load_class()`, and StableDiffusionDiffusersModel attempts to
# call it. However, the method is not defined in the ABC, so it is not guaranteed to be implemented.
class ModelLoaderBase(ABC): class ModelLoaderBase(ABC):
"""Abstract base class for loading models into RAM/VRAM.""" """Abstract base class for loading models into RAM/VRAM."""
@ -71,7 +64,7 @@ class ModelLoaderBase(ABC):
pass pass
@abstractmethod @abstractmethod
def load_model(self, model_config: ModelConfigBase, submodel_type: Optional[SubModelType] = None) -> LoadedModel: def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
""" """
Return a model given its confguration. Return a model given its confguration.
@ -90,106 +83,3 @@ class ModelLoaderBase(ABC):
) -> int: ) -> int:
"""Return size in bytes of the model, calculated before loading.""" """Return size in bytes of the model, calculated before loading."""
pass pass
# TO DO: Better name?
class AnyModelLoader:
"""This class manages the model loaders and invokes the correct one to load a model of given base and type."""
# this tracks the loader subclasses
_registry: Dict[str, Type[ModelLoaderBase]] = {}
_logger: Logger = InvokeAILogger.get_logger()
def __init__(
self,
app_config: InvokeAIAppConfig,
logger: Logger,
ram_cache: ModelCacheBase[AnyModel],
convert_cache: ModelConvertCacheBase,
):
"""Initialize AnyModelLoader with its dependencies."""
self._app_config = app_config
self._logger = logger
self._ram_cache = ram_cache
self._convert_cache = convert_cache
@property
def ram_cache(self) -> ModelCacheBase[AnyModel]:
"""Return the RAM cache associated used by the loaders."""
return self._ram_cache
@property
def convert_cache(self) -> ModelConvertCacheBase:
"""Return the convert cache associated used by the loaders."""
return self._convert_cache
def load_model(self, model_config: ModelConfigBase, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
"""
Return a model given its configuration.
:param key: model key, as known to the config backend
:param submodel_type: an ModelType enum indicating the portion of
the model to retrieve (e.g. ModelType.Vae)
"""
implementation, model_config, submodel_type = self.__class__.get_implementation(model_config, submodel_type)
return implementation(
app_config=self._app_config,
logger=self._logger,
ram_cache=self._ram_cache,
convert_cache=self._convert_cache,
).load_model(model_config, submodel_type)
@staticmethod
def _to_registry_key(base: BaseModelType, type: ModelType, format: ModelFormat) -> str:
return "-".join([base.value, type.value, format.value])
@classmethod
def get_implementation(
cls, config: ModelConfigBase, submodel_type: Optional[SubModelType]
) -> Tuple[Type[ModelLoaderBase], ModelConfigBase, Optional[SubModelType]]:
"""Get subclass of ModelLoaderBase registered to handle base and type."""
# We have to handle VAE overrides here because this will change the model type and the corresponding implementation returned
conf2, submodel_type = cls._handle_subtype_overrides(config, submodel_type)
key1 = cls._to_registry_key(conf2.base, conf2.type, conf2.format) # for a specific base type
key2 = cls._to_registry_key(BaseModelType.Any, conf2.type, conf2.format) # with wildcard Any
implementation = cls._registry.get(key1) or cls._registry.get(key2)
if not implementation:
raise NotImplementedError(
f"No subclass of LoadedModel is registered for base={config.base}, type={config.type}, format={config.format}"
)
return implementation, conf2, submodel_type
@classmethod
def _handle_subtype_overrides(
cls, config: ModelConfigBase, submodel_type: Optional[SubModelType]
) -> Tuple[ModelConfigBase, Optional[SubModelType]]:
if submodel_type == SubModelType.Vae and hasattr(config, "vae") and config.vae is not None:
model_path = Path(config.vae)
config_class = (
VaeCheckpointConfig if model_path.suffix in [".pt", ".safetensors", ".ckpt"] else VaeDiffusersConfig
)
hash = hashlib.md5(model_path.as_posix().encode("utf-8")).hexdigest()
new_conf = config_class(path=model_path.as_posix(), name=model_path.stem, base=config.base, key=hash)
submodel_type = None
else:
new_conf = config
return new_conf, submodel_type
@classmethod
def register(
cls, type: ModelType, format: ModelFormat, base: BaseModelType = BaseModelType.Any
) -> Callable[[Type[ModelLoaderBase]], Type[ModelLoaderBase]]:
"""Define a decorator which registers the subclass of loader."""
def decorator(subclass: Type[ModelLoaderBase]) -> Type[ModelLoaderBase]:
cls._logger.debug(f"Registering class {subclass.__name__} to load models of type {base}/{type}/{format}")
key = cls._to_registry_key(base, type, format)
if key in cls._registry:
raise Exception(
f"{subclass.__name__} is trying to register as a loader for {base}/{type}/{format}, but this type of model has already been registered by {cls._registry[key].__name__}"
)
cls._registry[key] = subclass
return subclass
return decorator

View File

@ -1,13 +1,9 @@
# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team # Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team
"""Default implementation of model loading in InvokeAI.""" """Default implementation of model loading in InvokeAI."""
import sys
from logging import Logger from logging import Logger
from pathlib import Path from pathlib import Path
from typing import Any, Dict, Optional, Tuple from typing import Optional, Tuple
from diffusers import ModelMixin
from diffusers.configuration_utils import ConfigMixin
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.model_manager import ( from invokeai.backend.model_manager import (
@ -25,17 +21,6 @@ from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_
from invokeai.backend.util.devices import choose_torch_device, torch_dtype from invokeai.backend.util.devices import choose_torch_device, torch_dtype
class ConfigLoader(ConfigMixin):
"""Subclass of ConfigMixin for loading diffusers configuration files."""
@classmethod
def load_config(cls, *args: Any, **kwargs: Any) -> Dict[str, Any]:
"""Load a diffusrs ConfigMixin configuration."""
cls.config_name = kwargs.pop("config_name")
# Diffusers doesn't provide typing info
return super().load_config(*args, **kwargs) # type: ignore
# TO DO: The loader is not thread safe! # TO DO: The loader is not thread safe!
class ModelLoader(ModelLoaderBase): class ModelLoader(ModelLoaderBase):
"""Default implementation of ModelLoaderBase.""" """Default implementation of ModelLoaderBase."""
@ -137,43 +122,6 @@ class ModelLoader(ModelLoaderBase):
variant=config.repo_variant if hasattr(config, "repo_variant") else None, variant=config.repo_variant if hasattr(config, "repo_variant") else None,
) )
def _load_diffusers_config(self, model_path: Path, config_name: str = "config.json") -> Dict[str, Any]:
return ConfigLoader.load_config(model_path, config_name=config_name)
# TO DO: Add exception handling
def _hf_definition_to_type(self, module: str, class_name: str) -> ModelMixin: # fix with correct type
if module in ["diffusers", "transformers"]:
res_type = sys.modules[module]
else:
res_type = sys.modules["diffusers"].pipelines
result: ModelMixin = getattr(res_type, class_name)
return result
# TO DO: Add exception handling
def _get_hf_load_class(self, model_path: Path, submodel_type: Optional[SubModelType] = None) -> ModelMixin:
if submodel_type:
try:
config = self._load_diffusers_config(model_path, config_name="model_index.json")
module, class_name = config[submodel_type.value]
return self._hf_definition_to_type(module=module, class_name=class_name)
except KeyError as e:
raise InvalidModelConfigException(
f'The "{submodel_type}" submodel is not available for this model.'
) from e
else:
try:
config = self._load_diffusers_config(model_path, config_name="config.json")
class_name = config.get("_class_name", None)
if class_name:
return self._hf_definition_to_type(module="diffusers", class_name=class_name)
if config.get("model_type", None) == "clip_vision_model":
class_name = config.get("architectures")[0]
return self._hf_definition_to_type(module="transformers", class_name=class_name)
if not class_name:
raise InvalidModelConfigException("Unable to decifer Load Class based on given config.json")
except KeyError as e:
raise InvalidModelConfigException("An expected config.json file is missing from this model.") from e
# This needs to be implemented in subclasses that handle checkpoints # This needs to be implemented in subclasses that handle checkpoints
def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Path) -> Path: def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Path) -> Path:
raise NotImplementedError raise NotImplementedError

View File

@ -55,7 +55,7 @@ class MemorySnapshot:
vram = None vram = None
try: try:
malloc_info = LibcUtil().mallinfo2() # type: ignore malloc_info = LibcUtil().mallinfo2()
except (OSError, AttributeError): except (OSError, AttributeError):
# OSError: This is expected in environments that do not have the 'libc.so.6' shared library. # OSError: This is expected in environments that do not have the 'libc.so.6' shared library.
# AttributeError: This is expected in environments that have `libc.so.6` but do not have the `mallinfo2` (e.g. glibc < 2.33) # AttributeError: This is expected in environments that have `libc.so.6` but do not have the `mallinfo2` (e.g. glibc < 2.33)

View File

@ -0,0 +1,122 @@
# Copyright (c) 2024 Lincoln D. Stein and the InvokeAI Development team
"""
This module implements a system in which model loaders register the
type, base and format of models that they know how to load.
Use like this:
cls, model_config, submodel_type = ModelLoaderRegistry.get_implementation(model_config, submodel_type) # type: ignore
loaded_model = cls(
app_config=app_config,
logger=logger,
ram_cache=ram_cache,
convert_cache=convert_cache
).load_model(model_config, submodel_type)
"""
import hashlib
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Callable, Dict, Optional, Tuple, Type
from ..config import (
AnyModelConfig,
BaseModelType,
ModelConfigBase,
ModelFormat,
ModelType,
SubModelType,
VaeCheckpointConfig,
VaeDiffusersConfig,
)
from . import ModelLoaderBase
class ModelLoaderRegistryBase(ABC):
"""This class allows model loaders to register their type, base and format."""
@classmethod
@abstractmethod
def register(
cls, type: ModelType, format: ModelFormat, base: BaseModelType = BaseModelType.Any
) -> Callable[[Type[ModelLoaderBase]], Type[ModelLoaderBase]]:
"""Define a decorator which registers the subclass of loader."""
@classmethod
@abstractmethod
def get_implementation(
cls, config: AnyModelConfig, submodel_type: Optional[SubModelType]
) -> Tuple[Type[ModelLoaderBase], ModelConfigBase, Optional[SubModelType]]:
"""
Get subclass of ModelLoaderBase registered to handle base and type.
Parameters:
:param config: Model configuration record, as returned by ModelRecordService
:param submodel_type: Submodel to fetch (main models only)
:return: tuple(loader_class, model_config, submodel_type)
Note that the returned model config may be different from one what passed
in, in the event that a submodel type is provided.
"""
class ModelLoaderRegistry:
"""
This class allows model loaders to register their type, base and format.
"""
_registry: Dict[str, Type[ModelLoaderBase]] = {}
@classmethod
def register(
cls, type: ModelType, format: ModelFormat, base: BaseModelType = BaseModelType.Any
) -> Callable[[Type[ModelLoaderBase]], Type[ModelLoaderBase]]:
"""Define a decorator which registers the subclass of loader."""
def decorator(subclass: Type[ModelLoaderBase]) -> Type[ModelLoaderBase]:
key = cls._to_registry_key(base, type, format)
if key in cls._registry:
raise Exception(
f"{subclass.__name__} is trying to register as a loader for {base}/{type}/{format}, but this type of model has already been registered by {cls._registry[key].__name__}"
)
cls._registry[key] = subclass
return subclass
return decorator
@classmethod
def get_implementation(
cls, config: AnyModelConfig, submodel_type: Optional[SubModelType]
) -> Tuple[Type[ModelLoaderBase], ModelConfigBase, Optional[SubModelType]]:
"""Get subclass of ModelLoaderBase registered to handle base and type."""
# We have to handle VAE overrides here because this will change the model type and the corresponding implementation returned
conf2, submodel_type = cls._handle_subtype_overrides(config, submodel_type)
key1 = cls._to_registry_key(conf2.base, conf2.type, conf2.format) # for a specific base type
key2 = cls._to_registry_key(BaseModelType.Any, conf2.type, conf2.format) # with wildcard Any
implementation = cls._registry.get(key1) or cls._registry.get(key2)
if not implementation:
raise NotImplementedError(
f"No subclass of LoadedModel is registered for base={config.base}, type={config.type}, format={config.format}"
)
return implementation, conf2, submodel_type
@classmethod
def _handle_subtype_overrides(
cls, config: AnyModelConfig, submodel_type: Optional[SubModelType]
) -> Tuple[ModelConfigBase, Optional[SubModelType]]:
if submodel_type == SubModelType.Vae and hasattr(config, "vae") and config.vae is not None:
model_path = Path(config.vae)
config_class = (
VaeCheckpointConfig if model_path.suffix in [".pt", ".safetensors", ".ckpt"] else VaeDiffusersConfig
)
hash = hashlib.md5(model_path.as_posix().encode("utf-8")).hexdigest()
new_conf = config_class(path=model_path.as_posix(), name=model_path.stem, base=config.base, key=hash)
submodel_type = None
else:
new_conf = config
return new_conf, submodel_type
@staticmethod
def _to_registry_key(base: BaseModelType, type: ModelType, format: ModelFormat) -> str:
return "-".join([base.value, type.value, format.value])

View File

@ -13,13 +13,13 @@ from invokeai.backend.model_manager import (
ModelType, ModelType,
) )
from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_controlnet_to_diffusers from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_controlnet_to_diffusers
from invokeai.backend.model_manager.load.load_base import AnyModelLoader
from .. import ModelLoaderRegistry
from .generic_diffusers import GenericDiffusersLoader from .generic_diffusers import GenericDiffusersLoader
@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.ControlNet, format=ModelFormat.Diffusers) @ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.ControlNet, format=ModelFormat.Diffusers)
@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.ControlNet, format=ModelFormat.Checkpoint) @ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.ControlNet, format=ModelFormat.Checkpoint)
class ControlnetLoader(GenericDiffusersLoader): class ControlnetLoader(GenericDiffusersLoader):
"""Class to load ControlNet models.""" """Class to load ControlNet models."""

View File

@ -1,24 +1,27 @@
# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team # Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team
"""Class for simple diffusers model loading in InvokeAI.""" """Class for simple diffusers model loading in InvokeAI."""
import sys
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Any, Dict, Optional
from diffusers import ConfigMixin, ModelMixin
from invokeai.backend.model_manager import ( from invokeai.backend.model_manager import (
AnyModel, AnyModel,
BaseModelType, BaseModelType,
InvalidModelConfigException,
ModelFormat, ModelFormat,
ModelRepoVariant, ModelRepoVariant,
ModelType, ModelType,
SubModelType, SubModelType,
) )
from ..load_base import AnyModelLoader from .. import ModelLoader, ModelLoaderRegistry
from ..load_default import ModelLoader
@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.CLIPVision, format=ModelFormat.Diffusers) @ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.CLIPVision, format=ModelFormat.Diffusers)
@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.T2IAdapter, format=ModelFormat.Diffusers) @ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.T2IAdapter, format=ModelFormat.Diffusers)
class GenericDiffusersLoader(ModelLoader): class GenericDiffusersLoader(ModelLoader):
"""Class to load simple diffusers models.""" """Class to load simple diffusers models."""
@ -28,9 +31,60 @@ class GenericDiffusersLoader(ModelLoader):
model_variant: Optional[ModelRepoVariant] = None, model_variant: Optional[ModelRepoVariant] = None,
submodel_type: Optional[SubModelType] = None, submodel_type: Optional[SubModelType] = None,
) -> AnyModel: ) -> AnyModel:
model_class = self._get_hf_load_class(model_path) model_class = self.get_hf_load_class(model_path)
if submodel_type is not None: if submodel_type is not None:
raise Exception(f"There are no submodels in models of type {model_class}") raise Exception(f"There are no submodels in models of type {model_class}")
variant = model_variant.value if model_variant else None variant = model_variant.value if model_variant else None
result: AnyModel = model_class.from_pretrained(model_path, torch_dtype=self._torch_dtype, variant=variant) # type: ignore result: AnyModel = model_class.from_pretrained(model_path, torch_dtype=self._torch_dtype, variant=variant) # type: ignore
return result return result
# TO DO: Add exception handling
def get_hf_load_class(self, model_path: Path, submodel_type: Optional[SubModelType] = None) -> ModelMixin:
"""Given the model path and submodel, returns the diffusers ModelMixin subclass needed to load."""
if submodel_type:
try:
config = self._load_diffusers_config(model_path, config_name="model_index.json")
module, class_name = config[submodel_type.value]
result = self._hf_definition_to_type(module=module, class_name=class_name)
except KeyError as e:
raise InvalidModelConfigException(
f'The "{submodel_type}" submodel is not available for this model.'
) from e
else:
try:
config = self._load_diffusers_config(model_path, config_name="config.json")
class_name = config.get("_class_name", None)
if class_name:
result = self._hf_definition_to_type(module="diffusers", class_name=class_name)
if config.get("model_type", None) == "clip_vision_model":
class_name = config.get("architectures")
assert class_name is not None
result = self._hf_definition_to_type(module="transformers", class_name=class_name[0])
if not class_name:
raise InvalidModelConfigException("Unable to decifer Load Class based on given config.json")
except KeyError as e:
raise InvalidModelConfigException("An expected config.json file is missing from this model.") from e
return result
# TO DO: Add exception handling
def _hf_definition_to_type(self, module: str, class_name: str) -> ModelMixin: # fix with correct type
if module in ["diffusers", "transformers"]:
res_type = sys.modules[module]
else:
res_type = sys.modules["diffusers"].pipelines
result: ModelMixin = getattr(res_type, class_name)
return result
def _load_diffusers_config(self, model_path: Path, config_name: str = "config.json") -> Dict[str, Any]:
return ConfigLoader.load_config(model_path, config_name=config_name)
class ConfigLoader(ConfigMixin):
"""Subclass of ConfigMixin for loading diffusers configuration files."""
@classmethod
def load_config(cls, *args: Any, **kwargs: Any) -> Dict[str, Any]:
"""Load a diffusrs ConfigMixin configuration."""
cls.config_name = kwargs.pop("config_name")
# Diffusers doesn't provide typing info
return super().load_config(*args, **kwargs) # type: ignore

View File

@ -15,11 +15,10 @@ from invokeai.backend.model_manager import (
ModelType, ModelType,
SubModelType, SubModelType,
) )
from invokeai.backend.model_manager.load.load_base import AnyModelLoader from invokeai.backend.model_manager.load import ModelLoader, ModelLoaderRegistry
from invokeai.backend.model_manager.load.load_default import ModelLoader
@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.IPAdapter, format=ModelFormat.InvokeAI) @ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.IPAdapter, format=ModelFormat.InvokeAI)
class IPAdapterInvokeAILoader(ModelLoader): class IPAdapterInvokeAILoader(ModelLoader):
"""Class to load IP Adapter diffusers models.""" """Class to load IP Adapter diffusers models."""

View File

@ -18,13 +18,13 @@ from invokeai.backend.model_manager import (
SubModelType, SubModelType,
) )
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
from invokeai.backend.model_manager.load.load_base import AnyModelLoader
from invokeai.backend.model_manager.load.load_default import ModelLoader
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
from .. import ModelLoader, ModelLoaderRegistry
@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.Lora, format=ModelFormat.Diffusers)
@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.Lora, format=ModelFormat.Lycoris) @ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.Lora, format=ModelFormat.Diffusers)
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.Lora, format=ModelFormat.Lycoris)
class LoraLoader(ModelLoader): class LoraLoader(ModelLoader):
"""Class to load LoRA models.""" """Class to load LoRA models."""

View File

@ -13,13 +13,14 @@ from invokeai.backend.model_manager import (
ModelType, ModelType,
SubModelType, SubModelType,
) )
from invokeai.backend.model_manager.load.load_base import AnyModelLoader
from invokeai.backend.model_manager.load.load_default import ModelLoader from .. import ModelLoaderRegistry
from .generic_diffusers import GenericDiffusersLoader
@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.ONNX, format=ModelFormat.Onnx) @ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.ONNX, format=ModelFormat.Onnx)
@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.ONNX, format=ModelFormat.Olive) @ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.ONNX, format=ModelFormat.Olive)
class OnnyxDiffusersModel(ModelLoader): class OnnyxDiffusersModel(GenericDiffusersLoader):
"""Class to load onnx models.""" """Class to load onnx models."""
def _load_model( def _load_model(
@ -30,7 +31,7 @@ class OnnyxDiffusersModel(ModelLoader):
) -> AnyModel: ) -> AnyModel:
if not submodel_type is not None: if not submodel_type is not None:
raise Exception("A submodel type must be provided when loading onnx pipelines.") raise Exception("A submodel type must be provided when loading onnx pipelines.")
load_class = self._get_hf_load_class(model_path, submodel_type) load_class = self.get_hf_load_class(model_path, submodel_type)
variant = model_variant.value if model_variant else None variant = model_variant.value if model_variant else None
model_path = model_path / submodel_type.value model_path = model_path / submodel_type.value
result: AnyModel = load_class.from_pretrained( result: AnyModel = load_class.from_pretrained(

View File

@ -19,13 +19,14 @@ from invokeai.backend.model_manager import (
) )
from invokeai.backend.model_manager.config import MainCheckpointConfig from invokeai.backend.model_manager.config import MainCheckpointConfig
from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_ckpt_to_diffusers from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_ckpt_to_diffusers
from invokeai.backend.model_manager.load.load_base import AnyModelLoader
from invokeai.backend.model_manager.load.load_default import ModelLoader from .. import ModelLoaderRegistry
from .generic_diffusers import GenericDiffusersLoader
@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.Main, format=ModelFormat.Diffusers) @ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.Main, format=ModelFormat.Diffusers)
@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.Main, format=ModelFormat.Checkpoint) @ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.Main, format=ModelFormat.Checkpoint)
class StableDiffusionDiffusersModel(ModelLoader): class StableDiffusionDiffusersModel(GenericDiffusersLoader):
"""Class to load main models.""" """Class to load main models."""
model_base_to_model_type = { model_base_to_model_type = {
@ -43,7 +44,7 @@ class StableDiffusionDiffusersModel(ModelLoader):
) -> AnyModel: ) -> AnyModel:
if not submodel_type is not None: if not submodel_type is not None:
raise Exception("A submodel type must be provided when loading main pipelines.") raise Exception("A submodel type must be provided when loading main pipelines.")
load_class = self._get_hf_load_class(model_path, submodel_type) load_class = self.get_hf_load_class(model_path, submodel_type)
variant = model_variant.value if model_variant else None variant = model_variant.value if model_variant else None
model_path = model_path / submodel_type.value model_path = model_path / submodel_type.value
result: AnyModel = load_class.from_pretrained( result: AnyModel = load_class.from_pretrained(

View File

@ -5,7 +5,6 @@
from pathlib import Path from pathlib import Path
from typing import Optional, Tuple from typing import Optional, Tuple
from invokeai.backend.textual_inversion import TextualInversionModelRaw
from invokeai.backend.model_manager import ( from invokeai.backend.model_manager import (
AnyModel, AnyModel,
AnyModelConfig, AnyModelConfig,
@ -15,12 +14,15 @@ from invokeai.backend.model_manager import (
ModelType, ModelType,
SubModelType, SubModelType,
) )
from invokeai.backend.model_manager.load.load_base import AnyModelLoader from invokeai.backend.textual_inversion import TextualInversionModelRaw
from invokeai.backend.model_manager.load.load_default import ModelLoader
from .. import ModelLoader, ModelLoaderRegistry
@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.TextualInversion, format=ModelFormat.EmbeddingFile) @ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.TextualInversion, format=ModelFormat.EmbeddingFile)
@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.TextualInversion, format=ModelFormat.EmbeddingFolder) @ModelLoaderRegistry.register(
base=BaseModelType.Any, type=ModelType.TextualInversion, format=ModelFormat.EmbeddingFolder
)
class TextualInversionLoader(ModelLoader): class TextualInversionLoader(ModelLoader):
"""Class to load TI models.""" """Class to load TI models."""

View File

@ -14,14 +14,14 @@ from invokeai.backend.model_manager import (
ModelType, ModelType,
) )
from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_ldm_vae_to_diffusers from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_ldm_vae_to_diffusers
from invokeai.backend.model_manager.load.load_base import AnyModelLoader
from .. import ModelLoaderRegistry
from .generic_diffusers import GenericDiffusersLoader from .generic_diffusers import GenericDiffusersLoader
@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.Vae, format=ModelFormat.Diffusers) @ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.Vae, format=ModelFormat.Diffusers)
@AnyModelLoader.register(base=BaseModelType.StableDiffusion1, type=ModelType.Vae, format=ModelFormat.Checkpoint) @ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion1, type=ModelType.Vae, format=ModelFormat.Checkpoint)
@AnyModelLoader.register(base=BaseModelType.StableDiffusion2, type=ModelType.Vae, format=ModelFormat.Checkpoint) @ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion2, type=ModelType.Vae, format=ModelFormat.Checkpoint)
class VaeLoader(GenericDiffusersLoader): class VaeLoader(GenericDiffusersLoader):
"""Class to load VAE models.""" """Class to load VAE models."""

View File

@ -1,16 +1,16 @@
from contextlib import contextmanager from contextlib import contextmanager
from typing import Any, Generator
import torch import torch
def _no_op(*args, **kwargs): def _no_op(*args: Any, **kwargs: Any) -> None:
pass pass
@contextmanager @contextmanager
def skip_torch_weight_init(): def skip_torch_weight_init() -> Generator[None, None, None]:
"""A context manager that monkey-patches several of the common torch layers (torch.nn.Linear, torch.nn.Conv1d, etc.) """Monkey patch several of the common torch layers (torch.nn.Linear, torch.nn.Conv1d, etc.) to skip weight initialization.
to skip weight initialization.
By default, `torch.nn.Linear` and `torch.nn.ConvNd` layers initialize their weights (according to a particular By default, `torch.nn.Linear` and `torch.nn.ConvNd` layers initialize their weights (according to a particular
distribution) when __init__ is called. This weight initialization step can take a significant amount of time, and is distribution) when __init__ is called. This weight initialization step can take a significant amount of time, and is
@ -18,13 +18,14 @@ def skip_torch_weight_init():
monkey-patches common torch layers to skip the weight initialization step. monkey-patches common torch layers to skip the weight initialization step.
""" """
torch_modules = [torch.nn.Linear, torch.nn.modules.conv._ConvNd, torch.nn.Embedding] torch_modules = [torch.nn.Linear, torch.nn.modules.conv._ConvNd, torch.nn.Embedding]
saved_functions = [m.reset_parameters for m in torch_modules] saved_functions = [hasattr(m, "reset_parameters") and m.reset_parameters for m in torch_modules]
try: try:
for torch_module in torch_modules: for torch_module in torch_modules:
assert hasattr(torch_module, "reset_parameters")
torch_module.reset_parameters = _no_op torch_module.reset_parameters = _no_op
yield None yield None
finally: finally:
for torch_module, saved_function in zip(torch_modules, saved_functions, strict=True): for torch_module, saved_function in zip(torch_modules, saved_functions, strict=True):
assert hasattr(torch_module, "reset_parameters")
torch_module.reset_parameters = saved_function torch_module.reset_parameters = saved_function

View File

@ -13,7 +13,7 @@ from typing import Any, List, Optional, Set
import torch import torch
from diffusers import AutoPipelineForText2Image from diffusers import AutoPipelineForText2Image
from diffusers import logging as dlogging from diffusers.utils import logging as dlogging
from invokeai.app.services.model_install import ModelInstallServiceBase from invokeai.app.services.model_install import ModelInstallServiceBase
from invokeai.backend.util.devices import choose_torch_device, torch_dtype from invokeai.backend.util.devices import choose_torch_device, torch_dtype
@ -76,7 +76,7 @@ class ModelMerger(object):
custom_pipeline="checkpoint_merger", custom_pipeline="checkpoint_merger",
torch_dtype=dtype, torch_dtype=dtype,
variant=variant, variant=variant,
) ) # type: ignore
merged_pipe = pipe.merge( merged_pipe = pipe.merge(
pretrained_model_name_or_path_list=model_paths, pretrained_model_name_or_path_list=model_paths,
alpha=alpha, alpha=alpha,

View File

@ -54,8 +54,8 @@ class LicenseRestrictions(BaseModel):
AllowDifferentLicense: bool = Field( AllowDifferentLicense: bool = Field(
description="if true, derivatives of this model be redistributed under a different license", default=False description="if true, derivatives of this model be redistributed under a different license", default=False
) )
AllowCommercialUse: CommercialUsage = Field( AllowCommercialUse: Optional[CommercialUsage] = Field(
description="Type of commercial use allowed or 'No' if no commercial use is allowed.", default_factory=set description="Type of commercial use allowed or 'No' if no commercial use is allowed.", default=None
) )
@ -139,6 +139,9 @@ class CivitaiMetadata(ModelMetadataWithFiles):
@property @property
def allow_commercial_use(self) -> bool: def allow_commercial_use(self) -> bool:
"""Return True if commercial use is allowed.""" """Return True if commercial use is allowed."""
if self.restrictions.AllowCommercialUse is None:
return False
else:
return self.restrictions.AllowCommercialUse != CommercialUsage("None") return self.restrictions.AllowCommercialUse != CommercialUsage("None")
@property @property

View File

@ -8,7 +8,6 @@ import torch
from picklescan.scanner import scan_file_path from picklescan.scanner import scan_file_path
import invokeai.backend.util.logging as logger import invokeai.backend.util.logging as logger
from .util.model_util import lora_token_vector_length, read_checkpoint_meta
from invokeai.backend.util.util import SilenceWarnings from invokeai.backend.util.util import SilenceWarnings
from .config import ( from .config import (
@ -23,6 +22,7 @@ from .config import (
SchedulerPredictionType, SchedulerPredictionType,
) )
from .hash import FastModelHash from .hash import FastModelHash
from .util.model_util import lora_token_vector_length, read_checkpoint_meta
CkptType = Dict[str, Any] CkptType = Dict[str, Any]
@ -53,6 +53,7 @@ LEGACY_CONFIGS: Dict[BaseModelType, Dict[ModelVariantType, Union[str, Dict[Sched
}, },
} }
class ProbeBase(object): class ProbeBase(object):
"""Base class for probes.""" """Base class for probes."""

View File

@ -116,9 +116,9 @@ class ModelSearch(ModelSearchBase):
# returns all models that have 'anime' in the path # returns all models that have 'anime' in the path
""" """
models_found: Optional[Set[Path]] = Field(default=None) models_found: Set[Path] = Field(default_factory=set)
scanned_dirs: Optional[Set[Path]] = Field(default=None) scanned_dirs: Set[Path] = Field(default_factory=set)
pruned_paths: Optional[Set[Path]] = Field(default=None) pruned_paths: Set[Path] = Field(default_factory=set)
def search_started(self) -> None: def search_started(self) -> None:
self.models_found = set() self.models_found = set()

View File

@ -35,7 +35,7 @@ class Struct_mallinfo2(ctypes.Structure):
("keepcost", ctypes.c_size_t), ("keepcost", ctypes.c_size_t),
] ]
def __str__(self): def __str__(self) -> str:
s = "" s = ""
s += f"{'arena': <10}= {(self.arena/2**30):15.5f} # Non-mmapped space allocated (GB) (uordblks + fordblks)\n" s += f"{'arena': <10}= {(self.arena/2**30):15.5f} # Non-mmapped space allocated (GB) (uordblks + fordblks)\n"
s += f"{'ordblks': <10}= {(self.ordblks): >15} # Number of free chunks\n" s += f"{'ordblks': <10}= {(self.ordblks): >15} # Number of free chunks\n"
@ -62,7 +62,7 @@ class LibcUtil:
TODO: Improve cross-OS compatibility of this class. TODO: Improve cross-OS compatibility of this class.
""" """
def __init__(self): def __init__(self) -> None:
self._libc = ctypes.cdll.LoadLibrary("libc.so.6") self._libc = ctypes.cdll.LoadLibrary("libc.so.6")
def mallinfo2(self) -> Struct_mallinfo2: def mallinfo2(self) -> Struct_mallinfo2:
@ -72,4 +72,5 @@ class LibcUtil:
""" """
mallinfo2 = self._libc.mallinfo2 mallinfo2 = self._libc.mallinfo2
mallinfo2.restype = Struct_mallinfo2 mallinfo2.restype = Struct_mallinfo2
return mallinfo2() result: Struct_mallinfo2 = mallinfo2()
return result

View File

@ -1,12 +1,15 @@
"""Utilities for parsing model files, used mostly by probe.py""" """Utilities for parsing model files, used mostly by probe.py"""
import json import json
import torch
from typing import Union
from pathlib import Path from pathlib import Path
from typing import Dict, Optional, Union
import safetensors
import torch
from picklescan.scanner import scan_file_path from picklescan.scanner import scan_file_path
def _fast_safetensors_reader(path: str):
def _fast_safetensors_reader(path: str) -> Dict[str, torch.Tensor]:
checkpoint = {} checkpoint = {}
device = torch.device("meta") device = torch.device("meta")
with open(path, "rb") as f: with open(path, "rb") as f:
@ -37,10 +40,12 @@ def _fast_safetensors_reader(path: str):
return checkpoint return checkpoint
def read_checkpoint_meta(path: Union[str, Path], scan: bool = False):
def read_checkpoint_meta(path: Union[str, Path], scan: bool = False) -> Dict[str, torch.Tensor]:
if str(path).endswith(".safetensors"): if str(path).endswith(".safetensors"):
try: try:
checkpoint = _fast_safetensors_reader(path) path_str = path.as_posix() if isinstance(path, Path) else path
checkpoint = _fast_safetensors_reader(path_str)
except Exception: except Exception:
# TODO: create issue for support "meta"? # TODO: create issue for support "meta"?
checkpoint = safetensors.torch.load_file(path, device="cpu") checkpoint = safetensors.torch.load_file(path, device="cpu")
@ -52,14 +57,15 @@ def read_checkpoint_meta(path: Union[str, Path], scan: bool = False):
checkpoint = torch.load(path, map_location=torch.device("meta")) checkpoint = torch.load(path, map_location=torch.device("meta"))
return checkpoint return checkpoint
def lora_token_vector_length(checkpoint: dict) -> int:
def lora_token_vector_length(checkpoint: Dict[str, torch.Tensor]) -> Optional[int]:
""" """
Given a checkpoint in memory, return the lora token vector length Given a checkpoint in memory, return the lora token vector length
:param checkpoint: The checkpoint :param checkpoint: The checkpoint
""" """
def _get_shape_1(key: str, tensor, checkpoint) -> int: def _get_shape_1(key: str, tensor: torch.Tensor, checkpoint: Dict[str, torch.Tensor]) -> Optional[int]:
lora_token_vector_length = None lora_token_vector_length = None
if "." not in key: if "." not in key:

View File

@ -8,6 +8,7 @@ import numpy as np
import onnx import onnx
from onnx import numpy_helper from onnx import numpy_helper
from onnxruntime import InferenceSession, SessionOptions, get_available_providers from onnxruntime import InferenceSession, SessionOptions, get_available_providers
from ..raw_model import RawModel from ..raw_model import RawModel
ONNX_WEIGHTS_NAME = "model.onnx" ONNX_WEIGHTS_NAME = "model.onnx"
@ -15,7 +16,7 @@ ONNX_WEIGHTS_NAME = "model.onnx"
# NOTE FROM LS: This was copied from Stalker's original implementation. # NOTE FROM LS: This was copied from Stalker's original implementation.
# I have not yet gone through and fixed all the type hints # I have not yet gone through and fixed all the type hints
class IAIOnnxRuntimeModel: class IAIOnnxRuntimeModel(RawModel):
class _tensor_access: class _tensor_access:
def __init__(self, model): # type: ignore def __init__(self, model): # type: ignore
self.model = model self.model = model

View File

@ -10,5 +10,6 @@ The term 'raw' was introduced to describe a wrapper around a torch.nn.Module
that adds additional methods and attributes. that adds additional methods and attributes.
""" """
class RawModel: class RawModel:
"""Base class for 'Raw' model wrappers.""" """Base class for 'Raw' model wrappers."""

View File

@ -1,10 +1,11 @@
from __future__ import annotations from __future__ import annotations
from contextlib import contextmanager from contextlib import contextmanager
from typing import List, Union from typing import Callable, List, Union
import torch.nn as nn import torch.nn as nn
from diffusers.models import AutoencoderKL, UNet2DConditionModel from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
def _conv_forward_asymmetric(self, input, weight, bias): def _conv_forward_asymmetric(self, input, weight, bias):
@ -26,51 +27,32 @@ def _conv_forward_asymmetric(self, input, weight, bias):
@contextmanager @contextmanager
def set_seamless(model: Union[UNet2DConditionModel, AutoencoderKL], seamless_axes: List[str]): def set_seamless(model: Union[UNet2DConditionModel, AutoencoderKL], seamless_axes: List[str]):
# Callable: (input: Tensor, weight: Tensor, bias: Optional[Tensor]) -> Tensor
to_restore: list[tuple[nn.Conv2d | nn.ConvTranspose2d, Callable]] = []
try: try:
to_restore = [] # Hard coded to skip down block layers, allowing for seamless tiling at the expense of prompt adherence
skipped_layers = 1
for m_name, m in model.named_modules(): for m_name, m in model.named_modules():
if isinstance(model, UNet2DConditionModel): if not isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
if ".attentions." in m_name:
continue continue
if ".resnets." in m_name: if isinstance(model, UNet2DConditionModel) and m_name.startswith("down_blocks.") and ".resnets." in m_name:
if ".conv2" in m_name: # down_blocks.1.resnets.1.conv1
continue _, block_num, _, resnet_num, submodule_name = m_name.split(".")
if ".conv_shortcut" in m_name: block_num = int(block_num)
resnet_num = int(resnet_num)
if block_num >= len(model.down_blocks) - skipped_layers:
continue continue
""" # Skip the second resnet (could be configurable)
if isinstance(model, UNet2DConditionModel): if resnet_num > 0:
if False and ".upsamplers." in m_name:
continue continue
if False and ".downsamplers." in m_name: # Skip Conv2d layers (could be configurable)
if submodule_name == "conv2":
continue continue
if True and ".resnets." in m_name:
if True and ".conv1" in m_name:
if False and "down_blocks" in m_name:
continue
if False and "mid_block" in m_name:
continue
if False and "up_blocks" in m_name:
continue
if True and ".conv2" in m_name:
continue
if True and ".conv_shortcut" in m_name:
continue
if True and ".attentions." in m_name:
continue
if False and m_name in ["conv_in", "conv_out"]:
continue
"""
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
m.asymmetric_padding_mode = {} m.asymmetric_padding_mode = {}
m.asymmetric_padding = {} m.asymmetric_padding = {}
m.asymmetric_padding_mode["x"] = "circular" if ("x" in seamless_axes) else "constant" m.asymmetric_padding_mode["x"] = "circular" if ("x" in seamless_axes) else "constant"

View File

@ -8,8 +8,10 @@ from compel.embeddings_provider import BaseTextualInversionManager
from safetensors.torch import load_file from safetensors.torch import load_file
from transformers import CLIPTokenizer from transformers import CLIPTokenizer
from typing_extensions import Self from typing_extensions import Self
from .raw_model import RawModel from .raw_model import RawModel
class TextualInversionModelRaw(RawModel): class TextualInversionModelRaw(RawModel):
embedding: torch.Tensor # [n, 768]|[n, 1280] embedding: torch.Tensor # [n, 768]|[n, 1280]
embedding_2: Optional[torch.Tensor] = None # [n, 768]|[n, 1280] - for SDXL models embedding_2: Optional[torch.Tensor] = None # [n, 768]|[n, 1280] - for SDXL models

View File

@ -42,7 +42,7 @@ def install_and_load_model(
# If the requested model is already installed, return its LoadedModel # If the requested model is already installed, return its LoadedModel
with contextlib.suppress(UnknownModelException): with contextlib.suppress(UnknownModelException):
# TODO: Replace with wrapper call # TODO: Replace with wrapper call
loaded_model: LoadedModel = model_manager.load.load_model_by_attr( loaded_model: LoadedModel = model_manager.load_model_by_attr(
model_name=model_name, base_model=base_model, model_type=model_type model_name=model_name, base_model=base_model, model_type=model_type
) )
return loaded_model return loaded_model
@ -53,7 +53,7 @@ def install_and_load_model(
assert job.complete assert job.complete
try: try:
loaded_model = model_manager.load.load_model_by_config(job.config_out) loaded_model = model_manager.load_model_by_config(job.config_out)
return loaded_model return loaded_model
except UnknownModelException as e: except UnknownModelException as e:
raise Exception( raise Exception(

View File

@ -4,18 +4,27 @@ Test model loading
from pathlib import Path from pathlib import Path
from invokeai.app.services.model_install import ModelInstallServiceBase from invokeai.app.services.model_manager import ModelManagerServiceBase
from invokeai.app.services.model_load import ModelLoadServiceBase
from invokeai.backend.textual_inversion import TextualInversionModelRaw from invokeai.backend.textual_inversion import TextualInversionModelRaw
from tests.backend.model_manager.model_manager_fixtures import * # noqa F403 from tests.backend.model_manager.model_manager_fixtures import * # noqa F403
def test_loading(mm2_installer: ModelInstallServiceBase, mm2_loader: ModelLoadServiceBase, embedding_file: Path):
store = mm2_installer.record_store def test_loading(mm2_model_manager: ModelManagerServiceBase, embedding_file: Path):
store = mm2_model_manager.store
matches = store.search_by_attr(model_name="test_embedding") matches = store.search_by_attr(model_name="test_embedding")
assert len(matches) == 0 assert len(matches) == 0
key = mm2_installer.register_path(embedding_file) key = mm2_model_manager.install.register_path(embedding_file)
loaded_model = mm2_loader.load_model_by_config(store.get_model(key)) loaded_model = mm2_model_manager.load_model_by_config(store.get_model(key))
assert loaded_model is not None assert loaded_model is not None
assert loaded_model.config.key == key assert loaded_model.config.key == key
with loaded_model as model: with loaded_model as model:
assert isinstance(model, TextualInversionModelRaw) assert isinstance(model, TextualInversionModelRaw)
loaded_model_2 = mm2_model_manager.load_model_by_key(key)
assert loaded_model.config.key == loaded_model_2.config.key
loaded_model_3 = mm2_model_manager.load_model_by_attr(
model_name=loaded_model.config.name,
model_type=loaded_model.config.type,
base_model=loaded_model.config.base,
)
assert loaded_model.config.key == loaded_model_3.config.key

View File

@ -6,17 +6,17 @@ from pathlib import Path
from typing import Any, Dict, List from typing import Any, Dict, List
import pytest import pytest
from pytest import FixtureRequest
from pydantic import BaseModel from pydantic import BaseModel
from pytest import FixtureRequest
from requests.sessions import Session from requests.sessions import Session
from requests_testadapter import TestAdapter, TestSession from requests_testadapter import TestAdapter, TestSession
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.download import DownloadQueueServiceBase, DownloadQueueService from invokeai.app.services.download import DownloadQueueService, DownloadQueueServiceBase
from invokeai.app.services.events.events_base import EventServiceBase from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.services.model_manager import ModelManagerServiceBase, ModelManagerService
from invokeai.app.services.model_load import ModelLoadServiceBase, ModelLoadService
from invokeai.app.services.model_install import ModelInstallService, ModelInstallServiceBase from invokeai.app.services.model_install import ModelInstallService, ModelInstallServiceBase
from invokeai.app.services.model_load import ModelLoadService, ModelLoadServiceBase
from invokeai.app.services.model_manager import ModelManagerService, ModelManagerServiceBase
from invokeai.app.services.model_metadata import ModelMetadataStoreBase, ModelMetadataStoreSQL from invokeai.app.services.model_metadata import ModelMetadataStoreBase, ModelMetadataStoreSQL
from invokeai.app.services.model_records import ModelRecordServiceBase, ModelRecordServiceSQL from invokeai.app.services.model_records import ModelRecordServiceBase, ModelRecordServiceSQL
from invokeai.backend.model_manager.config import ( from invokeai.backend.model_manager.config import (
@ -95,9 +95,7 @@ def mm2_app_config(mm2_root_dir: Path) -> InvokeAIAppConfig:
@pytest.fixture @pytest.fixture
def mm2_download_queue(mm2_session: Session, def mm2_download_queue(mm2_session: Session, request: FixtureRequest) -> DownloadQueueServiceBase:
request: FixtureRequest
) -> DownloadQueueServiceBase:
download_queue = DownloadQueueService(requests_session=mm2_session) download_queue = DownloadQueueService(requests_session=mm2_session)
download_queue.start() download_queue.start()
@ -107,26 +105,30 @@ def mm2_download_queue(mm2_session: Session,
request.addfinalizer(stop_queue) request.addfinalizer(stop_queue)
return download_queue return download_queue
@pytest.fixture @pytest.fixture
def mm2_metadata_store(mm2_record_store: ModelRecordServiceSQL) -> ModelMetadataStoreBase: def mm2_metadata_store(mm2_record_store: ModelRecordServiceSQL) -> ModelMetadataStoreBase:
return mm2_record_store.metadata_store return mm2_record_store.metadata_store
@pytest.fixture @pytest.fixture
def mm2_loader(mm2_app_config: InvokeAIAppConfig, mm2_record_store: ModelRecordServiceBase) -> ModelLoadServiceBase: def mm2_loader(mm2_app_config: InvokeAIAppConfig, mm2_record_store: ModelRecordServiceBase) -> ModelLoadServiceBase:
ram_cache = ModelCache( ram_cache = ModelCache(
logger=InvokeAILogger.get_logger(), logger=InvokeAILogger.get_logger(),
max_cache_size=mm2_app_config.ram_cache_size, max_cache_size=mm2_app_config.ram_cache_size,
max_vram_cache_size=mm2_app_config.vram_cache_size max_vram_cache_size=mm2_app_config.vram_cache_size,
) )
convert_cache = ModelConvertCache(mm2_app_config.models_convert_cache_path) convert_cache = ModelConvertCache(mm2_app_config.models_convert_cache_path)
return ModelLoadService(app_config=mm2_app_config, return ModelLoadService(
record_store=mm2_record_store, app_config=mm2_app_config,
ram_cache=ram_cache, ram_cache=ram_cache,
convert_cache=convert_cache, convert_cache=convert_cache,
) )
@pytest.fixture @pytest.fixture
def mm2_installer(mm2_app_config: InvokeAIAppConfig, def mm2_installer(
mm2_app_config: InvokeAIAppConfig,
mm2_download_queue: DownloadQueueServiceBase, mm2_download_queue: DownloadQueueServiceBase,
mm2_session: Session, mm2_session: Session,
request: FixtureRequest, request: FixtureRequest,
@ -213,15 +215,13 @@ def mm2_record_store(mm2_app_config: InvokeAIAppConfig) -> ModelRecordServiceBas
store.add_model("test_config_5", raw5) store.add_model("test_config_5", raw5)
return store return store
@pytest.fixture @pytest.fixture
def mm2_model_manager(mm2_record_store: ModelRecordServiceBase, def mm2_model_manager(
mm2_installer: ModelInstallServiceBase, mm2_record_store: ModelRecordServiceBase, mm2_installer: ModelInstallServiceBase, mm2_loader: ModelLoadServiceBase
mm2_loader: ModelLoadServiceBase) -> ModelManagerServiceBase: ) -> ModelManagerServiceBase:
return ModelManagerService( return ModelManagerService(store=mm2_record_store, install=mm2_installer, load=mm2_loader)
store=mm2_record_store,
install=mm2_installer,
load=mm2_loader
)
@pytest.fixture @pytest.fixture
def mm2_session(embedding_file: Path, diffusers_dir: Path) -> Session: def mm2_session(embedding_file: Path, diffusers_dir: Path) -> Session:
@ -306,5 +306,3 @@ def mm2_session(embedding_file: Path, diffusers_dir: Path) -> Session:
), ),
) )
return sess return sess

View File

@ -5,8 +5,8 @@
import pytest import pytest
import torch import torch
from invokeai.backend.model_patcher import ModelPatcher
from invokeai.backend.lora import LoRALayer, LoRAModelRaw from invokeai.backend.lora import LoRALayer, LoRAModelRaw
from invokeai.backend.model_patcher import ModelPatcher
@pytest.mark.parametrize( @pytest.mark.parametrize(

View File

@ -1,7 +1,8 @@
import pytest import pytest
from invokeai.backend.model_manager.util.libc_util import Struct_mallinfo2
from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff
from invokeai.backend.model_manager.util.libc_util import Struct_mallinfo2
def test_memory_snapshot_capture(): def test_memory_snapshot_capture():
"""Smoke test of MemorySnapshot.capture().""" """Smoke test of MemorySnapshot.capture()."""