implemented Stalker's suggested improvements

This commit is contained in:
Lincoln Stein 2023-06-24 12:37:26 -04:00
parent d5f742620f
commit c3c4a71173
3 changed files with 40 additions and 25 deletions

View File

@ -478,7 +478,7 @@ def _parse_legacy_initfile(root: Path, initfile: Path)->ModelPaths:
models = root / 'models', models = root / 'models',
embeddings = root / str(opt.embedding_path).strip('"'), embeddings = root / str(opt.embedding_path).strip('"'),
loras = root / str(opt.lora_path).strip('"'), loras = root / str(opt.lora_path).strip('"'),
controlnets = None controlnets = root / 'controlnets',
) )
def _parse_legacy_yamlfile(root: Path, initfile: Path)->ModelPaths: def _parse_legacy_yamlfile(root: Path, initfile: Path)->ModelPaths:

View File

@ -9,7 +9,11 @@ from pathlib import Path
from typing import Callable, Literal, Union, Dict from typing import Callable, Literal, Union, Dict
from picklescan.scanner import scan_file_path from picklescan.scanner import scan_file_path
from .models import BaseModelType, ModelType, ModelVariantType, SchedulerPredictionType, SilenceWarnings from .models import (
BaseModelType, ModelType, ModelVariantType,
SchedulerPredictionType, SilenceWarnings,
)
from .models.base import read_checkpoint_meta
@dataclass @dataclass
class ModelProbeInfo(object): class ModelProbeInfo(object):
@ -105,29 +109,34 @@ class ModelProbe(object):
return model_info return model_info
@classmethod @classmethod
def get_model_type_from_checkpoint(cls, model_path: Path, checkpoint: dict)->ModelType: def get_model_type_from_checkpoint(cls, model_path: Path, checkpoint: dict) -> ModelType:
if model_path.suffix not in ('.bin','.pt','.ckpt','.safetensors'): if model_path.suffix not in ('.bin','.pt','.ckpt','.safetensors','.pth'):
return None return None
if model_path.name=='learned_embeds.bin':
if model_path.name == "learned_embeds.bin":
return ModelType.TextualInversion return ModelType.TextualInversion
checkpoint = checkpoint or cls._scan_and_load_checkpoint(model_path)
state_dict = checkpoint.get("state_dict") or checkpoint checkpoint = checkpoint or read_checkpoint_meta(model_path, scan=True)
checkpoint = checkpoint.get("state_dict", checkpoint)
for key in checkpoint.keys():
if any(key.startswith(v) for v in {"cond_stage_model.", "first_stage_model.", "model.diffusion_model."}):
return ModelType.Main
elif any(key.startswith(v) for v in {"encoder.conv_in", "decoder.conv_in"}):
return ModelType.Vae
elif any(key.startswith(v) for v in {"lora_te_", "lora_unet_"}):
return ModelType.Lora
elif any(key.startswith(v) for v in {"control_model", "input_blocks"}):
return ModelType.ControlNet
elif key in {"emb_params", "string_to_param"}:
return ModelType.TextualInversion
else:
# diffusers-ti
if len(checkpoint) < 10 and all(isinstance(v, torch.Tensor) for v in checkpoint.values()):
return ModelType.TextualInversion
if len(checkpoint) < 10 and all(isinstance(v, torch.Tensor) for v in checkpoint.values()): raise ValueError("Unable to determine model type")
return ModelType.TextualInversion
if any([x.startswith("model.diffusion_model") for x in state_dict.keys()]):
return ModelType.Main
if any([x.startswith("encoder.conv_in") for x in state_dict.keys()]):
return ModelType.Vae
if "string_to_token" in state_dict or "emb_params" in state_dict:
return ModelType.TextualInversion
if any([x.startswith("lora") for x in state_dict.keys()]):
return ModelType.Lora
if any([x.startswith("control_model") for x in state_dict.keys()]):
return ModelType.ControlNet
if any([x.startswith("input_blocks") for x in state_dict.keys()]):
return ModelType.ControlNet
return None # give up
@classmethod @classmethod
def get_model_type_from_folder(cls, folder_path: Path, model: ModelMixin)->ModelType: def get_model_type_from_folder(cls, folder_path: Path, model: ModelMixin)->ModelType:

View File

@ -1,9 +1,12 @@
import json
import os import os
import sys import sys
import typing import typing
import inspect import inspect
from enum import Enum from enum import Enum
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
from pathlib import Path
from picklescan.scanner import scan_file_path
import torch import torch
import safetensors.torch import safetensors.torch
from diffusers import DiffusionPipeline, ConfigMixin from diffusers import DiffusionPipeline, ConfigMixin
@ -382,15 +385,18 @@ def _fast_safetensors_reader(path: str):
return checkpoint return checkpoint
def read_checkpoint_meta(path: Union[str, Path], scan: bool = False):
def read_checkpoint_meta(path: str): if str(path).endswith(".safetensors"):
if path.endswith(".safetensors"):
try: try:
checkpoint = _fast_safetensors_reader(path) checkpoint = _fast_safetensors_reader(path)
except: except:
# TODO: create issue for support "meta"? # TODO: create issue for support "meta"?
checkpoint = safetensors.torch.load_file(path, device="cpu") checkpoint = safetensors.torch.load_file(path, device="cpu")
else: else:
if scan:
scan_result = scan_file_path(checkpoint)
if scan_result.infected_files != 0:
raise Exception(f"The model file \"{path}\" is potentially infected by malware. Aborting import.")
checkpoint = torch.load(path, map_location=torch.device("meta")) checkpoint = torch.load(path, map_location=torch.device("meta"))
return checkpoint return checkpoint