mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
re-implement model scanning when loading legacy checkpoint files (#3012)
- This PR turns on pickle scanning before a legacy checkpoint file is loaded from disk within the checkpoint_to_diffusers module. - Also miscellaneous diagnostic message cleanup. - See also #3011 for a similar patch to the 2.3 branch.
This commit is contained in:
commit
bc01a96f9d
@ -372,9 +372,9 @@ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False
|
|||||||
unet_key = "model.diffusion_model."
|
unet_key = "model.diffusion_model."
|
||||||
# at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA
|
# at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA
|
||||||
if sum(k.startswith("model_ema") for k in keys) > 100:
|
if sum(k.startswith("model_ema") for k in keys) > 100:
|
||||||
print(f" | Checkpoint {path} has both EMA and non-EMA weights.")
|
print(f" | Checkpoint {path} has both EMA and non-EMA weights.")
|
||||||
if extract_ema:
|
if extract_ema:
|
||||||
print(" | Extracting EMA weights (usually better for inference)")
|
print(" | Extracting EMA weights (usually better for inference)")
|
||||||
for key in keys:
|
for key in keys:
|
||||||
if key.startswith("model.diffusion_model"):
|
if key.startswith("model.diffusion_model"):
|
||||||
flat_ema_key = "model_ema." + "".join(key.split(".")[1:])
|
flat_ema_key = "model_ema." + "".join(key.split(".")[1:])
|
||||||
@ -393,7 +393,7 @@ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
print(
|
print(
|
||||||
" | Extracting only the non-EMA weights (usually better for fine-tuning)"
|
" | Extracting only the non-EMA weights (usually better for fine-tuning)"
|
||||||
)
|
)
|
||||||
|
|
||||||
for key in keys:
|
for key in keys:
|
||||||
@ -1050,6 +1050,7 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
|
|||||||
vae: AutoencoderKL = None,
|
vae: AutoencoderKL = None,
|
||||||
precision: torch.dtype = torch.float32,
|
precision: torch.dtype = torch.float32,
|
||||||
return_generator_pipeline: bool = False,
|
return_generator_pipeline: bool = False,
|
||||||
|
scan_needed:bool=True,
|
||||||
) -> Union[StableDiffusionPipeline, StableDiffusionGeneratorPipeline]:
|
) -> Union[StableDiffusionPipeline, StableDiffusionGeneratorPipeline]:
|
||||||
"""
|
"""
|
||||||
Load a Stable Diffusion pipeline object from a CompVis-style `.ckpt`/`.safetensors` file and (ideally) a `.yaml`
|
Load a Stable Diffusion pipeline object from a CompVis-style `.ckpt`/`.safetensors` file and (ideally) a `.yaml`
|
||||||
@ -1084,12 +1085,13 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
|
|||||||
verbosity = dlogging.get_verbosity()
|
verbosity = dlogging.get_verbosity()
|
||||||
dlogging.set_verbosity_error()
|
dlogging.set_verbosity_error()
|
||||||
|
|
||||||
checkpoint = (
|
if Path(checkpoint_path).suffix == '.ckpt':
|
||||||
torch.load(checkpoint_path)
|
if scan_needed:
|
||||||
if Path(checkpoint_path).suffix == ".ckpt"
|
ModelManager.scan_model(checkpoint_path,checkpoint_path)
|
||||||
else load_file(checkpoint_path)
|
checkpoint = torch.load(checkpoint_path)
|
||||||
|
else:
|
||||||
)
|
checkpoint = load_file(checkpoint_path)
|
||||||
|
|
||||||
cache_dir = global_cache_dir("hub")
|
cache_dir = global_cache_dir("hub")
|
||||||
pipeline_class = (
|
pipeline_class = (
|
||||||
StableDiffusionGeneratorPipeline
|
StableDiffusionGeneratorPipeline
|
||||||
@ -1101,7 +1103,7 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
|
|||||||
if "global_step" in checkpoint:
|
if "global_step" in checkpoint:
|
||||||
global_step = checkpoint["global_step"]
|
global_step = checkpoint["global_step"]
|
||||||
else:
|
else:
|
||||||
print(" | global_step key not found in model")
|
print(" | global_step key not found in model")
|
||||||
global_step = None
|
global_step = None
|
||||||
|
|
||||||
# sometimes there is a state_dict key and sometimes not
|
# sometimes there is a state_dict key and sometimes not
|
||||||
@ -1214,7 +1216,7 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
|
|||||||
|
|
||||||
# Convert the VAE model, or use the one passed
|
# Convert the VAE model, or use the one passed
|
||||||
if not vae:
|
if not vae:
|
||||||
print(" | Using checkpoint model's original VAE")
|
print(" | Using checkpoint model's original VAE")
|
||||||
vae_config = create_vae_diffusers_config(
|
vae_config = create_vae_diffusers_config(
|
||||||
original_config, image_size=image_size
|
original_config, image_size=image_size
|
||||||
)
|
)
|
||||||
@ -1225,7 +1227,7 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
|
|||||||
vae = AutoencoderKL(**vae_config)
|
vae = AutoencoderKL(**vae_config)
|
||||||
vae.load_state_dict(converted_vae_checkpoint)
|
vae.load_state_dict(converted_vae_checkpoint)
|
||||||
else:
|
else:
|
||||||
print(" | Using external VAE specified in config")
|
print(" | Using external VAE specified in config")
|
||||||
|
|
||||||
# Convert the text model.
|
# Convert the text model.
|
||||||
model_type = pipeline_type
|
model_type = pipeline_type
|
||||||
|
@ -34,7 +34,7 @@ from picklescan.scanner import scan_file_path
|
|||||||
from invokeai.backend.globals import Globals, global_cache_dir
|
from invokeai.backend.globals import Globals, global_cache_dir
|
||||||
|
|
||||||
from ..stable_diffusion import StableDiffusionGeneratorPipeline
|
from ..stable_diffusion import StableDiffusionGeneratorPipeline
|
||||||
from ..util import CUDA_DEVICE, CPU_DEVICE, ask_user, download_with_resume
|
from ..util import CUDA_DEVICE, ask_user, download_with_resume
|
||||||
|
|
||||||
class SDLegacyType(Enum):
|
class SDLegacyType(Enum):
|
||||||
V1 = 1
|
V1 = 1
|
||||||
@ -285,13 +285,13 @@ class ModelManager(object):
|
|||||||
self.stack.remove(model_name)
|
self.stack.remove(model_name)
|
||||||
if delete_files:
|
if delete_files:
|
||||||
if weights:
|
if weights:
|
||||||
print(f"** deleting file {weights}")
|
print(f"** Deleting file {weights}")
|
||||||
Path(weights).unlink(missing_ok=True)
|
Path(weights).unlink(missing_ok=True)
|
||||||
elif path:
|
elif path:
|
||||||
print(f"** deleting directory {path}")
|
print(f"** Deleting directory {path}")
|
||||||
rmtree(path, ignore_errors=True)
|
rmtree(path, ignore_errors=True)
|
||||||
elif repo_id:
|
elif repo_id:
|
||||||
print(f"** deleting the cached model directory for {repo_id}")
|
print(f"** Deleting the cached model directory for {repo_id}")
|
||||||
self._delete_model_from_cache(repo_id)
|
self._delete_model_from_cache(repo_id)
|
||||||
|
|
||||||
def add_model(
|
def add_model(
|
||||||
@ -382,9 +382,9 @@ class ModelManager(object):
|
|||||||
|
|
||||||
print(f">> Loading diffusers model from {name_or_path}")
|
print(f">> Loading diffusers model from {name_or_path}")
|
||||||
if using_fp16:
|
if using_fp16:
|
||||||
print(" | Using faster float16 precision")
|
print(" | Using faster float16 precision")
|
||||||
else:
|
else:
|
||||||
print(" | Using more accurate float32 precision")
|
print(" | Using more accurate float32 precision")
|
||||||
|
|
||||||
# TODO: scan weights maybe?
|
# TODO: scan weights maybe?
|
||||||
pipeline_args: dict[str, Any] = dict(
|
pipeline_args: dict[str, Any] = dict(
|
||||||
@ -435,9 +435,9 @@ class ModelManager(object):
|
|||||||
# square images???
|
# square images???
|
||||||
width = pipeline.unet.config.sample_size * pipeline.vae_scale_factor
|
width = pipeline.unet.config.sample_size * pipeline.vae_scale_factor
|
||||||
height = width
|
height = width
|
||||||
|
print(f" | Default image dimensions = {width} x {height}")
|
||||||
print(f" | Default image dimensions = {width} x {height}")
|
|
||||||
|
self._add_embeddings_to_model(pipeline)
|
||||||
return pipeline, width, height, model_hash
|
return pipeline, width, height, model_hash
|
||||||
|
|
||||||
def _load_ckpt_model(self, model_name, mconfig):
|
def _load_ckpt_model(self, model_name, mconfig):
|
||||||
@ -517,13 +517,14 @@ class ModelManager(object):
|
|||||||
if self._has_cuda():
|
if self._has_cuda():
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
def scan_model(self, model_name, checkpoint):
|
def scan_model(self, model_name, checkpoint):
|
||||||
"""
|
"""
|
||||||
Apply picklescanner to the indicated checkpoint and issue a warning
|
Apply picklescanner to the indicated checkpoint and issue a warning
|
||||||
and option to exit if an infected file is identified.
|
and option to exit if an infected file is identified.
|
||||||
"""
|
"""
|
||||||
# scan model
|
# scan model
|
||||||
print(f">> Scanning Model: {model_name}")
|
print(f" | Scanning Model: {model_name}")
|
||||||
scan_result = scan_file_path(checkpoint)
|
scan_result = scan_file_path(checkpoint)
|
||||||
if scan_result.infected_files != 0:
|
if scan_result.infected_files != 0:
|
||||||
if scan_result.infected_files == 1:
|
if scan_result.infected_files == 1:
|
||||||
@ -546,7 +547,7 @@ class ModelManager(object):
|
|||||||
print("### Exiting InvokeAI")
|
print("### Exiting InvokeAI")
|
||||||
sys.exit()
|
sys.exit()
|
||||||
else:
|
else:
|
||||||
print(">> Model scanned ok")
|
print(" | Model scanned ok")
|
||||||
|
|
||||||
def import_diffuser_model(
|
def import_diffuser_model(
|
||||||
self,
|
self,
|
||||||
@ -665,7 +666,7 @@ class ModelManager(object):
|
|||||||
print(f">> Probing {thing} for import")
|
print(f">> Probing {thing} for import")
|
||||||
|
|
||||||
if thing.startswith(("http:", "https:", "ftp:")):
|
if thing.startswith(("http:", "https:", "ftp:")):
|
||||||
print(f" | {thing} appears to be a URL")
|
print(f" | {thing} appears to be a URL")
|
||||||
model_path = self._resolve_path(
|
model_path = self._resolve_path(
|
||||||
thing, "models/ldm/stable-diffusion-v1"
|
thing, "models/ldm/stable-diffusion-v1"
|
||||||
) # _resolve_path does a download if needed
|
) # _resolve_path does a download if needed
|
||||||
@ -673,15 +674,15 @@ class ModelManager(object):
|
|||||||
elif Path(thing).is_file() and thing.endswith((".ckpt", ".safetensors")):
|
elif Path(thing).is_file() and thing.endswith((".ckpt", ".safetensors")):
|
||||||
if Path(thing).stem in ["model", "diffusion_pytorch_model"]:
|
if Path(thing).stem in ["model", "diffusion_pytorch_model"]:
|
||||||
print(
|
print(
|
||||||
f" | {Path(thing).name} appears to be part of a diffusers model. Skipping import"
|
f" | {Path(thing).name} appears to be part of a diffusers model. Skipping import"
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
else:
|
else:
|
||||||
print(f" | {thing} appears to be a checkpoint file on disk")
|
print(f" | {thing} appears to be a checkpoint file on disk")
|
||||||
model_path = self._resolve_path(thing, "models/ldm/stable-diffusion-v1")
|
model_path = self._resolve_path(thing, "models/ldm/stable-diffusion-v1")
|
||||||
|
|
||||||
elif Path(thing).is_dir() and Path(thing, "model_index.json").exists():
|
elif Path(thing).is_dir() and Path(thing, "model_index.json").exists():
|
||||||
print(f" | {thing} appears to be a diffusers file on disk")
|
print(f" | {thing} appears to be a diffusers file on disk")
|
||||||
model_name = self.import_diffuser_model(
|
model_name = self.import_diffuser_model(
|
||||||
thing,
|
thing,
|
||||||
vae=dict(repo_id="stabilityai/sd-vae-ft-mse"),
|
vae=dict(repo_id="stabilityai/sd-vae-ft-mse"),
|
||||||
@ -692,13 +693,13 @@ class ModelManager(object):
|
|||||||
|
|
||||||
elif Path(thing).is_dir():
|
elif Path(thing).is_dir():
|
||||||
if (Path(thing) / "model_index.json").exists():
|
if (Path(thing) / "model_index.json").exists():
|
||||||
print(f" | {thing} appears to be a diffusers model.")
|
print(f" | {thing} appears to be a diffusers model.")
|
||||||
model_name = self.import_diffuser_model(
|
model_name = self.import_diffuser_model(
|
||||||
thing, commit_to_conf=commit_to_conf
|
thing, commit_to_conf=commit_to_conf
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
print(
|
print(
|
||||||
f" |{thing} appears to be a directory. Will scan for models to import"
|
f" |{thing} appears to be a directory. Will scan for models to import"
|
||||||
)
|
)
|
||||||
for m in list(Path(thing).rglob("*.ckpt")) + list(
|
for m in list(Path(thing).rglob("*.ckpt")) + list(
|
||||||
Path(thing).rglob("*.safetensors")
|
Path(thing).rglob("*.safetensors")
|
||||||
@ -710,7 +711,7 @@ class ModelManager(object):
|
|||||||
return model_name
|
return model_name
|
||||||
|
|
||||||
elif re.match(r"^[\w.+-]+/[\w.+-]+$", thing):
|
elif re.match(r"^[\w.+-]+/[\w.+-]+$", thing):
|
||||||
print(f" | {thing} appears to be a HuggingFace diffusers repo_id")
|
print(f" | {thing} appears to be a HuggingFace diffusers repo_id")
|
||||||
model_name = self.import_diffuser_model(
|
model_name = self.import_diffuser_model(
|
||||||
thing, commit_to_conf=commit_to_conf
|
thing, commit_to_conf=commit_to_conf
|
||||||
)
|
)
|
||||||
@ -727,32 +728,33 @@ class ModelManager(object):
|
|||||||
return
|
return
|
||||||
|
|
||||||
if model_path.stem in self.config: # already imported
|
if model_path.stem in self.config: # already imported
|
||||||
print(" | Already imported. Skipping")
|
print(" | Already imported. Skipping")
|
||||||
return model_path.stem
|
return model_path.stem
|
||||||
|
|
||||||
# another round of heuristics to guess the correct config file.
|
# another round of heuristics to guess the correct config file.
|
||||||
checkpoint = (
|
checkpoint = None
|
||||||
torch.load(model_path)
|
if model_path.suffix.endswith((".ckpt",".pt")):
|
||||||
if model_path.suffix == ".ckpt"
|
self.scan_model(model_path,model_path)
|
||||||
else safetensors.torch.load_file(model_path)
|
checkpoint = torch.load(model_path)
|
||||||
)
|
else:
|
||||||
|
checkpoint = safetensors.torch.load_file(model_path)
|
||||||
|
|
||||||
# additional probing needed if no config file provided
|
# additional probing needed if no config file provided
|
||||||
if model_config_file is None:
|
if model_config_file is None:
|
||||||
model_type = self.probe_model_type(checkpoint)
|
model_type = self.probe_model_type(checkpoint)
|
||||||
if model_type == SDLegacyType.V1:
|
if model_type == SDLegacyType.V1:
|
||||||
print(" | SD-v1 model detected")
|
print(" | SD-v1 model detected")
|
||||||
model_config_file = Path(
|
model_config_file = Path(
|
||||||
Globals.root, "configs/stable-diffusion/v1-inference.yaml"
|
Globals.root, "configs/stable-diffusion/v1-inference.yaml"
|
||||||
)
|
)
|
||||||
elif model_type == SDLegacyType.V1_INPAINT:
|
elif model_type == SDLegacyType.V1_INPAINT:
|
||||||
print(" | SD-v1 inpainting model detected")
|
print(" | SD-v1 inpainting model detected")
|
||||||
model_config_file = Path(
|
model_config_file = Path(
|
||||||
Globals.root, "configs/stable-diffusion/v1-inpainting-inference.yaml"
|
Globals.root, "configs/stable-diffusion/v1-inpainting-inference.yaml"
|
||||||
)
|
)
|
||||||
elif model_type == SDLegacyType.V2_v:
|
elif model_type == SDLegacyType.V2_v:
|
||||||
print(
|
print(
|
||||||
" | SD-v2-v model detected; model will be converted to diffusers format"
|
" | SD-v2-v model detected; model will be converted to diffusers format"
|
||||||
)
|
)
|
||||||
model_config_file = Path(
|
model_config_file = Path(
|
||||||
Globals.root, "configs/stable-diffusion/v2-inference-v.yaml"
|
Globals.root, "configs/stable-diffusion/v2-inference-v.yaml"
|
||||||
@ -760,7 +762,7 @@ class ModelManager(object):
|
|||||||
convert = True
|
convert = True
|
||||||
elif model_type == SDLegacyType.V2_e:
|
elif model_type == SDLegacyType.V2_e:
|
||||||
print(
|
print(
|
||||||
" | SD-v2-e model detected; model will be converted to diffusers format"
|
" | SD-v2-e model detected; model will be converted to diffusers format"
|
||||||
)
|
)
|
||||||
model_config_file = Path(
|
model_config_file = Path(
|
||||||
Globals.root, "configs/stable-diffusion/v2-inference.yaml"
|
Globals.root, "configs/stable-diffusion/v2-inference.yaml"
|
||||||
@ -788,6 +790,7 @@ class ModelManager(object):
|
|||||||
model_description=description,
|
model_description=description,
|
||||||
original_config_file=model_config_file,
|
original_config_file=model_config_file,
|
||||||
commit_to_conf=commit_to_conf,
|
commit_to_conf=commit_to_conf,
|
||||||
|
scan_needed=False,
|
||||||
)
|
)
|
||||||
return model_name
|
return model_name
|
||||||
|
|
||||||
@ -800,6 +803,7 @@ class ModelManager(object):
|
|||||||
vae=None,
|
vae=None,
|
||||||
original_config_file: Path = None,
|
original_config_file: Path = None,
|
||||||
commit_to_conf: Path = None,
|
commit_to_conf: Path = None,
|
||||||
|
scan_needed: bool=True,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Convert a legacy ckpt weights file to diffuser model and import
|
Convert a legacy ckpt weights file to diffuser model and import
|
||||||
@ -834,11 +838,12 @@ class ModelManager(object):
|
|||||||
extract_ema=True,
|
extract_ema=True,
|
||||||
original_config_file=original_config_file,
|
original_config_file=original_config_file,
|
||||||
vae=vae_model,
|
vae=vae_model,
|
||||||
|
scan_needed=scan_needed,
|
||||||
)
|
)
|
||||||
print(
|
print(
|
||||||
f" | Success. Optimized model is now located at {str(diffusers_path)}"
|
f" | Success. Optimized model is now located at {str(diffusers_path)}"
|
||||||
)
|
)
|
||||||
print(f" | Writing new config file entry for {model_name}")
|
print(f" | Writing new config file entry for {model_name}")
|
||||||
new_config = dict(
|
new_config = dict(
|
||||||
path=str(diffusers_path),
|
path=str(diffusers_path),
|
||||||
description=model_description,
|
description=model_description,
|
||||||
@ -849,7 +854,7 @@ class ModelManager(object):
|
|||||||
self.add_model(model_name, new_config, True)
|
self.add_model(model_name, new_config, True)
|
||||||
if commit_to_conf:
|
if commit_to_conf:
|
||||||
self.commit(commit_to_conf)
|
self.commit(commit_to_conf)
|
||||||
print(">> Conversion succeeded")
|
print(" | Conversion succeeded")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"** Conversion failed: {str(e)}")
|
print(f"** Conversion failed: {str(e)}")
|
||||||
print(
|
print(
|
||||||
@ -1105,7 +1110,7 @@ class ModelManager(object):
|
|||||||
with open(hashpath) as f:
|
with open(hashpath) as f:
|
||||||
hash = f.read()
|
hash = f.read()
|
||||||
return hash
|
return hash
|
||||||
print(" | Calculating sha256 hash of model files")
|
print(" | Calculating sha256 hash of model files")
|
||||||
tic = time.time()
|
tic = time.time()
|
||||||
sha = hashlib.sha256()
|
sha = hashlib.sha256()
|
||||||
count = 0
|
count = 0
|
||||||
@ -1117,7 +1122,7 @@ class ModelManager(object):
|
|||||||
sha.update(chunk)
|
sha.update(chunk)
|
||||||
hash = sha.hexdigest()
|
hash = sha.hexdigest()
|
||||||
toc = time.time()
|
toc = time.time()
|
||||||
print(f" | sha256 = {hash} ({count} files hashed in", "%4.2fs)" % (toc - tic))
|
print(f" | sha256 = {hash} ({count} files hashed in", "%4.2fs)" % (toc - tic))
|
||||||
with open(hashpath, "w") as f:
|
with open(hashpath, "w") as f:
|
||||||
f.write(hash)
|
f.write(hash)
|
||||||
return hash
|
return hash
|
||||||
@ -1162,12 +1167,12 @@ class ModelManager(object):
|
|||||||
local_files_only=not Globals.internet_available,
|
local_files_only=not Globals.internet_available,
|
||||||
)
|
)
|
||||||
|
|
||||||
print(f" | Loading diffusers VAE from {name_or_path}")
|
print(f" | Loading diffusers VAE from {name_or_path}")
|
||||||
if using_fp16:
|
if using_fp16:
|
||||||
vae_args.update(torch_dtype=torch.float16)
|
vae_args.update(torch_dtype=torch.float16)
|
||||||
fp_args_list = [{"revision": "fp16"}, {}]
|
fp_args_list = [{"revision": "fp16"}, {}]
|
||||||
else:
|
else:
|
||||||
print(" | Using more accurate float32 precision")
|
print(" | Using more accurate float32 precision")
|
||||||
fp_args_list = [{}]
|
fp_args_list = [{}]
|
||||||
|
|
||||||
vae = None
|
vae = None
|
||||||
@ -1208,7 +1213,7 @@ class ModelManager(object):
|
|||||||
hashes_to_delete.add(revision.commit_hash)
|
hashes_to_delete.add(revision.commit_hash)
|
||||||
strategy = cache_info.delete_revisions(*hashes_to_delete)
|
strategy = cache_info.delete_revisions(*hashes_to_delete)
|
||||||
print(
|
print(
|
||||||
f"** deletion of this model is expected to free {strategy.expected_freed_size_str}"
|
f"** Deletion of this model is expected to free {strategy.expected_freed_size_str}"
|
||||||
)
|
)
|
||||||
strategy.execute()
|
strategy.execute()
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user