record model_variant in t2i and clip_vision configs (#5989)

- Move base of t2i and clip_vision config models to DiffusersBase, which contains
  a field to record the model variant (e.g. "fp16")
- This restore the ability to load fp16 t2i and clip_vision models
- Also add defensive coding to load the vanilla model when the fp16 model
  has been replaced (or more likely, user's preferences changed since installation)

Co-authored-by: Lincoln Stein <lstein@gmail.com>
This commit is contained in:
Lincoln Stein 2024-03-19 16:14:12 -04:00 committed by GitHub
parent 3f61c51c3a
commit c87497fd54
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 26 additions and 9 deletions

View File

@ -331,7 +331,7 @@ class IPAdapterConfig(ModelConfigBase):
return Tag(f"{ModelType.IPAdapter.value}.{ModelFormat.InvokeAI.value}") return Tag(f"{ModelType.IPAdapter.value}.{ModelFormat.InvokeAI.value}")
class CLIPVisionDiffusersConfig(ModelConfigBase): class CLIPVisionDiffusersConfig(DiffusersConfigBase):
"""Model config for CLIPVision.""" """Model config for CLIPVision."""
type: Literal[ModelType.CLIPVision] = ModelType.CLIPVision type: Literal[ModelType.CLIPVision] = ModelType.CLIPVision
@ -342,7 +342,7 @@ class CLIPVisionDiffusersConfig(ModelConfigBase):
return Tag(f"{ModelType.CLIPVision.value}.{ModelFormat.Diffusers.value}") return Tag(f"{ModelType.CLIPVision.value}.{ModelFormat.Diffusers.value}")
class T2IAdapterConfig(ModelConfigBase, ControlAdapterConfigBase): class T2IAdapterConfig(DiffusersConfigBase, ControlAdapterConfigBase):
"""Model config for T2I.""" """Model config for T2I."""
type: Literal[ModelType.T2IAdapter] = ModelType.T2IAdapter type: Literal[ModelType.T2IAdapter] = ModelType.T2IAdapter

View File

@ -36,7 +36,15 @@ class GenericDiffusersLoader(ModelLoader):
if submodel_type is not None: if submodel_type is not None:
raise Exception(f"There are no submodels in models of type {model_class}") raise Exception(f"There are no submodels in models of type {model_class}")
variant = model_variant.value if model_variant else None variant = model_variant.value if model_variant else None
result: AnyModel = model_class.from_pretrained(model_path, torch_dtype=self._torch_dtype, variant=variant) # type: ignore try:
result: AnyModel = model_class.from_pretrained(model_path, torch_dtype=self._torch_dtype, variant=variant)
except OSError as e:
if variant and "no file named" in str(
e
): # try without the variant, just in case user's preferences changed
result = model_class.from_pretrained(model_path, torch_dtype=self._torch_dtype)
else:
raise e
return result return result
# TO DO: Add exception handling # TO DO: Add exception handling
@ -63,7 +71,7 @@ class GenericDiffusersLoader(ModelLoader):
assert class_name is not None assert class_name is not None
result = self._hf_definition_to_type(module="transformers", class_name=class_name[0]) result = self._hf_definition_to_type(module="transformers", class_name=class_name[0])
if not class_name: if not class_name:
raise InvalidModelConfigException("Unable to decifer Load Class based on given config.json") raise InvalidModelConfigException("Unable to decipher Load Class based on given config.json")
except KeyError as e: except KeyError as e:
raise InvalidModelConfigException("An expected config.json file is missing from this model.") from e raise InvalidModelConfigException("An expected config.json file is missing from this model.") from e
assert result is not None assert result is not None

View File

@ -44,11 +44,20 @@ class StableDiffusionDiffusersModel(GenericDiffusersLoader):
load_class = self.get_hf_load_class(model_path, submodel_type) load_class = self.get_hf_load_class(model_path, submodel_type)
variant = model_variant.value if model_variant else None variant = model_variant.value if model_variant else None
model_path = model_path / submodel_type.value model_path = model_path / submodel_type.value
try:
result: AnyModel = load_class.from_pretrained( result: AnyModel = load_class.from_pretrained(
model_path, model_path,
torch_dtype=self._torch_dtype, torch_dtype=self._torch_dtype,
variant=variant, variant=variant,
) # type: ignore )
except OSError as e:
if variant and "no file named" in str(
e
): # try without the variant, just in case user's preferences changed
result = load_class.from_pretrained(model_path, torch_dtype=self._torch_dtype)
else:
raise e
return result return result
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool: def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool: