Add unit test

This commit is contained in:
Brandon Rising
2024-03-15 14:23:30 -04:00
committed by Brandon
parent d38262a7ea
commit f78ed3a952
4 changed files with 62 additions and 4 deletions

View File

@ -132,7 +132,8 @@ class ModelProbe(object):
format_type = ModelFormat.Diffusers if model_path.is_dir() else ModelFormat.Checkpoint
model_info = None
model_type = ModelType(fields['type']) if 'type' in fields else None
model_type = fields['type'] if 'type' in fields else None
model_type = ModelType(model_type) if isinstance(model_type, str) else model_type
if not model_type:
if format_type is ModelFormat.Diffusers:
model_type = cls.get_model_type_from_folder(model_path)
@ -157,7 +158,7 @@ class ModelProbe(object):
fields["image_encoder_model_id"] = fields.get("image_encoder_model_id") or probe.get_image_encoder_model_id()
fields["name"] = fields.get("name") or cls.get_model_name(model_path)
fields["description"] = (
fields.get("description") or f"{fields['base'].value} {fields['type'].value} model {fields['name']}"
fields.get("description") or f"{fields['base'].value} {model_type.value} model {fields['name']}"
)
fields["format"] = fields.get("format") or probe.get_format()
fields["hash"] = fields.get("hash") or ModelHash(algorithm=hash_algo).hash(model_path)