mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
model_probe working; model_install incomplete
This commit is contained in:
parent
085ab54124
commit
893f776f1d
@ -3,4 +3,4 @@ Initialization file for invokeai.backend.model_management
|
||||
"""
|
||||
from .model_manager import ModelManager, ModelInfo
|
||||
from .model_cache import ModelCache
|
||||
from .models import BaseModelType, ModelType, SubModelType
|
||||
from .models import BaseModelType, ModelType, SubModelType, VariantType
|
||||
|
@ -29,6 +29,7 @@ import torch
|
||||
|
||||
from diffusers import logging as diffusers_logging
|
||||
from transformers import logging as transformers_logging
|
||||
import logging
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.app.services.config import get_invokeai_config
|
||||
from .lora import LoRAModel, TextualInversionModel
|
||||
|
@ -4,6 +4,8 @@ Routines for downloading and installing models.
|
||||
import json
|
||||
import safetensors
|
||||
import safetensors.torch
|
||||
import shutil
|
||||
import tempfile
|
||||
import torch
|
||||
import traceback
|
||||
from dataclasses import dataclass
|
||||
@ -14,8 +16,10 @@ from pathlib import Path
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from .models import BaseModelType, ModelType
|
||||
from . import ModelManager
|
||||
from .models import BaseModelType, ModelType, VariantType
|
||||
from .model_probe import ModelProbe, ModelVariantInfo
|
||||
from .model_cache import SilenceWarnings
|
||||
|
||||
class ModelInstall(object):
|
||||
'''
|
||||
@ -54,32 +58,49 @@ class ModelInstall(object):
|
||||
if not model_info:
|
||||
raise ValueError(f"Unable to determine type of checkpoint file {checkpoint}")
|
||||
|
||||
# non-pipeline; no conversion needed, just copy into right place
|
||||
if model_info.model_type != ModelType.Pipeline:
|
||||
destination_path = self._dest_path(model_info) / checkpoint.name
|
||||
self._check_for_collision(destination_path)
|
||||
shutil.copyfile(checkpoint, destination_path)
|
||||
key = ModelManager.create_key(
|
||||
model_name = checkpoint.stem,
|
||||
base_model = model_info.base_type
|
||||
model_type = model_info.model_type
|
||||
base_model = model_info.base_type,
|
||||
model_type = model_info.model_type,
|
||||
)
|
||||
return {
|
||||
destination_path = self._dest_path(model_info) / checkpoint
|
||||
destination_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
self._check_for_collision(destination_path)
|
||||
stanza = {
|
||||
key: dict(
|
||||
name = model_name,
|
||||
description = f'{model_info.model_type} model {model_name}',
|
||||
name = checkpoint.stem,
|
||||
description = f'{model_info.model_type} model {checkpoint.stem}',
|
||||
base = model_info.base_model.value,
|
||||
type = model_info.model_type.value,
|
||||
variant = model_info.variant_type.value,
|
||||
path = str(destination_path),
|
||||
format = 'checkpoint',
|
||||
base = str(base_model),
|
||||
type = str(model_type),
|
||||
variant = str(model_info.variant_type),
|
||||
)
|
||||
}
|
||||
|
||||
# non-pipeline; no conversion needed, just copy into right place
|
||||
if model_info.model_type != ModelType.Pipeline:
|
||||
shutil.copyfile(checkpoint, destination_path)
|
||||
stanza[key].update({'format': 'checkpoint'})
|
||||
|
||||
# pipeline - conversion needed here
|
||||
else:
|
||||
destination_path = self._dest_path(model_info) / checkpoint.stem
|
||||
config_file = self._pipeline_type_to_config_file(model_info.model_type)
|
||||
|
||||
from .convert_ckpt_to_diffusers import convert_ckpt_to_diffusers
|
||||
with SilenceWarnings:
|
||||
convert_ckpt_to_diffusers(
|
||||
checkpoint,
|
||||
destination_path,
|
||||
extract_ema=True,
|
||||
original_config_file=config_file,
|
||||
scan_needed=False,
|
||||
)
|
||||
stanza[key].update({'format': 'folder',
|
||||
'path': destination_path, # no suffix on this
|
||||
})
|
||||
|
||||
return stanza
|
||||
|
||||
|
||||
def _check_for_collision(self, path: Path):
|
||||
|
@ -4,13 +4,14 @@ import torch
|
||||
import safetensors.torch
|
||||
|
||||
from dataclasses import dataclass
|
||||
from diffusers import ModelMixin
|
||||
from diffusers import ModelMixin, ConfigMixin, StableDiffusionPipeline, AutoencoderKL, ControlNetModel
|
||||
from pathlib import Path
|
||||
from typing import Callable, Literal, Union, Dict
|
||||
from picklescan.scanner import scan_file_path
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from .models import BaseModelType, ModelType, VariantType
|
||||
from .model_cache import SilenceWarnings
|
||||
|
||||
@dataclass
|
||||
class ModelVariantInfo(object):
|
||||
@ -30,10 +31,9 @@ class ModelProbe(object):
|
||||
}
|
||||
|
||||
CLASS2TYPE = {
|
||||
"StableDiffusionPipeline" : ModelType.Pipeline,
|
||||
"AutoencoderKL": ModelType.Vae,
|
||||
"ControlNetModel" : ModelType.ControlNet,
|
||||
|
||||
'StableDiffusionPipeline' : ModelType.Pipeline,
|
||||
'AutoencoderKL' : ModelType.Vae,
|
||||
'ControlNetModel' : ModelType.ControlNet,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
@ -56,7 +56,11 @@ class ModelProbe(object):
|
||||
the path to the model and returns the BaseModelType. It is called to distinguish
|
||||
between V2-Base and V2-768 SD models.
|
||||
'''
|
||||
format = 'folder' if model_path.is_dir() else 'file'
|
||||
if model_path:
|
||||
format = 'folder' if model_path.is_dir() else 'checkpoint'
|
||||
else:
|
||||
format = 'folder' if isinstance(model,(ConfigMixin,ModelMixin)) else 'checkpoint'
|
||||
|
||||
model_info = None
|
||||
try:
|
||||
model_type = cls.get_model_type_from_folder(model_path, model) \
|
||||
@ -102,6 +106,9 @@ class ModelProbe(object):
|
||||
'''
|
||||
Get the model type of a hugging-face style folder.
|
||||
'''
|
||||
if model:
|
||||
class_name = model.__class__.__name__
|
||||
else:
|
||||
if (folder_path / 'learned_embeds.bin').exists():
|
||||
return ModelType.TextualInversion
|
||||
|
||||
@ -113,16 +120,18 @@ class ModelProbe(object):
|
||||
config_path = i if i.exists() else c if c.exists() else None
|
||||
|
||||
if config_path:
|
||||
conf = json.loads(config_path)
|
||||
conf = json.load(open(config_path,'r'))
|
||||
class_name = conf['_class_name']
|
||||
|
||||
if type := cls.CLASS2TYPE.get(class_name):
|
||||
return type
|
||||
|
||||
# give up
|
||||
raise ValueError("Unable to determine model type of {model_path}")
|
||||
raise ValueError("Unable to determine model type")
|
||||
|
||||
@classmethod
|
||||
def _scan_and_load_checkpoint(cls,model_path: Path)->dict:
|
||||
with SilenceWarnings():
|
||||
if model_path.suffix.endswith((".ckpt", ".pt", ".bin")):
|
||||
cls._scan_model(model_path, model_path)
|
||||
return torch.load(model_path)
|
||||
@ -255,43 +264,58 @@ class ControlNetCheckpointProbe(CheckpointProbeBase):
|
||||
#######################################################
|
||||
class FolderProbeBase(ProbeBase):
|
||||
def __init__(self,
|
||||
model: ModelMixin,
|
||||
folder_path: Path,
|
||||
model: ModelMixin = None,
|
||||
helper: Callable=None # not used
|
||||
):
|
||||
self.model = model
|
||||
self.folder_path = folder_path
|
||||
|
||||
def get_variant_type(self)->VariantType:
|
||||
|
||||
# only works for pipelines
|
||||
config_file = self.folder_path / 'unet' / 'config.json'
|
||||
if not config_file.exists():
|
||||
return VariantType.Normal
|
||||
|
||||
conf = json.loads(config_file)
|
||||
channels = conf['in_channels']
|
||||
if channels == 9:
|
||||
return VariantType.Inpainting
|
||||
elif channels == 5:
|
||||
return VariantType.Depth
|
||||
elif channels == 4:
|
||||
return VariantType.Normal
|
||||
else:
|
||||
return VariantType.Normal
|
||||
|
||||
class PipelineFolderProbe(FolderProbeBase):
|
||||
def get_base_type(self)->BaseModelType:
|
||||
config_file = self.folder_path / 'scheduler' / 'scheduler_config.json'
|
||||
if not config_file.exists():
|
||||
return None
|
||||
conf = json.load(config_file)
|
||||
if conf['prediction_type'] == "v_prediction":
|
||||
if self.model:
|
||||
unet_conf = self.model.unet.config
|
||||
scheduler_conf = self.model.scheduler.config
|
||||
else:
|
||||
unet_conf = json.load(open(self.folder_path / 'unet' / 'config.json','r'))
|
||||
scheduler_conf = json.load(open(self.folder_path / 'scheduler' / 'scheduler_config.json','r'))
|
||||
|
||||
if unet_conf['cross_attention_dim'] == 768:
|
||||
return BaseModelType.StableDiffusion1_5
|
||||
elif unet_conf['cross_attention_dim'] == 1024:
|
||||
if scheduler_conf['prediction_type'] == "v_prediction":
|
||||
return BaseModelType.StableDiffusion2
|
||||
elif conf['prediction_type'] == 'epsilon':
|
||||
elif scheduler_conf['prediction_type'] == 'epsilon':
|
||||
return BaseModelType.StableDiffusion2Base
|
||||
else:
|
||||
return BaseModelType.StableDiffusion2
|
||||
else:
|
||||
raise ValueError(f'Unknown base model for {self.folder_path}')
|
||||
|
||||
def get_variant_type(self)->VariantType:
|
||||
# This only works for pipelines! Any kind of
|
||||
# exception results in our returning the
|
||||
# "normal" variant type
|
||||
try:
|
||||
if self.model:
|
||||
conf = self.model.unet.config
|
||||
else:
|
||||
config_file = self.folder_path / 'unet' / 'config.json'
|
||||
conf = json.load(open(config_file,'r'))
|
||||
|
||||
in_channels = conf['in_channels']
|
||||
if in_channels == 9:
|
||||
return VariantType.Inpainting
|
||||
elif in_channels == 5:
|
||||
return VariantType.Depth
|
||||
elif in_channels == 4:
|
||||
return VariantType.Normal
|
||||
except:
|
||||
pass
|
||||
return VariantType.Normal
|
||||
|
||||
class VaeFolderProbe(FolderProbeBase):
|
||||
def get_base_type(self)->BaseModelType:
|
||||
|
@ -1,4 +1,4 @@
|
||||
from .base import BaseModelType, ModelType, SubModelType, ModelBase, ModelConfigBase
|
||||
from .base import BaseModelType, ModelType, SubModelType, ModelBase, ModelConfigBase, VariantType
|
||||
from .stable_diffusion import StableDiffusion15Model, StableDiffusion2Model, StableDiffusion2BaseModel
|
||||
from .vae import VaeModel
|
||||
from .lora import LoRAModel
|
||||
|
@ -21,9 +21,8 @@ class BaseModelType(str, Enum):
|
||||
class ModelType(str, Enum):
|
||||
Pipeline = "pipeline"
|
||||
Vae = "vae"
|
||||
|
||||
Lora = "lora"
|
||||
#ControlNet = "controlnet"
|
||||
ControlNet = "controlnet" # used by model_probe
|
||||
TextualInversion = "embedding"
|
||||
|
||||
class SubModelType(str, Enum):
|
||||
|
Loading…
Reference in New Issue
Block a user