further changes for ruff

This commit is contained in:
Lincoln Stein
2023-11-26 17:13:31 -05:00
parent 8f4f4d48d5
commit 8ef596eac7
15 changed files with 245 additions and 212 deletions

View File

@ -14,15 +14,16 @@ from .config import (
from .probe import ModelProbe
from .search import ModelSearch
__all__ = ['ModelProbe', 'ModelSearch',
'InvalidModelConfigException',
'ModelConfigFactory',
'BaseModelType',
'ModelType',
'SubModelType',
'ModelVariantType',
'ModelFormat',
'SchedulerPredictionType',
'AnyModelConfig',
]
__all__ = [
"ModelProbe",
"ModelSearch",
"InvalidModelConfigException",
"ModelConfigFactory",
"BaseModelType",
"ModelType",
"SubModelType",
"ModelVariantType",
"ModelFormat",
"SchedulerPredictionType",
"AnyModelConfig",
]

View File

@ -49,6 +49,7 @@ LEGACY_CONFIGS: Dict[BaseModelType, Dict[ModelVariantType, Union[str, Dict[Sched
},
}
class ProbeBase(object):
"""Base class for probes."""
@ -71,6 +72,7 @@ class ProbeBase(object):
"""Get model scheduler prediction type."""
return None
class ModelProbe(object):
PROBES: Dict[str, Dict[ModelType, type[ProbeBase]]] = {
"diffusers": {},
@ -100,9 +102,9 @@ class ModelProbe(object):
@classmethod
def heuristic_probe(
cls,
model_path: Path,
fields: Optional[Dict[str, Any]] = None,
cls,
model_path: Path,
fields: Optional[Dict[str, Any]] = None,
) -> AnyModelConfig:
return cls.probe(model_path, fields)
@ -138,29 +140,38 @@ class ModelProbe(object):
hash = FastModelHash.hash(model_path)
probe = probe_class(model_path)
fields['path'] = model_path.as_posix()
fields['type'] = fields.get('type') or model_type
fields['base'] = fields.get('base') or probe.get_base_type()
fields['variant'] = fields.get('variant') or probe.get_variant_type()
fields['prediction_type'] = fields.get('prediction_type') or probe.get_scheduler_prediction_type()
fields['name'] = fields.get('name') or cls.get_model_name(model_path)
fields['description'] = fields.get('description') or f"{fields['base'].value} {fields['type'].value} model {fields['name']}"
fields['format'] = fields.get('format') or probe.get_format()
fields['original_hash'] = fields.get('original_hash') or hash
fields['current_hash'] = fields.get('current_hash') or hash
fields["path"] = model_path.as_posix()
fields["type"] = fields.get("type") or model_type
fields["base"] = fields.get("base") or probe.get_base_type()
fields["variant"] = fields.get("variant") or probe.get_variant_type()
fields["prediction_type"] = fields.get("prediction_type") or probe.get_scheduler_prediction_type()
fields["name"] = fields.get("name") or cls.get_model_name(model_path)
fields["description"] = (
fields.get("description") or f"{fields['base'].value} {fields['type'].value} model {fields['name']}"
)
fields["format"] = fields.get("format") or probe.get_format()
fields["original_hash"] = fields.get("original_hash") or hash
fields["current_hash"] = fields.get("current_hash") or hash
# additional fields needed for main and controlnet models
if fields['type'] in [ModelType.Main, ModelType.ControlNet] and fields['format'] == ModelFormat.Checkpoint:
fields['config'] = cls._get_checkpoint_config_path(model_path,
model_type=fields['type'],
base_type=fields['base'],
variant_type=fields['variant'],
prediction_type=fields['prediction_type']).as_posix()
if fields["type"] in [ModelType.Main, ModelType.ControlNet] and fields["format"] == ModelFormat.Checkpoint:
fields["config"] = cls._get_checkpoint_config_path(
model_path,
model_type=fields["type"],
base_type=fields["base"],
variant_type=fields["variant"],
prediction_type=fields["prediction_type"],
).as_posix()
# additional fields needed for main non-checkpoint models
elif fields['type'] == ModelType.Main and fields['format'] in [ModelFormat.Onnx, ModelFormat.Olive, ModelFormat.Diffusers]:
fields['upcast_attention'] = fields.get('upcast_attention') or (
fields['base'] == BaseModelType.StableDiffusion2 and fields['prediction_type'] == SchedulerPredictionType.VPrediction
elif fields["type"] == ModelType.Main and fields["format"] in [
ModelFormat.Onnx,
ModelFormat.Olive,
ModelFormat.Diffusers,
]:
fields["upcast_attention"] = fields.get("upcast_attention") or (
fields["base"] == BaseModelType.StableDiffusion2
and fields["prediction_type"] == SchedulerPredictionType.VPrediction
)
model_info = ModelConfigFactory.make_config(fields)
@ -168,7 +179,7 @@ class ModelProbe(object):
@classmethod
def get_model_name(cls, model_path: Path) -> str:
if model_path.suffix in {'.safetensors', '.bin', '.pt', '.ckpt'}:
if model_path.suffix in {".safetensors", ".bin", ".pt", ".ckpt"}:
return model_path.stem
else:
return model_path.name
@ -247,13 +258,14 @@ class ModelProbe(object):
)
@classmethod
def _get_checkpoint_config_path(cls,
model_path: Path,
model_type: ModelType,
base_type: BaseModelType,
variant_type: ModelVariantType,
prediction_type: SchedulerPredictionType) -> Path:
def _get_checkpoint_config_path(
cls,
model_path: Path,
model_type: ModelType,
base_type: BaseModelType,
variant_type: ModelVariantType,
prediction_type: SchedulerPredictionType,
) -> Path:
# look for a YAML file adjacent to the model file first
possible_conf = model_path.with_suffix(".yaml")
if possible_conf.exists():
@ -264,9 +276,13 @@ class ModelProbe(object):
if isinstance(config_file, dict): # need another tier for sd-2.x models
config_file = config_file[prediction_type]
elif model_type == ModelType.ControlNet:
config_file = "../controlnet/cldm_v15.yaml" if base_type == BaseModelType("sd-1") else "../controlnet/cldm_v21.yaml"
config_file = (
"../controlnet/cldm_v15.yaml" if base_type == BaseModelType("sd-1") else "../controlnet/cldm_v21.yaml"
)
else:
raise InvalidModelConfigException(f"{model_path}: Unrecognized combination of model_type={model_type}, base_type={base_type}")
raise InvalidModelConfigException(
f"{model_path}: Unrecognized combination of model_type={model_type}, base_type={base_type}"
)
assert isinstance(config_file, str)
return Path(config_file)
@ -297,6 +313,7 @@ class ModelProbe(object):
# Checkpoint probing
# ##################################################3
class CheckpointProbeBase(ProbeBase):
def __init__(self, model_path: Path):
super().__init__(model_path)
@ -446,7 +463,6 @@ class T2IAdapterCheckpointProbe(CheckpointProbeBase):
# classes for probing folders
#######################################################
class FolderProbeBase(ProbeBase):
def get_variant_type(self) -> ModelVariantType:
return ModelVariantType.Normal
@ -537,7 +553,9 @@ class TextualInversionFolderProbe(FolderProbeBase):
def get_base_type(self) -> BaseModelType:
path = self.model_path / "learned_embeds.bin"
if not path.exists():
raise InvalidModelConfigException(f"{self.model_path.as_posix()} does not contain expected 'learned_embeds.bin' file")
raise InvalidModelConfigException(
f"{self.model_path.as_posix()} does not contain expected 'learned_embeds.bin' file"
)
return TextualInversionCheckpointProbe(path).get_base_type()
@ -608,7 +626,9 @@ class IPAdapterFolderProbe(FolderProbeBase):
elif cross_attention_dim == 2048:
return BaseModelType.StableDiffusionXL
else:
raise InvalidModelConfigException(f"IP-Adapter had unexpected cross-attention dimension: {cross_attention_dim}.")
raise InvalidModelConfigException(
f"IP-Adapter had unexpected cross-attention dimension: {cross_attention_dim}."
)
class CLIPVisionFolderProbe(FolderProbeBase):

View File

@ -165,14 +165,14 @@ class ModelSearch(ModelSearchBase):
self.scanned_dirs.add(path)
continue
if any(
(path / x).exists()
for x in [
"config.json",
"model_index.json",
"learned_embeds.bin",
"pytorch_lora_weights.bin",
"image_encoder.txt",
]
(path / x).exists()
for x in [
"config.json",
"model_index.json",
"learned_embeds.bin",
"pytorch_lora_weights.bin",
"image_encoder.txt",
]
):
self.scanned_dirs.add(path)
try: