add router API support for model manager heuristic_import()`

This commit is contained in:
Lincoln Stein 2023-06-23 16:35:39 -04:00
parent 54b74427f4
commit 466ec3ab5e
6 changed files with 74 additions and 19 deletions

View File

@ -1,13 +1,13 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) and 2023 Kent Keirsey (https://github.com/hipsterusername) # Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) and 2023 Kent Keirsey (https://github.com/hipsterusername)
from typing import Annotated, Literal, Optional, Union, Dict from typing import Literal, Optional, Union
from fastapi import Query from fastapi import Query
from fastapi.routing import APIRouter, HTTPException from fastapi.routing import APIRouter, HTTPException
from pydantic import BaseModel, Field, parse_obj_as from pydantic import BaseModel, Field, parse_obj_as
from ..dependencies import ApiDependencies from ..dependencies import ApiDependencies
from invokeai.backend import BaseModelType, ModelType from invokeai.backend import BaseModelType, ModelType
from invokeai.backend.model_management.models import OPENAPI_MODEL_CONFIGS from invokeai.backend.model_management.models import OPENAPI_MODEL_CONFIGS, SchedulerPredictionType
MODEL_CONFIGS = Union[tuple(OPENAPI_MODEL_CONFIGS)] MODEL_CONFIGS = Union[tuple(OPENAPI_MODEL_CONFIGS)]
models_router = APIRouter(prefix="/v1/models", tags=["models"]) models_router = APIRouter(prefix="/v1/models", tags=["models"])
@ -51,12 +51,15 @@ class CreateModelResponse(BaseModel):
info: Union[CkptModelInfo, DiffusersModelInfo] = Field(discriminator="format", description="The model info") info: Union[CkptModelInfo, DiffusersModelInfo] = Field(discriminator="format", description="The model info")
status: str = Field(description="The status of the API response") status: str = Field(description="The status of the API response")
class ImportModelRequest(BaseModel):
name: str = Field(description="A model path, repo_id or URL to import")
prediction_type: Optional[Literal['epsilon','v_prediction','sample']] = Field(description='Prediction type for SDv2 checkpoint files')
class ConversionRequest(BaseModel): class ConversionRequest(BaseModel):
name: str = Field(description="The name of the new model") name: str = Field(description="The name of the new model")
info: CkptModelInfo = Field(description="The converted model info") info: CkptModelInfo = Field(description="The converted model info")
save_location: str = Field(description="The path to save the converted model weights") save_location: str = Field(description="The path to save the converted model weights")
class ConvertedModelResponse(BaseModel): class ConvertedModelResponse(BaseModel):
name: str = Field(description="The name of the new model") name: str = Field(description="The name of the new model")
info: DiffusersModelInfo = Field(description="The converted model info") info: DiffusersModelInfo = Field(description="The converted model info")
@ -105,6 +108,28 @@ async def update_model(
return model_response return model_response
@models_router.post(
"/",
operation_id="import_model",
responses={200: {"status": "success"}},
)
async def import_model(
model_request: ImportModelRequest
) -> None:
""" Add Model """
items_to_import = set([model_request.name])
prediction_types = { x.value: x for x in SchedulerPredictionType }
logger = ApiDependencies.invoker.services.logger
installed_models = ApiDependencies.invoker.services.model_manager.heuristic_import(
items_to_import = items_to_import,
prediction_type_helper = lambda x: prediction_types.get(model_request.prediction_type)
)
if len(installed_models) > 0:
logger.info(f'Successfully imported {model_request.name}')
else:
logger.error(f'Model {model_request.name} not imported')
raise HTTPException(status_code=500, detail=f'Model {model_request.name} not imported')
@models_router.delete( @models_router.delete(
"/{model_name}", "/{model_name}",

View File

@ -93,9 +93,10 @@ class ModelInstall(object):
def __init__(self, def __init__(self,
config:InvokeAIAppConfig, config:InvokeAIAppConfig,
prediction_type_helper: Callable[[Path],SchedulerPredictionType]=None, prediction_type_helper: Callable[[Path],SchedulerPredictionType]=None,
model_manager: ModelManager = None,
access_token:str = None): access_token:str = None):
self.config = config self.config = config
self.mgr = ModelManager(config.model_conf_path) self.mgr = model_manager or ModelManager(config.model_conf_path)
self.datasets = OmegaConf.load(Dataset_path) self.datasets = OmegaConf.load(Dataset_path)
self.prediction_helper = prediction_type_helper self.prediction_helper = prediction_type_helper
self.access_token = access_token or HfFolder.get_token() self.access_token = access_token or HfFolder.get_token()

View File

@ -151,13 +151,11 @@ import os
import hashlib import hashlib
import textwrap import textwrap
from dataclasses import dataclass from dataclasses import dataclass
from packaging import version
from pathlib import Path from pathlib import Path
from typing import Dict, Optional, List, Tuple, Union, types from typing import Optional, List, Tuple, Union, Set, Callable, types
from shutil import rmtree from shutil import rmtree
import torch import torch
from huggingface_hub import scan_cache_dir
from omegaconf import OmegaConf from omegaconf import OmegaConf
from omegaconf.dictconfig import DictConfig from omegaconf.dictconfig import DictConfig
@ -165,9 +163,13 @@ from pydantic import BaseModel
import invokeai.backend.util.logging as logger import invokeai.backend.util.logging as logger
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.util import CUDA_DEVICE, download_with_resume from invokeai.backend.util import CUDA_DEVICE
from .model_cache import ModelCache, ModelLocker from .model_cache import ModelCache, ModelLocker
from .models import BaseModelType, ModelType, SubModelType, ModelError, MODEL_CLASSES from .models import (
BaseModelType, ModelType, SubModelType,
ModelError, SchedulerPredictionType, MODEL_CLASSES,
ModelConfigBase,
)
# We are only starting to number the config file with release 3. # We are only starting to number the config file with release 3.
# The config file version doesn't have to start at release version, but it will help # The config file version doesn't have to start at release version, but it will help
@ -686,3 +688,34 @@ class ModelManager(object):
if new_models_found and self.config_path: if new_models_found and self.config_path:
self.commit() self.commit()
def heuristic_import(self,
items_to_import: Set[str],
prediction_type_helper: Callable[[Path],SchedulerPredictionType]=None,
)->Set[str]:
'''
Import a list of paths, repo_ids or URLs. Returns the
set of successfully imported items. The prediction_type_helper
is a callback that receives the Path of a checkpoint or diffusers
model and returns a SchedulerPredictionType (or None).
'''
# avoid circular import here
from invokeai.backend.install.model_install_backend import ModelInstall
successfully_installed = set()
installer = ModelInstall(config = self.globals,
prediction_type_helper = prediction_type_helper,
model_manager = self)
for thing in items_to_import:
try:
installer.heuristic_install(thing)
successfully_installed.add(thing)
except Exception as e:
self.logger.warning(f'{thing} could not be imported: {str(e)}')
self.commit()
return successfully_installed

View File

@ -1,17 +1,14 @@
import json import json
import traceback
import torch import torch
import safetensors.torch import safetensors.torch
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum
from diffusers import ModelMixin, ConfigMixin, StableDiffusionPipeline, AutoencoderKL, ControlNetModel from diffusers import ModelMixin, ConfigMixin
from pathlib import Path from pathlib import Path
from typing import Callable, Literal, Union, Dict from typing import Callable, Literal, Union, Dict
from picklescan.scanner import scan_file_path from picklescan.scanner import scan_file_path
import invokeai.backend.util.logging as logger
from .models import BaseModelType, ModelType, ModelVariantType, SchedulerPredictionType, SilenceWarnings from .models import BaseModelType, ModelType, ModelVariantType, SchedulerPredictionType, SilenceWarnings
@dataclass @dataclass
@ -102,7 +99,7 @@ class ModelProbe(object):
and prediction_type==SchedulerPredictionType.VPrediction \ and prediction_type==SchedulerPredictionType.VPrediction \
) else 512, ) else 512,
) )
except Exception as e: except Exception:
return None return None
return model_info return model_info
@ -115,6 +112,9 @@ class ModelProbe(object):
return ModelType.TextualInversion return ModelType.TextualInversion
checkpoint = checkpoint or cls._scan_and_load_checkpoint(model_path) checkpoint = checkpoint or cls._scan_and_load_checkpoint(model_path)
state_dict = checkpoint.get("state_dict") or checkpoint state_dict = checkpoint.get("state_dict") or checkpoint
if len(checkpoint) < 10 and all(isinstance(v, torch.Tensor) for v in checkpoint.values()):
return ModelType.TextualInversion
if any([x.startswith("model.diffusion_model") for x in state_dict.keys()]): if any([x.startswith("model.diffusion_model") for x in state_dict.keys()]):
return ModelType.Pipeline return ModelType.Pipeline
if any([x.startswith("encoder.conv_in") for x in state_dict.keys()]): if any([x.startswith("encoder.conv_in") for x in state_dict.keys()]):
@ -326,13 +326,9 @@ class PipelineFolderProbe(FolderProbeBase):
def get_base_type(self)->BaseModelType: def get_base_type(self)->BaseModelType:
if self.model: if self.model:
unet_conf = self.model.unet.config unet_conf = self.model.unet.config
scheduler_conf = self.model.scheduler.config
else: else:
with open(self.folder_path / 'unet' / 'config.json','r') as file: with open(self.folder_path / 'unet' / 'config.json','r') as file:
unet_conf = json.load(file) unet_conf = json.load(file)
with open(self.folder_path / 'scheduler' / 'scheduler_config.json','r') as file:
scheduler_conf = json.load(file)
if unet_conf['cross_attention_dim'] == 768: if unet_conf['cross_attention_dim'] == 768:
return BaseModelType.StableDiffusion1 return BaseModelType.StableDiffusion1
elif unet_conf['cross_attention_dim'] == 1024: elif unet_conf['cross_attention_dim'] == 1024:

View File

@ -56,7 +56,6 @@ class ModelConfigBase(BaseModel):
class Config: class Config:
use_enum_values = True use_enum_values = True
class EmptyConfigLoader(ConfigMixin): class EmptyConfigLoader(ConfigMixin):
@classmethod @classmethod
def load_config(cls, *args, **kwargs): def load_config(cls, *args, **kwargs):

View File

@ -45,6 +45,7 @@ sd-1/pipeline/portraitplus:
repo_id: wavymulder/portraitplus repo_id: wavymulder/portraitplus
recommended: False recommended: False
sd-1/pipeline/seek.art_MEGA: sd-1/pipeline/seek.art_MEGA:
repo_id: coreco/seek.art_MEGA
description: A general use SD-1.5 "anything" model that supports multiple styles (2.1 GB) description: A general use SD-1.5 "anything" model that supports multiple styles (2.1 GB)
recommended: False recommended: False
sd-1/pipeline/trinart_stable_diffusion_v2: sd-1/pipeline/trinart_stable_diffusion_v2: