add router API support for model manager heuristic_import()`

This commit is contained in:
Lincoln Stein 2023-06-23 16:35:39 -04:00
parent 54b74427f4
commit 466ec3ab5e
6 changed files with 74 additions and 19 deletions

View File

@ -1,13 +1,13 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) and 2023 Kent Keirsey (https://github.com/hipsterusername)
from typing import Annotated, Literal, Optional, Union, Dict
from typing import Literal, Optional, Union
from fastapi import Query
from fastapi.routing import APIRouter, HTTPException
from pydantic import BaseModel, Field, parse_obj_as
from ..dependencies import ApiDependencies
from invokeai.backend import BaseModelType, ModelType
from invokeai.backend.model_management.models import OPENAPI_MODEL_CONFIGS
from invokeai.backend.model_management.models import OPENAPI_MODEL_CONFIGS, SchedulerPredictionType
MODEL_CONFIGS = Union[tuple(OPENAPI_MODEL_CONFIGS)]
models_router = APIRouter(prefix="/v1/models", tags=["models"])
@ -51,11 +51,14 @@ class CreateModelResponse(BaseModel):
info: Union[CkptModelInfo, DiffusersModelInfo] = Field(discriminator="format", description="The model info")
status: str = Field(description="The status of the API response")
class ImportModelRequest(BaseModel):
name: str = Field(description="A model path, repo_id or URL to import")
prediction_type: Optional[Literal['epsilon','v_prediction','sample']] = Field(description='Prediction type for SDv2 checkpoint files')
class ConversionRequest(BaseModel):
name: str = Field(description="The name of the new model")
info: CkptModelInfo = Field(description="The converted model info")
save_location: str = Field(description="The path to save the converted model weights")
class ConvertedModelResponse(BaseModel):
name: str = Field(description="The name of the new model")
@ -105,6 +108,28 @@ async def update_model(
return model_response
@models_router.post(
"/",
operation_id="import_model",
responses={200: {"status": "success"}},
)
async def import_model(
model_request: ImportModelRequest
) -> None:
""" Add Model """
items_to_import = set([model_request.name])
prediction_types = { x.value: x for x in SchedulerPredictionType }
logger = ApiDependencies.invoker.services.logger
installed_models = ApiDependencies.invoker.services.model_manager.heuristic_import(
items_to_import = items_to_import,
prediction_type_helper = lambda x: prediction_types.get(model_request.prediction_type)
)
if len(installed_models) > 0:
logger.info(f'Successfully imported {model_request.name}')
else:
logger.error(f'Model {model_request.name} not imported')
raise HTTPException(status_code=500, detail=f'Model {model_request.name} not imported')
@models_router.delete(
"/{model_name}",

View File

@ -93,9 +93,10 @@ class ModelInstall(object):
def __init__(self,
config:InvokeAIAppConfig,
prediction_type_helper: Callable[[Path],SchedulerPredictionType]=None,
model_manager: ModelManager = None,
access_token:str = None):
self.config = config
self.mgr = ModelManager(config.model_conf_path)
self.mgr = model_manager or ModelManager(config.model_conf_path)
self.datasets = OmegaConf.load(Dataset_path)
self.prediction_helper = prediction_type_helper
self.access_token = access_token or HfFolder.get_token()

View File

@ -151,13 +151,11 @@ import os
import hashlib
import textwrap
from dataclasses import dataclass
from packaging import version
from pathlib import Path
from typing import Dict, Optional, List, Tuple, Union, types
from typing import Optional, List, Tuple, Union, Set, Callable, types
from shutil import rmtree
import torch
from huggingface_hub import scan_cache_dir
from omegaconf import OmegaConf
from omegaconf.dictconfig import DictConfig
@ -165,9 +163,13 @@ from pydantic import BaseModel
import invokeai.backend.util.logging as logger
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.util import CUDA_DEVICE, download_with_resume
from invokeai.backend.util import CUDA_DEVICE
from .model_cache import ModelCache, ModelLocker
from .models import BaseModelType, ModelType, SubModelType, ModelError, MODEL_CLASSES
from .models import (
BaseModelType, ModelType, SubModelType,
ModelError, SchedulerPredictionType, MODEL_CLASSES,
ModelConfigBase,
)
# We are only starting to number the config file with release 3.
# The config file version doesn't have to start at release version, but it will help
@ -686,3 +688,34 @@ class ModelManager(object):
if new_models_found and self.config_path:
self.commit()
def heuristic_import(self,
items_to_import: Set[str],
prediction_type_helper: Callable[[Path],SchedulerPredictionType]=None,
)->Set[str]:
'''
Import a list of paths, repo_ids or URLs. Returns the
set of successfully imported items. The prediction_type_helper
is a callback that receives the Path of a checkpoint or diffusers
model and returns a SchedulerPredictionType (or None).
'''
# avoid circular import here
from invokeai.backend.install.model_install_backend import ModelInstall
successfully_installed = set()
installer = ModelInstall(config = self.globals,
prediction_type_helper = prediction_type_helper,
model_manager = self)
for thing in items_to_import:
try:
installer.heuristic_install(thing)
successfully_installed.add(thing)
except Exception as e:
self.logger.warning(f'{thing} could not be imported: {str(e)}')
self.commit()
return successfully_installed

View File

@ -1,17 +1,14 @@
import json
import traceback
import torch
import safetensors.torch
from dataclasses import dataclass
from enum import Enum
from diffusers import ModelMixin, ConfigMixin, StableDiffusionPipeline, AutoencoderKL, ControlNetModel
from diffusers import ModelMixin, ConfigMixin
from pathlib import Path
from typing import Callable, Literal, Union, Dict
from picklescan.scanner import scan_file_path
import invokeai.backend.util.logging as logger
from .models import BaseModelType, ModelType, ModelVariantType, SchedulerPredictionType, SilenceWarnings
@dataclass
@ -102,7 +99,7 @@ class ModelProbe(object):
and prediction_type==SchedulerPredictionType.VPrediction \
) else 512,
)
except Exception as e:
except Exception:
return None
return model_info
@ -115,6 +112,9 @@ class ModelProbe(object):
return ModelType.TextualInversion
checkpoint = checkpoint or cls._scan_and_load_checkpoint(model_path)
state_dict = checkpoint.get("state_dict") or checkpoint
if len(checkpoint) < 10 and all(isinstance(v, torch.Tensor) for v in checkpoint.values()):
return ModelType.TextualInversion
if any([x.startswith("model.diffusion_model") for x in state_dict.keys()]):
return ModelType.Pipeline
if any([x.startswith("encoder.conv_in") for x in state_dict.keys()]):
@ -326,13 +326,9 @@ class PipelineFolderProbe(FolderProbeBase):
def get_base_type(self)->BaseModelType:
if self.model:
unet_conf = self.model.unet.config
scheduler_conf = self.model.scheduler.config
else:
with open(self.folder_path / 'unet' / 'config.json','r') as file:
unet_conf = json.load(file)
with open(self.folder_path / 'scheduler' / 'scheduler_config.json','r') as file:
scheduler_conf = json.load(file)
if unet_conf['cross_attention_dim'] == 768:
return BaseModelType.StableDiffusion1
elif unet_conf['cross_attention_dim'] == 1024:

View File

@ -56,7 +56,6 @@ class ModelConfigBase(BaseModel):
class Config:
use_enum_values = True
class EmptyConfigLoader(ConfigMixin):
@classmethod
def load_config(cls, *args, **kwargs):

View File

@ -45,6 +45,7 @@ sd-1/pipeline/portraitplus:
repo_id: wavymulder/portraitplus
recommended: False
sd-1/pipeline/seek.art_MEGA:
repo_id: coreco/seek.art_MEGA
description: A general use SD-1.5 "anything" model that supports multiple styles (2.1 GB)
recommended: False
sd-1/pipeline/trinart_stable_diffusion_v2: