Merge branch 'main' into bugfix/dreambooth_ema

This commit is contained in:
Lincoln Stein
2023-03-23 23:24:15 -04:00
committed by GitHub
72 changed files with 1060 additions and 790 deletions

View File

@ -1085,9 +1085,10 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
dlogging.set_verbosity_error()
checkpoint = (
load_file(checkpoint_path)
if Path(checkpoint_path).suffix == ".safetensors"
else torch.load(checkpoint_path)
torch.load(checkpoint_path)
if Path(checkpoint_path).suffix == ".ckpt"
else load_file(checkpoint_path)
)
cache_dir = global_cache_dir("hub")
pipeline_class = (

View File

@ -97,7 +97,7 @@ class ModelManager(object):
If on disk, will load from there.
"""
if not model_name:
return self.current_model if self.current_model else self.get_model(self.default_model())
return self.get_model(self.current_model) if self.current_model else self.get_model(self.default_model())
if not self.valid_model(model_name):
print(
@ -362,6 +362,7 @@ class ModelManager(object):
raise NotImplementedError(
f"Unknown model format {model_name}: {model_format}"
)
self._add_embeddings_to_model(model)
# usage statistics
toc = time.time()
@ -436,7 +437,6 @@ class ModelManager(object):
height = width
print(f" | Default image dimensions = {width} x {height}")
self._add_embeddings_to_model(pipeline)
return pipeline, width, height, model_hash
@ -732,9 +732,9 @@ class ModelManager(object):
# another round of heuristics to guess the correct config file.
checkpoint = (
safetensors.torch.load_file(model_path)
if model_path.suffix == ".safetensors"
else torch.load(model_path)
torch.load(model_path)
if model_path.suffix == ".ckpt"
else safetensors.torch.load_file(model_path)
)
# additional probing needed if no config file provided