mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Remove core safetensors->diffusers conversion models
- No longer install core conversion models. Use the HuggingFace cache to load them if and when needed. - Call directly into the diffusers library to perform conversions with only shallow wrappers around them to massage arguments, etc. - At root configuration time, do not create all the possible model subdirectories, but let them be created and populated at model install time. - Remove checks for missing core conversion files, since they are no longer installed.
This commit is contained in:
File diff suppressed because it is too large
Load Diff
@ -3,9 +3,6 @@
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from safetensors.torch import load_file as safetensors_load_file
|
||||
|
||||
from invokeai.backend.model_manager import (
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
@ -37,27 +34,25 @@ class ControlNetLoader(GenericDiffusersLoader):
|
||||
return True
|
||||
|
||||
def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Path) -> Path:
|
||||
if config.base not in {BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2}:
|
||||
raise Exception(f"ControlNet conversion not supported for model type: {config.base}")
|
||||
else:
|
||||
assert isinstance(config, CheckpointConfigBase)
|
||||
config_file = config.config_path
|
||||
assert isinstance(config, CheckpointConfigBase)
|
||||
config_file = config.config_path
|
||||
|
||||
if model_path.suffix == ".safetensors":
|
||||
checkpoint = safetensors_load_file(model_path, device="cpu")
|
||||
else:
|
||||
checkpoint = torch.load(model_path, map_location="cpu")
|
||||
|
||||
# sometimes weights are hidden under "state_dict", and sometimes not
|
||||
if "state_dict" in checkpoint:
|
||||
checkpoint = checkpoint["state_dict"]
|
||||
|
||||
convert_controlnet_to_diffusers(
|
||||
model_path,
|
||||
output_path,
|
||||
original_config_file=self._app_config.root_path / config_file,
|
||||
image_size=512,
|
||||
scan_needed=True,
|
||||
from_safetensors=model_path.suffix == ".safetensors",
|
||||
image_size = (
|
||||
512
|
||||
if config.base == BaseModelType.StableDiffusion1
|
||||
else 768
|
||||
if config.base == BaseModelType.StableDiffusion2
|
||||
else 1024
|
||||
)
|
||||
|
||||
self._logger.info(f"Converting {model_path} to diffusers format")
|
||||
with open(self._app_config.root_path / config_file, "r") as config_stream:
|
||||
convert_controlnet_to_diffusers(
|
||||
model_path,
|
||||
output_path,
|
||||
original_config_file=config_stream,
|
||||
image_size=image_size,
|
||||
precision=self._torch_dtype,
|
||||
from_safetensors=model_path.suffix == ".safetensors",
|
||||
)
|
||||
return output_path
|
||||
|
@ -4,9 +4,6 @@
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipeline
|
||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint import StableDiffusionInpaintPipeline
|
||||
|
||||
from invokeai.backend.model_manager import (
|
||||
AnyModel,
|
||||
AnyModelConfig,
|
||||
@ -14,7 +11,7 @@ from invokeai.backend.model_manager import (
|
||||
ModelFormat,
|
||||
ModelRepoVariant,
|
||||
ModelType,
|
||||
ModelVariantType,
|
||||
SchedulerPredictionType,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.config import CheckpointConfigBase, MainCheckpointConfig
|
||||
@ -68,27 +65,31 @@ class StableDiffusionDiffusersModel(GenericDiffusersLoader):
|
||||
|
||||
def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Path) -> Path:
|
||||
assert isinstance(config, MainCheckpointConfig)
|
||||
variant = config.variant
|
||||
base = config.base
|
||||
pipeline_class = (
|
||||
StableDiffusionInpaintPipeline if variant == ModelVariantType.Inpaint else StableDiffusionPipeline
|
||||
)
|
||||
|
||||
config_file = config.config_path
|
||||
prediction_type = config.prediction_type.value
|
||||
upcast_attention = config.upcast_attention
|
||||
image_size = (
|
||||
1024
|
||||
if base == BaseModelType.StableDiffusionXL
|
||||
else 768
|
||||
if config.prediction_type == SchedulerPredictionType.VPrediction and base == BaseModelType.StableDiffusion2
|
||||
else 512
|
||||
)
|
||||
|
||||
self._logger.info(f"Converting {model_path} to diffusers format")
|
||||
convert_ckpt_to_diffusers(
|
||||
model_path,
|
||||
output_path,
|
||||
model_type=self.model_base_to_model_type[base],
|
||||
model_version=base,
|
||||
model_variant=variant,
|
||||
original_config_file=self._app_config.root_path / config_file,
|
||||
extract_ema=True,
|
||||
scan_needed=True,
|
||||
pipeline_class=pipeline_class,
|
||||
from_safetensors=model_path.suffix == ".safetensors",
|
||||
precision=self._torch_dtype,
|
||||
prediction_type=prediction_type,
|
||||
image_size=image_size,
|
||||
upcast_attention=upcast_attention,
|
||||
load_safety_checker=False,
|
||||
)
|
||||
return output_path
|
||||
|
@ -57,12 +57,12 @@ class VAELoader(GenericDiffusersLoader):
|
||||
|
||||
ckpt_config = OmegaConf.load(self._app_config.root_path / config_file)
|
||||
assert isinstance(ckpt_config, DictConfig)
|
||||
|
||||
self._logger.info(f"Converting {model_path} to diffusers format")
|
||||
vae_model = convert_ldm_vae_to_diffusers(
|
||||
checkpoint=checkpoint,
|
||||
vae_config=ckpt_config,
|
||||
image_size=512,
|
||||
precision=self._torch_dtype,
|
||||
)
|
||||
vae_model.to(self._torch_dtype) # set precision appropriately
|
||||
vae_model.save_pretrained(output_path, safe_serialization=True)
|
||||
return output_path
|
||||
|
@ -319,7 +319,7 @@ class ModelProbe(object):
|
||||
@classmethod
|
||||
def _scan_and_load_checkpoint(cls, model_path: Path) -> CkptType:
|
||||
with SilenceWarnings():
|
||||
if model_path.suffix.endswith((".ckpt", ".pt", ".bin")):
|
||||
if model_path.suffix.endswith((".ckpt", ".pt", ".pth", ".bin")):
|
||||
cls._scan_model(model_path.name, model_path)
|
||||
model = torch.load(model_path)
|
||||
assert isinstance(model, dict)
|
||||
|
Reference in New Issue
Block a user