mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
model_probe working; model_install incomplete
This commit is contained in:
parent
085ab54124
commit
893f776f1d
@ -3,4 +3,4 @@ Initialization file for invokeai.backend.model_management
|
|||||||
"""
|
"""
|
||||||
from .model_manager import ModelManager, ModelInfo
|
from .model_manager import ModelManager, ModelInfo
|
||||||
from .model_cache import ModelCache
|
from .model_cache import ModelCache
|
||||||
from .models import BaseModelType, ModelType, SubModelType
|
from .models import BaseModelType, ModelType, SubModelType, VariantType
|
||||||
|
@ -29,6 +29,7 @@ import torch
|
|||||||
|
|
||||||
from diffusers import logging as diffusers_logging
|
from diffusers import logging as diffusers_logging
|
||||||
from transformers import logging as transformers_logging
|
from transformers import logging as transformers_logging
|
||||||
|
import logging
|
||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
from invokeai.app.services.config import get_invokeai_config
|
from invokeai.app.services.config import get_invokeai_config
|
||||||
from .lora import LoRAModel, TextualInversionModel
|
from .lora import LoRAModel, TextualInversionModel
|
||||||
|
@ -4,6 +4,8 @@ Routines for downloading and installing models.
|
|||||||
import json
|
import json
|
||||||
import safetensors
|
import safetensors
|
||||||
import safetensors.torch
|
import safetensors.torch
|
||||||
|
import shutil
|
||||||
|
import tempfile
|
||||||
import torch
|
import torch
|
||||||
import traceback
|
import traceback
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
@ -14,8 +16,10 @@ from pathlib import Path
|
|||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
from .models import BaseModelType, ModelType
|
from . import ModelManager
|
||||||
|
from .models import BaseModelType, ModelType, VariantType
|
||||||
from .model_probe import ModelProbe, ModelVariantInfo
|
from .model_probe import ModelProbe, ModelVariantInfo
|
||||||
|
from .model_cache import SilenceWarnings
|
||||||
|
|
||||||
class ModelInstall(object):
|
class ModelInstall(object):
|
||||||
'''
|
'''
|
||||||
@ -54,32 +58,49 @@ class ModelInstall(object):
|
|||||||
if not model_info:
|
if not model_info:
|
||||||
raise ValueError(f"Unable to determine type of checkpoint file {checkpoint}")
|
raise ValueError(f"Unable to determine type of checkpoint file {checkpoint}")
|
||||||
|
|
||||||
|
key = ModelManager.create_key(
|
||||||
|
model_name = checkpoint.stem,
|
||||||
|
base_model = model_info.base_type,
|
||||||
|
model_type = model_info.model_type,
|
||||||
|
)
|
||||||
|
destination_path = self._dest_path(model_info) / checkpoint
|
||||||
|
destination_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
self._check_for_collision(destination_path)
|
||||||
|
stanza = {
|
||||||
|
key: dict(
|
||||||
|
name = checkpoint.stem,
|
||||||
|
description = f'{model_info.model_type} model {checkpoint.stem}',
|
||||||
|
base = model_info.base_model.value,
|
||||||
|
type = model_info.model_type.value,
|
||||||
|
variant = model_info.variant_type.value,
|
||||||
|
path = str(destination_path),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
# non-pipeline; no conversion needed, just copy into right place
|
# non-pipeline; no conversion needed, just copy into right place
|
||||||
if model_info.model_type != ModelType.Pipeline:
|
if model_info.model_type != ModelType.Pipeline:
|
||||||
destination_path = self._dest_path(model_info) / checkpoint.name
|
|
||||||
self._check_for_collision(destination_path)
|
|
||||||
shutil.copyfile(checkpoint, destination_path)
|
shutil.copyfile(checkpoint, destination_path)
|
||||||
key = ModelManager.create_key(
|
stanza[key].update({'format': 'checkpoint'})
|
||||||
model_name = checkpoint.stem,
|
|
||||||
base_model = model_info.base_type
|
# pipeline - conversion needed here
|
||||||
model_type = model_info.model_type
|
else:
|
||||||
)
|
destination_path = self._dest_path(model_info) / checkpoint.stem
|
||||||
return {
|
config_file = self._pipeline_type_to_config_file(model_info.model_type)
|
||||||
key: dict(
|
|
||||||
name = model_name,
|
from .convert_ckpt_to_diffusers import convert_ckpt_to_diffusers
|
||||||
description = f'{model_info.model_type} model {model_name}',
|
with SilenceWarnings:
|
||||||
path = str(destination_path),
|
convert_ckpt_to_diffusers(
|
||||||
format = 'checkpoint',
|
checkpoint,
|
||||||
base = str(base_model),
|
destination_path,
|
||||||
type = str(model_type),
|
extract_ema=True,
|
||||||
variant = str(model_info.variant_type),
|
original_config_file=config_file,
|
||||||
|
scan_needed=False,
|
||||||
)
|
)
|
||||||
}
|
stanza[key].update({'format': 'folder',
|
||||||
|
'path': destination_path, # no suffix on this
|
||||||
|
})
|
||||||
destination_path = self._dest_path(model_info) / checkpoint.stem
|
|
||||||
|
|
||||||
|
|
||||||
|
return stanza
|
||||||
|
|
||||||
|
|
||||||
def _check_for_collision(self, path: Path):
|
def _check_for_collision(self, path: Path):
|
||||||
|
@ -4,13 +4,14 @@ import torch
|
|||||||
import safetensors.torch
|
import safetensors.torch
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from diffusers import ModelMixin
|
from diffusers import ModelMixin, ConfigMixin, StableDiffusionPipeline, AutoencoderKL, ControlNetModel
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Callable, Literal, Union, Dict
|
from typing import Callable, Literal, Union, Dict
|
||||||
from picklescan.scanner import scan_file_path
|
from picklescan.scanner import scan_file_path
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
from .models import BaseModelType, ModelType, VariantType
|
from .models import BaseModelType, ModelType, VariantType
|
||||||
|
from .model_cache import SilenceWarnings
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ModelVariantInfo(object):
|
class ModelVariantInfo(object):
|
||||||
@ -30,10 +31,9 @@ class ModelProbe(object):
|
|||||||
}
|
}
|
||||||
|
|
||||||
CLASS2TYPE = {
|
CLASS2TYPE = {
|
||||||
"StableDiffusionPipeline" : ModelType.Pipeline,
|
'StableDiffusionPipeline' : ModelType.Pipeline,
|
||||||
"AutoencoderKL": ModelType.Vae,
|
'AutoencoderKL' : ModelType.Vae,
|
||||||
"ControlNetModel" : ModelType.ControlNet,
|
'ControlNetModel' : ModelType.ControlNet,
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -56,7 +56,11 @@ class ModelProbe(object):
|
|||||||
the path to the model and returns the BaseModelType. It is called to distinguish
|
the path to the model and returns the BaseModelType. It is called to distinguish
|
||||||
between V2-Base and V2-768 SD models.
|
between V2-Base and V2-768 SD models.
|
||||||
'''
|
'''
|
||||||
format = 'folder' if model_path.is_dir() else 'file'
|
if model_path:
|
||||||
|
format = 'folder' if model_path.is_dir() else 'checkpoint'
|
||||||
|
else:
|
||||||
|
format = 'folder' if isinstance(model,(ConfigMixin,ModelMixin)) else 'checkpoint'
|
||||||
|
|
||||||
model_info = None
|
model_info = None
|
||||||
try:
|
try:
|
||||||
model_type = cls.get_model_type_from_folder(model_path, model) \
|
model_type = cls.get_model_type_from_folder(model_path, model) \
|
||||||
@ -102,32 +106,37 @@ class ModelProbe(object):
|
|||||||
'''
|
'''
|
||||||
Get the model type of a hugging-face style folder.
|
Get the model type of a hugging-face style folder.
|
||||||
'''
|
'''
|
||||||
if (folder_path / 'learned_embeds.bin').exists():
|
if model:
|
||||||
return ModelType.TextualInversion
|
class_name = model.__class__.__name__
|
||||||
|
else:
|
||||||
|
if (folder_path / 'learned_embeds.bin').exists():
|
||||||
|
return ModelType.TextualInversion
|
||||||
|
|
||||||
if (folder_path / 'pytorch_lora_weights.bin').exists():
|
if (folder_path / 'pytorch_lora_weights.bin').exists():
|
||||||
return ModelType.Lora
|
return ModelType.Lora
|
||||||
|
|
||||||
i = folder_path / 'model_index.json'
|
i = folder_path / 'model_index.json'
|
||||||
c = folder_path / 'config.json'
|
c = folder_path / 'config.json'
|
||||||
config_path = i if i.exists() else c if c.exists() else None
|
config_path = i if i.exists() else c if c.exists() else None
|
||||||
|
|
||||||
if config_path:
|
if config_path:
|
||||||
conf = json.loads(config_path)
|
conf = json.load(open(config_path,'r'))
|
||||||
class_name = conf['_class_name']
|
class_name = conf['_class_name']
|
||||||
if type := cls.CLASS2TYPE.get(class_name):
|
|
||||||
return type
|
if type := cls.CLASS2TYPE.get(class_name):
|
||||||
|
return type
|
||||||
|
|
||||||
# give up
|
# give up
|
||||||
raise ValueError("Unable to determine model type of {model_path}")
|
raise ValueError("Unable to determine model type")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _scan_and_load_checkpoint(cls,model_path: Path)->dict:
|
def _scan_and_load_checkpoint(cls,model_path: Path)->dict:
|
||||||
if model_path.suffix.endswith((".ckpt", ".pt", ".bin")):
|
with SilenceWarnings():
|
||||||
cls._scan_model(model_path, model_path)
|
if model_path.suffix.endswith((".ckpt", ".pt", ".bin")):
|
||||||
return torch.load(model_path)
|
cls._scan_model(model_path, model_path)
|
||||||
else:
|
return torch.load(model_path)
|
||||||
return safetensors.torch.load_file(model_path)
|
else:
|
||||||
|
return safetensors.torch.load_file(model_path)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _scan_model(cls, model_name, checkpoint):
|
def _scan_model(cls, model_name, checkpoint):
|
||||||
@ -255,43 +264,58 @@ class ControlNetCheckpointProbe(CheckpointProbeBase):
|
|||||||
#######################################################
|
#######################################################
|
||||||
class FolderProbeBase(ProbeBase):
|
class FolderProbeBase(ProbeBase):
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
model: ModelMixin,
|
|
||||||
folder_path: Path,
|
folder_path: Path,
|
||||||
|
model: ModelMixin = None,
|
||||||
helper: Callable=None # not used
|
helper: Callable=None # not used
|
||||||
):
|
):
|
||||||
self.model = model
|
self.model = model
|
||||||
self.folder_path = folder_path
|
self.folder_path = folder_path
|
||||||
|
|
||||||
def get_variant_type(self)->VariantType:
|
def get_variant_type(self)->VariantType:
|
||||||
|
return VariantType.Normal
|
||||||
# only works for pipelines
|
|
||||||
config_file = self.folder_path / 'unet' / 'config.json'
|
|
||||||
if not config_file.exists():
|
|
||||||
return VariantType.Normal
|
|
||||||
|
|
||||||
conf = json.loads(config_file)
|
|
||||||
channels = conf['in_channels']
|
|
||||||
if channels == 9:
|
|
||||||
return VariantType.Inpainting
|
|
||||||
elif channels == 5:
|
|
||||||
return VariantType.Depth
|
|
||||||
elif channels == 4:
|
|
||||||
return VariantType.Normal
|
|
||||||
else:
|
|
||||||
return VariantType.Normal
|
|
||||||
|
|
||||||
class PipelineFolderProbe(FolderProbeBase):
|
class PipelineFolderProbe(FolderProbeBase):
|
||||||
def get_base_type(self)->BaseModelType:
|
def get_base_type(self)->BaseModelType:
|
||||||
config_file = self.folder_path / 'scheduler' / 'scheduler_config.json'
|
if self.model:
|
||||||
if not config_file.exists():
|
unet_conf = self.model.unet.config
|
||||||
return None
|
scheduler_conf = self.model.scheduler.config
|
||||||
conf = json.load(config_file)
|
|
||||||
if conf['prediction_type'] == "v_prediction":
|
|
||||||
return BaseModelType.StableDiffusion2
|
|
||||||
elif conf['prediction_type'] == 'epsilon':
|
|
||||||
return BaseModelType.StableDiffusion2Base
|
|
||||||
else:
|
else:
|
||||||
return BaseModelType.StableDiffusion2
|
unet_conf = json.load(open(self.folder_path / 'unet' / 'config.json','r'))
|
||||||
|
scheduler_conf = json.load(open(self.folder_path / 'scheduler' / 'scheduler_config.json','r'))
|
||||||
|
|
||||||
|
if unet_conf['cross_attention_dim'] == 768:
|
||||||
|
return BaseModelType.StableDiffusion1_5
|
||||||
|
elif unet_conf['cross_attention_dim'] == 1024:
|
||||||
|
if scheduler_conf['prediction_type'] == "v_prediction":
|
||||||
|
return BaseModelType.StableDiffusion2
|
||||||
|
elif scheduler_conf['prediction_type'] == 'epsilon':
|
||||||
|
return BaseModelType.StableDiffusion2Base
|
||||||
|
else:
|
||||||
|
return BaseModelType.StableDiffusion2
|
||||||
|
else:
|
||||||
|
raise ValueError(f'Unknown base model for {self.folder_path}')
|
||||||
|
|
||||||
|
def get_variant_type(self)->VariantType:
|
||||||
|
# This only works for pipelines! Any kind of
|
||||||
|
# exception results in our returning the
|
||||||
|
# "normal" variant type
|
||||||
|
try:
|
||||||
|
if self.model:
|
||||||
|
conf = self.model.unet.config
|
||||||
|
else:
|
||||||
|
config_file = self.folder_path / 'unet' / 'config.json'
|
||||||
|
conf = json.load(open(config_file,'r'))
|
||||||
|
|
||||||
|
in_channels = conf['in_channels']
|
||||||
|
if in_channels == 9:
|
||||||
|
return VariantType.Inpainting
|
||||||
|
elif in_channels == 5:
|
||||||
|
return VariantType.Depth
|
||||||
|
elif in_channels == 4:
|
||||||
|
return VariantType.Normal
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
return VariantType.Normal
|
||||||
|
|
||||||
class VaeFolderProbe(FolderProbeBase):
|
class VaeFolderProbe(FolderProbeBase):
|
||||||
def get_base_type(self)->BaseModelType:
|
def get_base_type(self)->BaseModelType:
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from .base import BaseModelType, ModelType, SubModelType, ModelBase, ModelConfigBase
|
from .base import BaseModelType, ModelType, SubModelType, ModelBase, ModelConfigBase, VariantType
|
||||||
from .stable_diffusion import StableDiffusion15Model, StableDiffusion2Model, StableDiffusion2BaseModel
|
from .stable_diffusion import StableDiffusion15Model, StableDiffusion2Model, StableDiffusion2BaseModel
|
||||||
from .vae import VaeModel
|
from .vae import VaeModel
|
||||||
from .lora import LoRAModel
|
from .lora import LoRAModel
|
||||||
|
@ -21,9 +21,8 @@ class BaseModelType(str, Enum):
|
|||||||
class ModelType(str, Enum):
|
class ModelType(str, Enum):
|
||||||
Pipeline = "pipeline"
|
Pipeline = "pipeline"
|
||||||
Vae = "vae"
|
Vae = "vae"
|
||||||
|
|
||||||
Lora = "lora"
|
Lora = "lora"
|
||||||
#ControlNet = "controlnet"
|
ControlNet = "controlnet" # used by model_probe
|
||||||
TextualInversion = "embedding"
|
TextualInversion = "embedding"
|
||||||
|
|
||||||
class SubModelType(str, Enum):
|
class SubModelType(str, Enum):
|
||||||
|
Loading…
Reference in New Issue
Block a user