fix model probing for controlnet checkpoint legacy config files

This commit is contained in:
Lincoln Stein
2023-11-25 15:53:22 -05:00
parent 19baea1883
commit ec510d34b5
5 changed files with 269 additions and 23 deletions

View File

@ -129,7 +129,6 @@ class ModelProbe(object):
model_type = cls.get_model_type_from_folder(model_path)
else:
model_type = cls.get_model_type_from_checkpoint(model_path)
print(f'DEBUG: model_type={model_type}')
format_type = ModelFormat.Onnx if model_type == ModelType.ONNX else format_type
probe_class = cls.PROBES[format_type].get(model_type)
@ -150,14 +149,19 @@ class ModelProbe(object):
fields['original_hash'] = fields.get('original_hash') or hash
fields['current_hash'] = fields.get('current_hash') or hash
# additional work for main models
if fields['type'] == ModelType.Main:
if fields['format'] == ModelFormat.Checkpoint:
fields['config'] = cls._get_config_path(model_path, fields['base'], fields['variant'], fields['prediction_type']).as_posix()
elif fields['format'] in [ModelFormat.Onnx, ModelFormat.Olive, ModelFormat.Diffusers]:
fields['upcast_attention'] = fields.get('upcast_attention') or (
fields['base'] == BaseModelType.StableDiffusion2 and fields['prediction_type'] == SchedulerPredictionType.VPrediction
)
# additional fields needed for main and controlnet models
if fields['type'] in [ModelType.Main, ModelType.ControlNet] and fields['format'] == ModelFormat.Checkpoint:
fields['config'] = cls._get_checkpoint_config_path(model_path,
model_type=fields['type'],
base_type=fields['base'],
variant_type=fields['variant'],
prediction_type=fields['prediction_type']).as_posix()
# additional fields needed for main non-checkpoint models
elif fields['type'] == ModelType.Main and fields['format'] in [ModelFormat.Onnx, ModelFormat.Olive, ModelFormat.Diffusers]:
fields['upcast_attention'] = fields.get('upcast_attention') or (
fields['base'] == BaseModelType.StableDiffusion2 and fields['prediction_type'] == SchedulerPredictionType.VPrediction
)
model_info = ModelConfigFactory.make_config(fields)
return model_info
@ -243,18 +247,27 @@ class ModelProbe(object):
)
@classmethod
def _get_config_path(cls,
model_path: Path,
base_type: BaseModelType,
variant: ModelVariantType,
prediction_type: SchedulerPredictionType) -> Path:
def _get_checkpoint_config_path(cls,
model_path: Path,
model_type: ModelType,
base_type: BaseModelType,
variant_type: ModelVariantType,
prediction_type: SchedulerPredictionType) -> Path:
# look for a YAML file adjacent to the model file first
possible_conf = model_path.with_suffix(".yaml")
if possible_conf.exists():
return possible_conf.absolute()
config_file = LEGACY_CONFIGS[base_type][variant]
if isinstance(config_file, dict): # need another tier for sd-2.x models
config_file = config_file[prediction_type]
if model_type == ModelType.Main:
config_file = LEGACY_CONFIGS[base_type][variant_type]
if isinstance(config_file, dict): # need another tier for sd-2.x models
config_file = config_file[prediction_type]
elif model_type == ModelType.ControlNet:
config_file = "../controlnet/cldm_v15.yaml" if base_type == BaseModelType("sd-1") else "../controlnet/cldm_v21.yaml"
else:
raise InvalidModelConfigException(f"{model_path}: Unrecognized combination of model_type={model_type}, base_type={base_type}")
assert isinstance(config_file, str)
return Path(config_file)
@classmethod