implemented Stalker's suggested improvements

This commit is contained in:
Lincoln Stein 2023-06-24 12:37:26 -04:00
parent d5f742620f
commit c3c4a71173
3 changed files with 40 additions and 25 deletions

View File

@ -478,7 +478,7 @@ def _parse_legacy_initfile(root: Path, initfile: Path)->ModelPaths:
models = root / 'models',
embeddings = root / str(opt.embedding_path).strip('"'),
loras = root / str(opt.lora_path).strip('"'),
controlnets = None
controlnets = root / 'controlnets',
)
def _parse_legacy_yamlfile(root: Path, initfile: Path)->ModelPaths:

View File

@ -9,7 +9,11 @@ from pathlib import Path
from typing import Callable, Literal, Union, Dict
from picklescan.scanner import scan_file_path
from .models import BaseModelType, ModelType, ModelVariantType, SchedulerPredictionType, SilenceWarnings
from .models import (
BaseModelType, ModelType, ModelVariantType,
SchedulerPredictionType, SilenceWarnings,
)
from .models.base import read_checkpoint_meta
@dataclass
class ModelProbeInfo(object):
@ -105,29 +109,34 @@ class ModelProbe(object):
return model_info
@classmethod
def get_model_type_from_checkpoint(cls, model_path: Path, checkpoint: dict)->ModelType:
if model_path.suffix not in ('.bin','.pt','.ckpt','.safetensors'):
def get_model_type_from_checkpoint(cls, model_path: Path, checkpoint: dict) -> ModelType:
if model_path.suffix not in ('.bin','.pt','.ckpt','.safetensors','.pth'):
return None
if model_path.name=='learned_embeds.bin':
if model_path.name == "learned_embeds.bin":
return ModelType.TextualInversion
checkpoint = checkpoint or cls._scan_and_load_checkpoint(model_path)
state_dict = checkpoint.get("state_dict") or checkpoint
checkpoint = checkpoint or read_checkpoint_meta(model_path, scan=True)
checkpoint = checkpoint.get("state_dict", checkpoint)
for key in checkpoint.keys():
if any(key.startswith(v) for v in {"cond_stage_model.", "first_stage_model.", "model.diffusion_model."}):
return ModelType.Main
elif any(key.startswith(v) for v in {"encoder.conv_in", "decoder.conv_in"}):
return ModelType.Vae
elif any(key.startswith(v) for v in {"lora_te_", "lora_unet_"}):
return ModelType.Lora
elif any(key.startswith(v) for v in {"control_model", "input_blocks"}):
return ModelType.ControlNet
elif key in {"emb_params", "string_to_param"}:
return ModelType.TextualInversion
else:
# diffusers-ti
if len(checkpoint) < 10 and all(isinstance(v, torch.Tensor) for v in checkpoint.values()):
return ModelType.TextualInversion
if len(checkpoint) < 10 and all(isinstance(v, torch.Tensor) for v in checkpoint.values()):
return ModelType.TextualInversion
if any([x.startswith("model.diffusion_model") for x in state_dict.keys()]):
return ModelType.Main
if any([x.startswith("encoder.conv_in") for x in state_dict.keys()]):
return ModelType.Vae
if "string_to_token" in state_dict or "emb_params" in state_dict:
return ModelType.TextualInversion
if any([x.startswith("lora") for x in state_dict.keys()]):
return ModelType.Lora
if any([x.startswith("control_model") for x in state_dict.keys()]):
return ModelType.ControlNet
if any([x.startswith("input_blocks") for x in state_dict.keys()]):
return ModelType.ControlNet
return None # give up
raise ValueError("Unable to determine model type")
@classmethod
def get_model_type_from_folder(cls, folder_path: Path, model: ModelMixin)->ModelType:

View File

@ -1,9 +1,12 @@
import json
import os
import sys
import typing
import inspect
from enum import Enum
from abc import ABCMeta, abstractmethod
from pathlib import Path
from picklescan.scanner import scan_file_path
import torch
import safetensors.torch
from diffusers import DiffusionPipeline, ConfigMixin
@ -382,15 +385,18 @@ def _fast_safetensors_reader(path: str):
return checkpoint
def read_checkpoint_meta(path: str):
if path.endswith(".safetensors"):
def read_checkpoint_meta(path: Union[str, Path], scan: bool = False):
if str(path).endswith(".safetensors"):
try:
checkpoint = _fast_safetensors_reader(path)
except:
# TODO: create issue for support "meta"?
checkpoint = safetensors.torch.load_file(path, device="cpu")
else:
if scan:
scan_result = scan_file_path(checkpoint)
if scan_result.infected_files != 0:
raise Exception(f"The model file \"{path}\" is potentially infected by malware. Aborting import.")
checkpoint = torch.load(path, map_location=torch.device("meta"))
return checkpoint