mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
re-implement model scanning when loading legacy checkpoint files (#3012)
- This PR turns on pickle scanning before a legacy checkpoint file is loaded from disk within the checkpoint_to_diffusers module. - Also miscellaneous diagnostic message cleanup. - See also #3011 for a similar patch to the 2.3 branch.
This commit is contained in:
commit
bc01a96f9d
@ -1050,6 +1050,7 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
|
|||||||
vae: AutoencoderKL = None,
|
vae: AutoencoderKL = None,
|
||||||
precision: torch.dtype = torch.float32,
|
precision: torch.dtype = torch.float32,
|
||||||
return_generator_pipeline: bool = False,
|
return_generator_pipeline: bool = False,
|
||||||
|
scan_needed:bool=True,
|
||||||
) -> Union[StableDiffusionPipeline, StableDiffusionGeneratorPipeline]:
|
) -> Union[StableDiffusionPipeline, StableDiffusionGeneratorPipeline]:
|
||||||
"""
|
"""
|
||||||
Load a Stable Diffusion pipeline object from a CompVis-style `.ckpt`/`.safetensors` file and (ideally) a `.yaml`
|
Load a Stable Diffusion pipeline object from a CompVis-style `.ckpt`/`.safetensors` file and (ideally) a `.yaml`
|
||||||
@ -1084,12 +1085,13 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
|
|||||||
verbosity = dlogging.get_verbosity()
|
verbosity = dlogging.get_verbosity()
|
||||||
dlogging.set_verbosity_error()
|
dlogging.set_verbosity_error()
|
||||||
|
|
||||||
checkpoint = (
|
if Path(checkpoint_path).suffix == '.ckpt':
|
||||||
torch.load(checkpoint_path)
|
if scan_needed:
|
||||||
if Path(checkpoint_path).suffix == ".ckpt"
|
ModelManager.scan_model(checkpoint_path,checkpoint_path)
|
||||||
else load_file(checkpoint_path)
|
checkpoint = torch.load(checkpoint_path)
|
||||||
|
else:
|
||||||
|
checkpoint = load_file(checkpoint_path)
|
||||||
|
|
||||||
)
|
|
||||||
cache_dir = global_cache_dir("hub")
|
cache_dir = global_cache_dir("hub")
|
||||||
pipeline_class = (
|
pipeline_class = (
|
||||||
StableDiffusionGeneratorPipeline
|
StableDiffusionGeneratorPipeline
|
||||||
|
@ -34,7 +34,7 @@ from picklescan.scanner import scan_file_path
|
|||||||
from invokeai.backend.globals import Globals, global_cache_dir
|
from invokeai.backend.globals import Globals, global_cache_dir
|
||||||
|
|
||||||
from ..stable_diffusion import StableDiffusionGeneratorPipeline
|
from ..stable_diffusion import StableDiffusionGeneratorPipeline
|
||||||
from ..util import CUDA_DEVICE, CPU_DEVICE, ask_user, download_with_resume
|
from ..util import CUDA_DEVICE, ask_user, download_with_resume
|
||||||
|
|
||||||
class SDLegacyType(Enum):
|
class SDLegacyType(Enum):
|
||||||
V1 = 1
|
V1 = 1
|
||||||
@ -285,13 +285,13 @@ class ModelManager(object):
|
|||||||
self.stack.remove(model_name)
|
self.stack.remove(model_name)
|
||||||
if delete_files:
|
if delete_files:
|
||||||
if weights:
|
if weights:
|
||||||
print(f"** deleting file {weights}")
|
print(f"** Deleting file {weights}")
|
||||||
Path(weights).unlink(missing_ok=True)
|
Path(weights).unlink(missing_ok=True)
|
||||||
elif path:
|
elif path:
|
||||||
print(f"** deleting directory {path}")
|
print(f"** Deleting directory {path}")
|
||||||
rmtree(path, ignore_errors=True)
|
rmtree(path, ignore_errors=True)
|
||||||
elif repo_id:
|
elif repo_id:
|
||||||
print(f"** deleting the cached model directory for {repo_id}")
|
print(f"** Deleting the cached model directory for {repo_id}")
|
||||||
self._delete_model_from_cache(repo_id)
|
self._delete_model_from_cache(repo_id)
|
||||||
|
|
||||||
def add_model(
|
def add_model(
|
||||||
@ -435,9 +435,9 @@ class ModelManager(object):
|
|||||||
# square images???
|
# square images???
|
||||||
width = pipeline.unet.config.sample_size * pipeline.vae_scale_factor
|
width = pipeline.unet.config.sample_size * pipeline.vae_scale_factor
|
||||||
height = width
|
height = width
|
||||||
|
|
||||||
print(f" | Default image dimensions = {width} x {height}")
|
print(f" | Default image dimensions = {width} x {height}")
|
||||||
|
|
||||||
|
self._add_embeddings_to_model(pipeline)
|
||||||
return pipeline, width, height, model_hash
|
return pipeline, width, height, model_hash
|
||||||
|
|
||||||
def _load_ckpt_model(self, model_name, mconfig):
|
def _load_ckpt_model(self, model_name, mconfig):
|
||||||
@ -517,13 +517,14 @@ class ModelManager(object):
|
|||||||
if self._has_cuda():
|
if self._has_cuda():
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
def scan_model(self, model_name, checkpoint):
|
def scan_model(self, model_name, checkpoint):
|
||||||
"""
|
"""
|
||||||
Apply picklescanner to the indicated checkpoint and issue a warning
|
Apply picklescanner to the indicated checkpoint and issue a warning
|
||||||
and option to exit if an infected file is identified.
|
and option to exit if an infected file is identified.
|
||||||
"""
|
"""
|
||||||
# scan model
|
# scan model
|
||||||
print(f">> Scanning Model: {model_name}")
|
print(f" | Scanning Model: {model_name}")
|
||||||
scan_result = scan_file_path(checkpoint)
|
scan_result = scan_file_path(checkpoint)
|
||||||
if scan_result.infected_files != 0:
|
if scan_result.infected_files != 0:
|
||||||
if scan_result.infected_files == 1:
|
if scan_result.infected_files == 1:
|
||||||
@ -546,7 +547,7 @@ class ModelManager(object):
|
|||||||
print("### Exiting InvokeAI")
|
print("### Exiting InvokeAI")
|
||||||
sys.exit()
|
sys.exit()
|
||||||
else:
|
else:
|
||||||
print(">> Model scanned ok")
|
print(" | Model scanned ok")
|
||||||
|
|
||||||
def import_diffuser_model(
|
def import_diffuser_model(
|
||||||
self,
|
self,
|
||||||
@ -731,11 +732,12 @@ class ModelManager(object):
|
|||||||
return model_path.stem
|
return model_path.stem
|
||||||
|
|
||||||
# another round of heuristics to guess the correct config file.
|
# another round of heuristics to guess the correct config file.
|
||||||
checkpoint = (
|
checkpoint = None
|
||||||
torch.load(model_path)
|
if model_path.suffix.endswith((".ckpt",".pt")):
|
||||||
if model_path.suffix == ".ckpt"
|
self.scan_model(model_path,model_path)
|
||||||
else safetensors.torch.load_file(model_path)
|
checkpoint = torch.load(model_path)
|
||||||
)
|
else:
|
||||||
|
checkpoint = safetensors.torch.load_file(model_path)
|
||||||
|
|
||||||
# additional probing needed if no config file provided
|
# additional probing needed if no config file provided
|
||||||
if model_config_file is None:
|
if model_config_file is None:
|
||||||
@ -788,6 +790,7 @@ class ModelManager(object):
|
|||||||
model_description=description,
|
model_description=description,
|
||||||
original_config_file=model_config_file,
|
original_config_file=model_config_file,
|
||||||
commit_to_conf=commit_to_conf,
|
commit_to_conf=commit_to_conf,
|
||||||
|
scan_needed=False,
|
||||||
)
|
)
|
||||||
return model_name
|
return model_name
|
||||||
|
|
||||||
@ -800,6 +803,7 @@ class ModelManager(object):
|
|||||||
vae=None,
|
vae=None,
|
||||||
original_config_file: Path = None,
|
original_config_file: Path = None,
|
||||||
commit_to_conf: Path = None,
|
commit_to_conf: Path = None,
|
||||||
|
scan_needed: bool=True,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Convert a legacy ckpt weights file to diffuser model and import
|
Convert a legacy ckpt weights file to diffuser model and import
|
||||||
@ -834,6 +838,7 @@ class ModelManager(object):
|
|||||||
extract_ema=True,
|
extract_ema=True,
|
||||||
original_config_file=original_config_file,
|
original_config_file=original_config_file,
|
||||||
vae=vae_model,
|
vae=vae_model,
|
||||||
|
scan_needed=scan_needed,
|
||||||
)
|
)
|
||||||
print(
|
print(
|
||||||
f" | Success. Optimized model is now located at {str(diffusers_path)}"
|
f" | Success. Optimized model is now located at {str(diffusers_path)}"
|
||||||
@ -849,7 +854,7 @@ class ModelManager(object):
|
|||||||
self.add_model(model_name, new_config, True)
|
self.add_model(model_name, new_config, True)
|
||||||
if commit_to_conf:
|
if commit_to_conf:
|
||||||
self.commit(commit_to_conf)
|
self.commit(commit_to_conf)
|
||||||
print(">> Conversion succeeded")
|
print(" | Conversion succeeded")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"** Conversion failed: {str(e)}")
|
print(f"** Conversion failed: {str(e)}")
|
||||||
print(
|
print(
|
||||||
@ -1208,7 +1213,7 @@ class ModelManager(object):
|
|||||||
hashes_to_delete.add(revision.commit_hash)
|
hashes_to_delete.add(revision.commit_hash)
|
||||||
strategy = cache_info.delete_revisions(*hashes_to_delete)
|
strategy = cache_info.delete_revisions(*hashes_to_delete)
|
||||||
print(
|
print(
|
||||||
f"** deletion of this model is expected to free {strategy.expected_freed_size_str}"
|
f"** Deletion of this model is expected to free {strategy.expected_freed_size_str}"
|
||||||
)
|
)
|
||||||
strategy.execute()
|
strategy.execute()
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user