mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
further changes for ruff
This commit is contained in:
@ -14,15 +14,16 @@ from .config import (
|
||||
from .probe import ModelProbe
|
||||
from .search import ModelSearch
|
||||
|
||||
__all__ = ['ModelProbe', 'ModelSearch',
|
||||
'InvalidModelConfigException',
|
||||
'ModelConfigFactory',
|
||||
'BaseModelType',
|
||||
'ModelType',
|
||||
'SubModelType',
|
||||
'ModelVariantType',
|
||||
'ModelFormat',
|
||||
'SchedulerPredictionType',
|
||||
'AnyModelConfig',
|
||||
]
|
||||
|
||||
__all__ = [
|
||||
"ModelProbe",
|
||||
"ModelSearch",
|
||||
"InvalidModelConfigException",
|
||||
"ModelConfigFactory",
|
||||
"BaseModelType",
|
||||
"ModelType",
|
||||
"SubModelType",
|
||||
"ModelVariantType",
|
||||
"ModelFormat",
|
||||
"SchedulerPredictionType",
|
||||
"AnyModelConfig",
|
||||
]
|
||||
|
@ -49,6 +49,7 @@ LEGACY_CONFIGS: Dict[BaseModelType, Dict[ModelVariantType, Union[str, Dict[Sched
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class ProbeBase(object):
|
||||
"""Base class for probes."""
|
||||
|
||||
@ -71,6 +72,7 @@ class ProbeBase(object):
|
||||
"""Get model scheduler prediction type."""
|
||||
return None
|
||||
|
||||
|
||||
class ModelProbe(object):
|
||||
PROBES: Dict[str, Dict[ModelType, type[ProbeBase]]] = {
|
||||
"diffusers": {},
|
||||
@ -100,9 +102,9 @@ class ModelProbe(object):
|
||||
|
||||
@classmethod
|
||||
def heuristic_probe(
|
||||
cls,
|
||||
model_path: Path,
|
||||
fields: Optional[Dict[str, Any]] = None,
|
||||
cls,
|
||||
model_path: Path,
|
||||
fields: Optional[Dict[str, Any]] = None,
|
||||
) -> AnyModelConfig:
|
||||
return cls.probe(model_path, fields)
|
||||
|
||||
@ -138,29 +140,38 @@ class ModelProbe(object):
|
||||
hash = FastModelHash.hash(model_path)
|
||||
probe = probe_class(model_path)
|
||||
|
||||
fields['path'] = model_path.as_posix()
|
||||
fields['type'] = fields.get('type') or model_type
|
||||
fields['base'] = fields.get('base') or probe.get_base_type()
|
||||
fields['variant'] = fields.get('variant') or probe.get_variant_type()
|
||||
fields['prediction_type'] = fields.get('prediction_type') or probe.get_scheduler_prediction_type()
|
||||
fields['name'] = fields.get('name') or cls.get_model_name(model_path)
|
||||
fields['description'] = fields.get('description') or f"{fields['base'].value} {fields['type'].value} model {fields['name']}"
|
||||
fields['format'] = fields.get('format') or probe.get_format()
|
||||
fields['original_hash'] = fields.get('original_hash') or hash
|
||||
fields['current_hash'] = fields.get('current_hash') or hash
|
||||
fields["path"] = model_path.as_posix()
|
||||
fields["type"] = fields.get("type") or model_type
|
||||
fields["base"] = fields.get("base") or probe.get_base_type()
|
||||
fields["variant"] = fields.get("variant") or probe.get_variant_type()
|
||||
fields["prediction_type"] = fields.get("prediction_type") or probe.get_scheduler_prediction_type()
|
||||
fields["name"] = fields.get("name") or cls.get_model_name(model_path)
|
||||
fields["description"] = (
|
||||
fields.get("description") or f"{fields['base'].value} {fields['type'].value} model {fields['name']}"
|
||||
)
|
||||
fields["format"] = fields.get("format") or probe.get_format()
|
||||
fields["original_hash"] = fields.get("original_hash") or hash
|
||||
fields["current_hash"] = fields.get("current_hash") or hash
|
||||
|
||||
# additional fields needed for main and controlnet models
|
||||
if fields['type'] in [ModelType.Main, ModelType.ControlNet] and fields['format'] == ModelFormat.Checkpoint:
|
||||
fields['config'] = cls._get_checkpoint_config_path(model_path,
|
||||
model_type=fields['type'],
|
||||
base_type=fields['base'],
|
||||
variant_type=fields['variant'],
|
||||
prediction_type=fields['prediction_type']).as_posix()
|
||||
if fields["type"] in [ModelType.Main, ModelType.ControlNet] and fields["format"] == ModelFormat.Checkpoint:
|
||||
fields["config"] = cls._get_checkpoint_config_path(
|
||||
model_path,
|
||||
model_type=fields["type"],
|
||||
base_type=fields["base"],
|
||||
variant_type=fields["variant"],
|
||||
prediction_type=fields["prediction_type"],
|
||||
).as_posix()
|
||||
|
||||
# additional fields needed for main non-checkpoint models
|
||||
elif fields['type'] == ModelType.Main and fields['format'] in [ModelFormat.Onnx, ModelFormat.Olive, ModelFormat.Diffusers]:
|
||||
fields['upcast_attention'] = fields.get('upcast_attention') or (
|
||||
fields['base'] == BaseModelType.StableDiffusion2 and fields['prediction_type'] == SchedulerPredictionType.VPrediction
|
||||
elif fields["type"] == ModelType.Main and fields["format"] in [
|
||||
ModelFormat.Onnx,
|
||||
ModelFormat.Olive,
|
||||
ModelFormat.Diffusers,
|
||||
]:
|
||||
fields["upcast_attention"] = fields.get("upcast_attention") or (
|
||||
fields["base"] == BaseModelType.StableDiffusion2
|
||||
and fields["prediction_type"] == SchedulerPredictionType.VPrediction
|
||||
)
|
||||
|
||||
model_info = ModelConfigFactory.make_config(fields)
|
||||
@ -168,7 +179,7 @@ class ModelProbe(object):
|
||||
|
||||
@classmethod
|
||||
def get_model_name(cls, model_path: Path) -> str:
|
||||
if model_path.suffix in {'.safetensors', '.bin', '.pt', '.ckpt'}:
|
||||
if model_path.suffix in {".safetensors", ".bin", ".pt", ".ckpt"}:
|
||||
return model_path.stem
|
||||
else:
|
||||
return model_path.name
|
||||
@ -247,13 +258,14 @@ class ModelProbe(object):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _get_checkpoint_config_path(cls,
|
||||
model_path: Path,
|
||||
model_type: ModelType,
|
||||
base_type: BaseModelType,
|
||||
variant_type: ModelVariantType,
|
||||
prediction_type: SchedulerPredictionType) -> Path:
|
||||
|
||||
def _get_checkpoint_config_path(
|
||||
cls,
|
||||
model_path: Path,
|
||||
model_type: ModelType,
|
||||
base_type: BaseModelType,
|
||||
variant_type: ModelVariantType,
|
||||
prediction_type: SchedulerPredictionType,
|
||||
) -> Path:
|
||||
# look for a YAML file adjacent to the model file first
|
||||
possible_conf = model_path.with_suffix(".yaml")
|
||||
if possible_conf.exists():
|
||||
@ -264,9 +276,13 @@ class ModelProbe(object):
|
||||
if isinstance(config_file, dict): # need another tier for sd-2.x models
|
||||
config_file = config_file[prediction_type]
|
||||
elif model_type == ModelType.ControlNet:
|
||||
config_file = "../controlnet/cldm_v15.yaml" if base_type == BaseModelType("sd-1") else "../controlnet/cldm_v21.yaml"
|
||||
config_file = (
|
||||
"../controlnet/cldm_v15.yaml" if base_type == BaseModelType("sd-1") else "../controlnet/cldm_v21.yaml"
|
||||
)
|
||||
else:
|
||||
raise InvalidModelConfigException(f"{model_path}: Unrecognized combination of model_type={model_type}, base_type={base_type}")
|
||||
raise InvalidModelConfigException(
|
||||
f"{model_path}: Unrecognized combination of model_type={model_type}, base_type={base_type}"
|
||||
)
|
||||
assert isinstance(config_file, str)
|
||||
return Path(config_file)
|
||||
|
||||
@ -297,6 +313,7 @@ class ModelProbe(object):
|
||||
# Checkpoint probing
|
||||
# ##################################################3
|
||||
|
||||
|
||||
class CheckpointProbeBase(ProbeBase):
|
||||
def __init__(self, model_path: Path):
|
||||
super().__init__(model_path)
|
||||
@ -446,7 +463,6 @@ class T2IAdapterCheckpointProbe(CheckpointProbeBase):
|
||||
# classes for probing folders
|
||||
#######################################################
|
||||
class FolderProbeBase(ProbeBase):
|
||||
|
||||
def get_variant_type(self) -> ModelVariantType:
|
||||
return ModelVariantType.Normal
|
||||
|
||||
@ -537,7 +553,9 @@ class TextualInversionFolderProbe(FolderProbeBase):
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
path = self.model_path / "learned_embeds.bin"
|
||||
if not path.exists():
|
||||
raise InvalidModelConfigException(f"{self.model_path.as_posix()} does not contain expected 'learned_embeds.bin' file")
|
||||
raise InvalidModelConfigException(
|
||||
f"{self.model_path.as_posix()} does not contain expected 'learned_embeds.bin' file"
|
||||
)
|
||||
return TextualInversionCheckpointProbe(path).get_base_type()
|
||||
|
||||
|
||||
@ -608,7 +626,9 @@ class IPAdapterFolderProbe(FolderProbeBase):
|
||||
elif cross_attention_dim == 2048:
|
||||
return BaseModelType.StableDiffusionXL
|
||||
else:
|
||||
raise InvalidModelConfigException(f"IP-Adapter had unexpected cross-attention dimension: {cross_attention_dim}.")
|
||||
raise InvalidModelConfigException(
|
||||
f"IP-Adapter had unexpected cross-attention dimension: {cross_attention_dim}."
|
||||
)
|
||||
|
||||
|
||||
class CLIPVisionFolderProbe(FolderProbeBase):
|
||||
|
@ -165,14 +165,14 @@ class ModelSearch(ModelSearchBase):
|
||||
self.scanned_dirs.add(path)
|
||||
continue
|
||||
if any(
|
||||
(path / x).exists()
|
||||
for x in [
|
||||
"config.json",
|
||||
"model_index.json",
|
||||
"learned_embeds.bin",
|
||||
"pytorch_lora_weights.bin",
|
||||
"image_encoder.txt",
|
||||
]
|
||||
(path / x).exists()
|
||||
for x in [
|
||||
"config.json",
|
||||
"model_index.json",
|
||||
"learned_embeds.bin",
|
||||
"pytorch_lora_weights.bin",
|
||||
"image_encoder.txt",
|
||||
]
|
||||
):
|
||||
self.scanned_dirs.add(path)
|
||||
try:
|
||||
|
Reference in New Issue
Block a user