mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
implemented Stalker's suggested improvements
This commit is contained in:
parent
d5f742620f
commit
c3c4a71173
@ -478,7 +478,7 @@ def _parse_legacy_initfile(root: Path, initfile: Path)->ModelPaths:
|
|||||||
models = root / 'models',
|
models = root / 'models',
|
||||||
embeddings = root / str(opt.embedding_path).strip('"'),
|
embeddings = root / str(opt.embedding_path).strip('"'),
|
||||||
loras = root / str(opt.lora_path).strip('"'),
|
loras = root / str(opt.lora_path).strip('"'),
|
||||||
controlnets = None
|
controlnets = root / 'controlnets',
|
||||||
)
|
)
|
||||||
|
|
||||||
def _parse_legacy_yamlfile(root: Path, initfile: Path)->ModelPaths:
|
def _parse_legacy_yamlfile(root: Path, initfile: Path)->ModelPaths:
|
||||||
|
@ -9,7 +9,11 @@ from pathlib import Path
|
|||||||
from typing import Callable, Literal, Union, Dict
|
from typing import Callable, Literal, Union, Dict
|
||||||
from picklescan.scanner import scan_file_path
|
from picklescan.scanner import scan_file_path
|
||||||
|
|
||||||
from .models import BaseModelType, ModelType, ModelVariantType, SchedulerPredictionType, SilenceWarnings
|
from .models import (
|
||||||
|
BaseModelType, ModelType, ModelVariantType,
|
||||||
|
SchedulerPredictionType, SilenceWarnings,
|
||||||
|
)
|
||||||
|
from .models.base import read_checkpoint_meta
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ModelProbeInfo(object):
|
class ModelProbeInfo(object):
|
||||||
@ -105,29 +109,34 @@ class ModelProbe(object):
|
|||||||
return model_info
|
return model_info
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_model_type_from_checkpoint(cls, model_path: Path, checkpoint: dict)->ModelType:
|
def get_model_type_from_checkpoint(cls, model_path: Path, checkpoint: dict) -> ModelType:
|
||||||
if model_path.suffix not in ('.bin','.pt','.ckpt','.safetensors'):
|
if model_path.suffix not in ('.bin','.pt','.ckpt','.safetensors','.pth'):
|
||||||
return None
|
return None
|
||||||
if model_path.name=='learned_embeds.bin':
|
|
||||||
|
if model_path.name == "learned_embeds.bin":
|
||||||
return ModelType.TextualInversion
|
return ModelType.TextualInversion
|
||||||
checkpoint = checkpoint or cls._scan_and_load_checkpoint(model_path)
|
|
||||||
state_dict = checkpoint.get("state_dict") or checkpoint
|
checkpoint = checkpoint or read_checkpoint_meta(model_path, scan=True)
|
||||||
|
checkpoint = checkpoint.get("state_dict", checkpoint)
|
||||||
|
|
||||||
|
for key in checkpoint.keys():
|
||||||
|
if any(key.startswith(v) for v in {"cond_stage_model.", "first_stage_model.", "model.diffusion_model."}):
|
||||||
|
return ModelType.Main
|
||||||
|
elif any(key.startswith(v) for v in {"encoder.conv_in", "decoder.conv_in"}):
|
||||||
|
return ModelType.Vae
|
||||||
|
elif any(key.startswith(v) for v in {"lora_te_", "lora_unet_"}):
|
||||||
|
return ModelType.Lora
|
||||||
|
elif any(key.startswith(v) for v in {"control_model", "input_blocks"}):
|
||||||
|
return ModelType.ControlNet
|
||||||
|
elif key in {"emb_params", "string_to_param"}:
|
||||||
|
return ModelType.TextualInversion
|
||||||
|
|
||||||
|
else:
|
||||||
|
# diffusers-ti
|
||||||
|
if len(checkpoint) < 10 and all(isinstance(v, torch.Tensor) for v in checkpoint.values()):
|
||||||
|
return ModelType.TextualInversion
|
||||||
|
|
||||||
if len(checkpoint) < 10 and all(isinstance(v, torch.Tensor) for v in checkpoint.values()):
|
raise ValueError("Unable to determine model type")
|
||||||
return ModelType.TextualInversion
|
|
||||||
if any([x.startswith("model.diffusion_model") for x in state_dict.keys()]):
|
|
||||||
return ModelType.Main
|
|
||||||
if any([x.startswith("encoder.conv_in") for x in state_dict.keys()]):
|
|
||||||
return ModelType.Vae
|
|
||||||
if "string_to_token" in state_dict or "emb_params" in state_dict:
|
|
||||||
return ModelType.TextualInversion
|
|
||||||
if any([x.startswith("lora") for x in state_dict.keys()]):
|
|
||||||
return ModelType.Lora
|
|
||||||
if any([x.startswith("control_model") for x in state_dict.keys()]):
|
|
||||||
return ModelType.ControlNet
|
|
||||||
if any([x.startswith("input_blocks") for x in state_dict.keys()]):
|
|
||||||
return ModelType.ControlNet
|
|
||||||
return None # give up
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_model_type_from_folder(cls, folder_path: Path, model: ModelMixin)->ModelType:
|
def get_model_type_from_folder(cls, folder_path: Path, model: ModelMixin)->ModelType:
|
||||||
|
@ -1,9 +1,12 @@
|
|||||||
|
import json
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import typing
|
import typing
|
||||||
import inspect
|
import inspect
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from abc import ABCMeta, abstractmethod
|
from abc import ABCMeta, abstractmethod
|
||||||
|
from pathlib import Path
|
||||||
|
from picklescan.scanner import scan_file_path
|
||||||
import torch
|
import torch
|
||||||
import safetensors.torch
|
import safetensors.torch
|
||||||
from diffusers import DiffusionPipeline, ConfigMixin
|
from diffusers import DiffusionPipeline, ConfigMixin
|
||||||
@ -382,15 +385,18 @@ def _fast_safetensors_reader(path: str):
|
|||||||
|
|
||||||
return checkpoint
|
return checkpoint
|
||||||
|
|
||||||
|
def read_checkpoint_meta(path: Union[str, Path], scan: bool = False):
|
||||||
def read_checkpoint_meta(path: str):
|
if str(path).endswith(".safetensors"):
|
||||||
if path.endswith(".safetensors"):
|
|
||||||
try:
|
try:
|
||||||
checkpoint = _fast_safetensors_reader(path)
|
checkpoint = _fast_safetensors_reader(path)
|
||||||
except:
|
except:
|
||||||
# TODO: create issue for support "meta"?
|
# TODO: create issue for support "meta"?
|
||||||
checkpoint = safetensors.torch.load_file(path, device="cpu")
|
checkpoint = safetensors.torch.load_file(path, device="cpu")
|
||||||
else:
|
else:
|
||||||
|
if scan:
|
||||||
|
scan_result = scan_file_path(checkpoint)
|
||||||
|
if scan_result.infected_files != 0:
|
||||||
|
raise Exception(f"The model file \"{path}\" is potentially infected by malware. Aborting import.")
|
||||||
checkpoint = torch.load(path, map_location=torch.device("meta"))
|
checkpoint = torch.load(path, map_location=torch.device("meta"))
|
||||||
return checkpoint
|
return checkpoint
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user