mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix model probing for controlnet checkpoint legacy config files
This commit is contained in:
@ -129,7 +129,6 @@ class ModelProbe(object):
|
||||
model_type = cls.get_model_type_from_folder(model_path)
|
||||
else:
|
||||
model_type = cls.get_model_type_from_checkpoint(model_path)
|
||||
print(f'DEBUG: model_type={model_type}')
|
||||
format_type = ModelFormat.Onnx if model_type == ModelType.ONNX else format_type
|
||||
|
||||
probe_class = cls.PROBES[format_type].get(model_type)
|
||||
@ -150,14 +149,19 @@ class ModelProbe(object):
|
||||
fields['original_hash'] = fields.get('original_hash') or hash
|
||||
fields['current_hash'] = fields.get('current_hash') or hash
|
||||
|
||||
# additional work for main models
|
||||
if fields['type'] == ModelType.Main:
|
||||
if fields['format'] == ModelFormat.Checkpoint:
|
||||
fields['config'] = cls._get_config_path(model_path, fields['base'], fields['variant'], fields['prediction_type']).as_posix()
|
||||
elif fields['format'] in [ModelFormat.Onnx, ModelFormat.Olive, ModelFormat.Diffusers]:
|
||||
fields['upcast_attention'] = fields.get('upcast_attention') or (
|
||||
fields['base'] == BaseModelType.StableDiffusion2 and fields['prediction_type'] == SchedulerPredictionType.VPrediction
|
||||
)
|
||||
# additional fields needed for main and controlnet models
|
||||
if fields['type'] in [ModelType.Main, ModelType.ControlNet] and fields['format'] == ModelFormat.Checkpoint:
|
||||
fields['config'] = cls._get_checkpoint_config_path(model_path,
|
||||
model_type=fields['type'],
|
||||
base_type=fields['base'],
|
||||
variant_type=fields['variant'],
|
||||
prediction_type=fields['prediction_type']).as_posix()
|
||||
|
||||
# additional fields needed for main non-checkpoint models
|
||||
elif fields['type'] == ModelType.Main and fields['format'] in [ModelFormat.Onnx, ModelFormat.Olive, ModelFormat.Diffusers]:
|
||||
fields['upcast_attention'] = fields.get('upcast_attention') or (
|
||||
fields['base'] == BaseModelType.StableDiffusion2 and fields['prediction_type'] == SchedulerPredictionType.VPrediction
|
||||
)
|
||||
|
||||
model_info = ModelConfigFactory.make_config(fields)
|
||||
return model_info
|
||||
@ -243,18 +247,27 @@ class ModelProbe(object):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _get_config_path(cls,
|
||||
model_path: Path,
|
||||
base_type: BaseModelType,
|
||||
variant: ModelVariantType,
|
||||
prediction_type: SchedulerPredictionType) -> Path:
|
||||
def _get_checkpoint_config_path(cls,
|
||||
model_path: Path,
|
||||
model_type: ModelType,
|
||||
base_type: BaseModelType,
|
||||
variant_type: ModelVariantType,
|
||||
prediction_type: SchedulerPredictionType) -> Path:
|
||||
|
||||
# look for a YAML file adjacent to the model file first
|
||||
possible_conf = model_path.with_suffix(".yaml")
|
||||
if possible_conf.exists():
|
||||
return possible_conf.absolute()
|
||||
config_file = LEGACY_CONFIGS[base_type][variant]
|
||||
if isinstance(config_file, dict): # need another tier for sd-2.x models
|
||||
config_file = config_file[prediction_type]
|
||||
|
||||
if model_type == ModelType.Main:
|
||||
config_file = LEGACY_CONFIGS[base_type][variant_type]
|
||||
if isinstance(config_file, dict): # need another tier for sd-2.x models
|
||||
config_file = config_file[prediction_type]
|
||||
elif model_type == ModelType.ControlNet:
|
||||
config_file = "../controlnet/cldm_v15.yaml" if base_type == BaseModelType("sd-1") else "../controlnet/cldm_v21.yaml"
|
||||
else:
|
||||
raise InvalidModelConfigException(f"{model_path}: Unrecognized combination of model_type={model_type}, base_type={base_type}")
|
||||
assert isinstance(config_file, str)
|
||||
return Path(config_file)
|
||||
|
||||
@classmethod
|
||||
|
Reference in New Issue
Block a user