implement vae passthru

This commit is contained in:
Lincoln Stein 2023-06-23 13:56:30 -04:00
parent afd19ab61a
commit 3043af4620

View File

@ -15,7 +15,9 @@ import warnings
from dataclasses import dataclass
from pathlib import Path
from omegaconf import OmegaConf
from omegaconf import OmegaConf, DictConfig
from typing import Union
from diffusers import StableDiffusionPipeline, AutoencoderKL
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from transformers import (
@ -104,6 +106,9 @@ class MigrateTo3(object):
'''
copy a single file with logging
'''
if dest.exists():
logger.info(f'Skipping existing {str(dest)}')
return
logger.info(f'Copying {str(src)} to {str(dest)}')
try:
shutil.copy(src, dest)
@ -115,6 +120,10 @@ class MigrateTo3(object):
'''
Recursively copy a directory with logging
'''
if dest.exists():
logger.info(f'Skipping existing {str(dest)}')
return
logger.info(f'Copying {str(src)} to {str(dest)}')
try:
shutil.copytree(src, dest)
@ -127,7 +136,6 @@ class MigrateTo3(object):
that looks like a model, and copy the model into the
appropriate location within the destination models directory.
'''
dest_dir = self.dest_models
for root, dirs, files in os.walk(src_dir):
for f in files:
# hack - don't copy raw learned_embeds.bin, let them
@ -139,7 +147,7 @@ class MigrateTo3(object):
info = ModelProbe().heuristic_probe(model)
if not info:
continue
dest = Path(dest_dir, info.base_type.value, info.model_type.value, f)
dest = self._model_probe_to_path(info) / f
self.copy_file(model, dest)
except KeyboardInterrupt:
raise
@ -151,14 +159,13 @@ class MigrateTo3(object):
info = ModelProbe().heuristic_probe(model)
if not info:
continue
dest = Path(dest_dir, info.base_type.value, info.model_type.value, model.name)
dest = self._model_probe_to_path(info) / model.name
self.copy_dir(model, dest)
except KeyboardInterrupt:
raise
except Exception as e:
logger.error(str(e))
# TO DO: Rewrite this to support alternate locations for esrgan and gfpgan in init file
def migrate_support_models(self):
'''
Copy the clipseg, upscaler, and restoration models to their new
@ -203,39 +210,53 @@ class MigrateTo3(object):
logger.info('Migrating core tokenizers and text encoders')
target_dir = dest_directory / 'core' / 'convert'
# bert
bert = BertTokenizerFast.from_pretrained("bert-base-uncased", **kwargs)
bert.save_pretrained(target_dir / 'bert-base-uncased', safe_serialization=True)
self._migrate_pretrained(BertTokenizerFast,
repo_id='bert-base-uncased',
dest = target_dir / 'bert-base-uncased',
**kwargs)
# sd-1
repo_id = 'openai/clip-vit-large-patch14'
pipeline = CLIPTokenizer.from_pretrained(repo_id, **kwargs)
pipeline.save_pretrained(target_dir / 'clip-vit-large-patch14' / 'tokenizer', safe_serialization=True)
pipeline = CLIPTextModel.from_pretrained(repo_id, **kwargs)
pipeline.save_pretrained(target_dir / 'clip-vit-large-patch14' / 'text_encoder', safe_serialization=True)
self._migrate_pretrained(CLIPTokenizer,
repo_id= repo_id,
dest= target_dir / 'clip-vit-large-patch14' / 'tokenizer',
**kwargs)
self._migrate_pretrained(CLIPTextModel,
repo_id = repo_id,
dest = target_dir / 'clip-vit-large-patch14' / 'text_encoder',
**kwargs)
# sd-2
repo_id = "stabilityai/stable-diffusion-2"
pipeline = CLIPTokenizer.from_pretrained(repo_id, subfolder="tokenizer", **kwargs)
pipeline.save_pretrained(target_dir / 'stable-diffusion-2-clip' / 'tokenizer', safe_serialization=True)
pipeline = CLIPTextModel.from_pretrained(repo_id, subfolder="text_encoder", **kwargs)
pipeline.save_pretrained(target_dir / 'stable-diffusion-2-clip' / 'text_encoder', safe_serialization=True)
self._migrate_pretrained(CLIPTokenizer,
repo_id = repo_id,
dest = target_dir / 'stable-diffusion-2-clip' / 'tokenizer',
**{'subfolder':'tokenizer',**kwargs}
)
self._migrate_pretrained(CLIPTextModel,
repo_id = repo_id,
dest = target_dir / 'stable-diffusion-2-clip' / 'text_encoder',
**{'subfolder':'text_encoder',**kwargs}
)
# VAE
logger.info('Migrating stable diffusion VAE')
vae = AutoencoderKL.from_pretrained('stabilityai/sd-vae-ft-mse', **kwargs)
vae.save_pretrained(target_dir / 'sd-vae-ft-mse', safe_serialization=True)
self._migrate_pretrained(AutoencoderKL,
repo_id = 'stabilityai/sd-vae-ft-mse',
dest = target_dir / 'sd-vae-ft-mse',
**kwargs)
# safety checking
logger.info('Migrating safety checker')
repo_id = "CompVis/stable-diffusion-safety-checker"
pipeline = AutoFeatureExtractor.from_pretrained(repo_id,**kwargs)
pipeline.save_pretrained(target_dir / 'stable-diffusion-safety-checker', safe_serialization=True)
pipeline = StableDiffusionSafetyChecker.from_pretrained(repo_id,**kwargs)
pipeline.save_pretrained(target_dir / 'stable-diffusion-safety-checker', safe_serialization=True)
self._migrate_pretrained(AutoFeatureExtractor,
repo_id = repo_id,
dest = target_dir / 'stable-diffusion-safety-checker',
**kwargs)
self._migrate_pretrained(StableDiffusionSafetyChecker,
repo_id = repo_id,
dest = target_dir / 'stable-diffusion-safety-checker',
**kwargs)
except KeyboardInterrupt:
raise
except Exception as e:
@ -262,8 +283,72 @@ class MigrateTo3(object):
}
self.dest_yaml.write(yaml.dump(stanza))
self.dest_yaml.flush()
def _model_probe_to_path(self, info: ModelProbeInfo)->Path:
return Path(self.dest_models, info.base_type.value, info.model_type.value)
def migrate_repo_id(self, repo_id: str, model_name :str=None):
def _migrate_pretrained(self, model_class, repo_id: str, dest: Path, **kwargs):
if dest.exists():
logger.info(f'Skipping existing {dest}')
return
model = model_class.from_pretrained(repo_id, **kwargs)
self._save_pretrained(model, dest)
def _save_pretrained(self, model, dest: Path):
if dest.exists():
logger.info(f'Skipping existing {dest}')
return
model_name = dest.name
download_path = dest.with_name(f'{model_name}.downloading')
model.save_pretrained(download_path, safe_serialization=True)
download_path.replace(dest)
def _download_vae(self, repo_id: str, subfolder:str=None)->Path:
vae = AutoencoderKL.from_pretrained(repo_id, cache_dir=self.root_directory / 'models/hub', subfolder=subfolder)
info = ModelProbe().heuristic_probe(vae)
_, model_name = repo_id.split('/')
dest = self._model_probe_to_path(info) / self.unique_name(model_name, info)
vae.save_pretrained(dest, safe_serialization=True)
return dest
def _vae_path(self, vae: Union[str,dict])->Path:
'''
Convert 2.3 VAE stanza to a straight path.
'''
vae_path = None
# First get a path
if isinstance(vae,str):
vae_path = vae
elif isinstance(vae,DictConfig):
if p := vae.get('path'):
vae_path = p
elif repo_id := vae.get('repo_id'):
if repo_id=='stabilityai/sd-vae-ft-mse': # this guy is already downloaded
vae_path = 'models/core/convert/se-vae-ft-mse'
else:
vae_path = self._download_vae(repo_id, vae.get('subfolder'))
assert vae_path is not None, "Couldn't find VAE for this model"
# if the VAE is in the old models directory, then we must move it into the new
# one. VAEs outside of this directory can stay where they are.
vae_path = Path(vae_path)
if vae_path.is_relative_to(self.src_paths.models):
info = ModelProbe().heuristic_probe(vae_path)
dest = self._model_probe_to_path(info) / vae_path.name
if not dest.exists():
self.copy_dir(vae_path,dest)
vae_path = dest
if vae_path.is_relative_to(self.dest_models):
rel_path = vae_path.relative_to(self.dest_models)
return Path('models',rel_path)
else:
return vae_path
def migrate_repo_id(self, repo_id: str, model_name :str=None, **extra_config):
'''
Migrate a locally-cached diffusers pipeline identified with a repo_id
'''
@ -295,10 +380,11 @@ class MigrateTo3(object):
if not info:
return
dest = Path(dest_dir, info.base_type.value, info.model_type.value, f'{repo_name}')
pipeline.save_pretrained(dest, safe_serialization=True)
dest = self._model_probe_to_path(info) / repo_name
self._save_pretrained(pipeline, dest)
rel_path = Path('models',dest.relative_to(dest_dir))
self.write_yaml(model_name, path=rel_path, info=info)
self.write_yaml(model_name, path=rel_path, info=info, **extra_config)
def migrate_path(self, location: Path, model_name: str=None, **extra_config):
'''
@ -332,16 +418,29 @@ class MigrateTo3(object):
for model_name, stanza in conf.items():
try:
passthru_args = {}
if vae := stanza.get('vae'):
try:
passthru_args['vae'] = str(self._vae_path(vae))
except Exception as e:
logger.warning(f'Could not find a VAE matching "{vae}" for model "{model_name}"')
logger.warning(str(e))
if config := stanza.get('config'):
passthru_args['config'] = config
if repo_id := stanza.get('repo_id'):
logger.info(f'Migrating diffusers model {model_name}')
self.migrate_repo_id(repo_id, model_name)
self.migrate_repo_id(repo_id, model_name, **passthru_args)
elif location := stanza.get('weights'):
logger.info(f'Migrating checkpoint model {model_name}')
self.migrate_path(Path(location), model_name, config=stanza.get('config'))
self.migrate_path(Path(location), model_name, **passthru_args)
elif location := stanza.get('path'):
logger.info(f'Migrating diffusers model {model_name}')
self.migrate_path(Path(location), model_name, config=stanza.get('config'))
self.migrate_path(Path(location), model_name, **passthru_args)
except KeyboardInterrupt:
raise
@ -424,6 +523,7 @@ def do_migrate(src_directory: Path, dest_directory: Path):
)
migrator.migrate()
shutil.rmtree(dest_directory / 'models.orig', ignore_errors=True)
(dest_directory / 'models').replace(dest_directory / 'models.orig')
dest_models.replace(dest_directory / 'models')
@ -456,6 +556,7 @@ script, which will perform a full upgrade in place."""
required=True,
help='Destination InvokeAI 3.0 directory (containing "invokeai.yaml")'
)
# TO DO: Implement full directory scanning
# parser.add_argument('--all-models',
# action="store_true",
# help='Migrate all models found in `models` directory, not just those mentioned in models.yaml',
@ -476,3 +577,5 @@ script, which will perform a full upgrade in place."""
if __name__ == '__main__':
main()