re-implement model scanning when loading legacy checkpoint files (#3012)

- This PR turns on pickle scanning before a legacy checkpoint file is
loaded from disk within the checkpoint_to_diffusers module.

- Also miscellaneous diagnostic message cleanup.

- See also #3011 for a similar patch to the 2.3 branch.
This commit is contained in:
Lincoln Stein 2023-03-24 08:57:07 -04:00 committed by GitHub
commit bc01a96f9d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 55 additions and 48 deletions

View File

@ -1050,6 +1050,7 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
vae: AutoencoderKL = None, vae: AutoencoderKL = None,
precision: torch.dtype = torch.float32, precision: torch.dtype = torch.float32,
return_generator_pipeline: bool = False, return_generator_pipeline: bool = False,
scan_needed:bool=True,
) -> Union[StableDiffusionPipeline, StableDiffusionGeneratorPipeline]: ) -> Union[StableDiffusionPipeline, StableDiffusionGeneratorPipeline]:
""" """
Load a Stable Diffusion pipeline object from a CompVis-style `.ckpt`/`.safetensors` file and (ideally) a `.yaml` Load a Stable Diffusion pipeline object from a CompVis-style `.ckpt`/`.safetensors` file and (ideally) a `.yaml`
@ -1084,12 +1085,13 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
verbosity = dlogging.get_verbosity() verbosity = dlogging.get_verbosity()
dlogging.set_verbosity_error() dlogging.set_verbosity_error()
checkpoint = ( if Path(checkpoint_path).suffix == '.ckpt':
torch.load(checkpoint_path) if scan_needed:
if Path(checkpoint_path).suffix == ".ckpt" ModelManager.scan_model(checkpoint_path,checkpoint_path)
else load_file(checkpoint_path) checkpoint = torch.load(checkpoint_path)
else:
checkpoint = load_file(checkpoint_path)
)
cache_dir = global_cache_dir("hub") cache_dir = global_cache_dir("hub")
pipeline_class = ( pipeline_class = (
StableDiffusionGeneratorPipeline StableDiffusionGeneratorPipeline

View File

@ -34,7 +34,7 @@ from picklescan.scanner import scan_file_path
from invokeai.backend.globals import Globals, global_cache_dir from invokeai.backend.globals import Globals, global_cache_dir
from ..stable_diffusion import StableDiffusionGeneratorPipeline from ..stable_diffusion import StableDiffusionGeneratorPipeline
from ..util import CUDA_DEVICE, CPU_DEVICE, ask_user, download_with_resume from ..util import CUDA_DEVICE, ask_user, download_with_resume
class SDLegacyType(Enum): class SDLegacyType(Enum):
V1 = 1 V1 = 1
@ -285,13 +285,13 @@ class ModelManager(object):
self.stack.remove(model_name) self.stack.remove(model_name)
if delete_files: if delete_files:
if weights: if weights:
print(f"** deleting file {weights}") print(f"** Deleting file {weights}")
Path(weights).unlink(missing_ok=True) Path(weights).unlink(missing_ok=True)
elif path: elif path:
print(f"** deleting directory {path}") print(f"** Deleting directory {path}")
rmtree(path, ignore_errors=True) rmtree(path, ignore_errors=True)
elif repo_id: elif repo_id:
print(f"** deleting the cached model directory for {repo_id}") print(f"** Deleting the cached model directory for {repo_id}")
self._delete_model_from_cache(repo_id) self._delete_model_from_cache(repo_id)
def add_model( def add_model(
@ -435,9 +435,9 @@ class ModelManager(object):
# square images??? # square images???
width = pipeline.unet.config.sample_size * pipeline.vae_scale_factor width = pipeline.unet.config.sample_size * pipeline.vae_scale_factor
height = width height = width
print(f" | Default image dimensions = {width} x {height}") print(f" | Default image dimensions = {width} x {height}")
self._add_embeddings_to_model(pipeline)
return pipeline, width, height, model_hash return pipeline, width, height, model_hash
def _load_ckpt_model(self, model_name, mconfig): def _load_ckpt_model(self, model_name, mconfig):
@ -517,13 +517,14 @@ class ModelManager(object):
if self._has_cuda(): if self._has_cuda():
torch.cuda.empty_cache() torch.cuda.empty_cache()
@classmethod
def scan_model(self, model_name, checkpoint): def scan_model(self, model_name, checkpoint):
""" """
Apply picklescanner to the indicated checkpoint and issue a warning Apply picklescanner to the indicated checkpoint and issue a warning
and option to exit if an infected file is identified. and option to exit if an infected file is identified.
""" """
# scan model # scan model
print(f">> Scanning Model: {model_name}") print(f" | Scanning Model: {model_name}")
scan_result = scan_file_path(checkpoint) scan_result = scan_file_path(checkpoint)
if scan_result.infected_files != 0: if scan_result.infected_files != 0:
if scan_result.infected_files == 1: if scan_result.infected_files == 1:
@ -546,7 +547,7 @@ class ModelManager(object):
print("### Exiting InvokeAI") print("### Exiting InvokeAI")
sys.exit() sys.exit()
else: else:
print(">> Model scanned ok") print(" | Model scanned ok")
def import_diffuser_model( def import_diffuser_model(
self, self,
@ -731,11 +732,12 @@ class ModelManager(object):
return model_path.stem return model_path.stem
# another round of heuristics to guess the correct config file. # another round of heuristics to guess the correct config file.
checkpoint = ( checkpoint = None
torch.load(model_path) if model_path.suffix.endswith((".ckpt",".pt")):
if model_path.suffix == ".ckpt" self.scan_model(model_path,model_path)
else safetensors.torch.load_file(model_path) checkpoint = torch.load(model_path)
) else:
checkpoint = safetensors.torch.load_file(model_path)
# additional probing needed if no config file provided # additional probing needed if no config file provided
if model_config_file is None: if model_config_file is None:
@ -788,6 +790,7 @@ class ModelManager(object):
model_description=description, model_description=description,
original_config_file=model_config_file, original_config_file=model_config_file,
commit_to_conf=commit_to_conf, commit_to_conf=commit_to_conf,
scan_needed=False,
) )
return model_name return model_name
@ -800,6 +803,7 @@ class ModelManager(object):
vae=None, vae=None,
original_config_file: Path = None, original_config_file: Path = None,
commit_to_conf: Path = None, commit_to_conf: Path = None,
scan_needed: bool=True,
) -> str: ) -> str:
""" """
Convert a legacy ckpt weights file to diffuser model and import Convert a legacy ckpt weights file to diffuser model and import
@ -834,6 +838,7 @@ class ModelManager(object):
extract_ema=True, extract_ema=True,
original_config_file=original_config_file, original_config_file=original_config_file,
vae=vae_model, vae=vae_model,
scan_needed=scan_needed,
) )
print( print(
f" | Success. Optimized model is now located at {str(diffusers_path)}" f" | Success. Optimized model is now located at {str(diffusers_path)}"
@ -849,7 +854,7 @@ class ModelManager(object):
self.add_model(model_name, new_config, True) self.add_model(model_name, new_config, True)
if commit_to_conf: if commit_to_conf:
self.commit(commit_to_conf) self.commit(commit_to_conf)
print(">> Conversion succeeded") print(" | Conversion succeeded")
except Exception as e: except Exception as e:
print(f"** Conversion failed: {str(e)}") print(f"** Conversion failed: {str(e)}")
print( print(
@ -1208,7 +1213,7 @@ class ModelManager(object):
hashes_to_delete.add(revision.commit_hash) hashes_to_delete.add(revision.commit_hash)
strategy = cache_info.delete_revisions(*hashes_to_delete) strategy = cache_info.delete_revisions(*hashes_to_delete)
print( print(
f"** deletion of this model is expected to free {strategy.expected_freed_size_str}" f"** Deletion of this model is expected to free {strategy.expected_freed_size_str}"
) )
strategy.execute() strategy.execute()