mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
implement vae passthru
This commit is contained in:
parent
afd19ab61a
commit
3043af4620
@ -15,7 +15,9 @@ import warnings
|
||||
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from omegaconf import OmegaConf
|
||||
from omegaconf import OmegaConf, DictConfig
|
||||
from typing import Union
|
||||
|
||||
from diffusers import StableDiffusionPipeline, AutoencoderKL
|
||||
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||
from transformers import (
|
||||
@ -104,6 +106,9 @@ class MigrateTo3(object):
|
||||
'''
|
||||
copy a single file with logging
|
||||
'''
|
||||
if dest.exists():
|
||||
logger.info(f'Skipping existing {str(dest)}')
|
||||
return
|
||||
logger.info(f'Copying {str(src)} to {str(dest)}')
|
||||
try:
|
||||
shutil.copy(src, dest)
|
||||
@ -115,6 +120,10 @@ class MigrateTo3(object):
|
||||
'''
|
||||
Recursively copy a directory with logging
|
||||
'''
|
||||
if dest.exists():
|
||||
logger.info(f'Skipping existing {str(dest)}')
|
||||
return
|
||||
|
||||
logger.info(f'Copying {str(src)} to {str(dest)}')
|
||||
try:
|
||||
shutil.copytree(src, dest)
|
||||
@ -127,7 +136,6 @@ class MigrateTo3(object):
|
||||
that looks like a model, and copy the model into the
|
||||
appropriate location within the destination models directory.
|
||||
'''
|
||||
dest_dir = self.dest_models
|
||||
for root, dirs, files in os.walk(src_dir):
|
||||
for f in files:
|
||||
# hack - don't copy raw learned_embeds.bin, let them
|
||||
@ -139,7 +147,7 @@ class MigrateTo3(object):
|
||||
info = ModelProbe().heuristic_probe(model)
|
||||
if not info:
|
||||
continue
|
||||
dest = Path(dest_dir, info.base_type.value, info.model_type.value, f)
|
||||
dest = self._model_probe_to_path(info) / f
|
||||
self.copy_file(model, dest)
|
||||
except KeyboardInterrupt:
|
||||
raise
|
||||
@ -151,14 +159,13 @@ class MigrateTo3(object):
|
||||
info = ModelProbe().heuristic_probe(model)
|
||||
if not info:
|
||||
continue
|
||||
dest = Path(dest_dir, info.base_type.value, info.model_type.value, model.name)
|
||||
dest = self._model_probe_to_path(info) / model.name
|
||||
self.copy_dir(model, dest)
|
||||
except KeyboardInterrupt:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(str(e))
|
||||
|
||||
# TO DO: Rewrite this to support alternate locations for esrgan and gfpgan in init file
|
||||
def migrate_support_models(self):
|
||||
'''
|
||||
Copy the clipseg, upscaler, and restoration models to their new
|
||||
@ -203,39 +210,53 @@ class MigrateTo3(object):
|
||||
logger.info('Migrating core tokenizers and text encoders')
|
||||
target_dir = dest_directory / 'core' / 'convert'
|
||||
|
||||
# bert
|
||||
bert = BertTokenizerFast.from_pretrained("bert-base-uncased", **kwargs)
|
||||
bert.save_pretrained(target_dir / 'bert-base-uncased', safe_serialization=True)
|
||||
self._migrate_pretrained(BertTokenizerFast,
|
||||
repo_id='bert-base-uncased',
|
||||
dest = target_dir / 'bert-base-uncased',
|
||||
**kwargs)
|
||||
|
||||
# sd-1
|
||||
repo_id = 'openai/clip-vit-large-patch14'
|
||||
pipeline = CLIPTokenizer.from_pretrained(repo_id, **kwargs)
|
||||
pipeline.save_pretrained(target_dir / 'clip-vit-large-patch14' / 'tokenizer', safe_serialization=True)
|
||||
|
||||
pipeline = CLIPTextModel.from_pretrained(repo_id, **kwargs)
|
||||
pipeline.save_pretrained(target_dir / 'clip-vit-large-patch14' / 'text_encoder', safe_serialization=True)
|
||||
self._migrate_pretrained(CLIPTokenizer,
|
||||
repo_id= repo_id,
|
||||
dest= target_dir / 'clip-vit-large-patch14' / 'tokenizer',
|
||||
**kwargs)
|
||||
self._migrate_pretrained(CLIPTextModel,
|
||||
repo_id = repo_id,
|
||||
dest = target_dir / 'clip-vit-large-patch14' / 'text_encoder',
|
||||
**kwargs)
|
||||
|
||||
# sd-2
|
||||
repo_id = "stabilityai/stable-diffusion-2"
|
||||
pipeline = CLIPTokenizer.from_pretrained(repo_id, subfolder="tokenizer", **kwargs)
|
||||
pipeline.save_pretrained(target_dir / 'stable-diffusion-2-clip' / 'tokenizer', safe_serialization=True)
|
||||
|
||||
pipeline = CLIPTextModel.from_pretrained(repo_id, subfolder="text_encoder", **kwargs)
|
||||
pipeline.save_pretrained(target_dir / 'stable-diffusion-2-clip' / 'text_encoder', safe_serialization=True)
|
||||
self._migrate_pretrained(CLIPTokenizer,
|
||||
repo_id = repo_id,
|
||||
dest = target_dir / 'stable-diffusion-2-clip' / 'tokenizer',
|
||||
**{'subfolder':'tokenizer',**kwargs}
|
||||
)
|
||||
self._migrate_pretrained(CLIPTextModel,
|
||||
repo_id = repo_id,
|
||||
dest = target_dir / 'stable-diffusion-2-clip' / 'text_encoder',
|
||||
**{'subfolder':'text_encoder',**kwargs}
|
||||
)
|
||||
|
||||
# VAE
|
||||
logger.info('Migrating stable diffusion VAE')
|
||||
vae = AutoencoderKL.from_pretrained('stabilityai/sd-vae-ft-mse', **kwargs)
|
||||
vae.save_pretrained(target_dir / 'sd-vae-ft-mse', safe_serialization=True)
|
||||
|
||||
self._migrate_pretrained(AutoencoderKL,
|
||||
repo_id = 'stabilityai/sd-vae-ft-mse',
|
||||
dest = target_dir / 'sd-vae-ft-mse',
|
||||
**kwargs)
|
||||
|
||||
# safety checking
|
||||
logger.info('Migrating safety checker')
|
||||
repo_id = "CompVis/stable-diffusion-safety-checker"
|
||||
pipeline = AutoFeatureExtractor.from_pretrained(repo_id,**kwargs)
|
||||
pipeline.save_pretrained(target_dir / 'stable-diffusion-safety-checker', safe_serialization=True)
|
||||
|
||||
pipeline = StableDiffusionSafetyChecker.from_pretrained(repo_id,**kwargs)
|
||||
pipeline.save_pretrained(target_dir / 'stable-diffusion-safety-checker', safe_serialization=True)
|
||||
self._migrate_pretrained(AutoFeatureExtractor,
|
||||
repo_id = repo_id,
|
||||
dest = target_dir / 'stable-diffusion-safety-checker',
|
||||
**kwargs)
|
||||
self._migrate_pretrained(StableDiffusionSafetyChecker,
|
||||
repo_id = repo_id,
|
||||
dest = target_dir / 'stable-diffusion-safety-checker',
|
||||
**kwargs)
|
||||
except KeyboardInterrupt:
|
||||
raise
|
||||
except Exception as e:
|
||||
@ -262,8 +283,72 @@ class MigrateTo3(object):
|
||||
}
|
||||
self.dest_yaml.write(yaml.dump(stanza))
|
||||
self.dest_yaml.flush()
|
||||
|
||||
def _model_probe_to_path(self, info: ModelProbeInfo)->Path:
|
||||
return Path(self.dest_models, info.base_type.value, info.model_type.value)
|
||||
|
||||
def migrate_repo_id(self, repo_id: str, model_name :str=None):
|
||||
def _migrate_pretrained(self, model_class, repo_id: str, dest: Path, **kwargs):
|
||||
if dest.exists():
|
||||
logger.info(f'Skipping existing {dest}')
|
||||
return
|
||||
model = model_class.from_pretrained(repo_id, **kwargs)
|
||||
self._save_pretrained(model, dest)
|
||||
|
||||
def _save_pretrained(self, model, dest: Path):
|
||||
if dest.exists():
|
||||
logger.info(f'Skipping existing {dest}')
|
||||
return
|
||||
model_name = dest.name
|
||||
download_path = dest.with_name(f'{model_name}.downloading')
|
||||
model.save_pretrained(download_path, safe_serialization=True)
|
||||
download_path.replace(dest)
|
||||
|
||||
def _download_vae(self, repo_id: str, subfolder:str=None)->Path:
|
||||
vae = AutoencoderKL.from_pretrained(repo_id, cache_dir=self.root_directory / 'models/hub', subfolder=subfolder)
|
||||
info = ModelProbe().heuristic_probe(vae)
|
||||
_, model_name = repo_id.split('/')
|
||||
dest = self._model_probe_to_path(info) / self.unique_name(model_name, info)
|
||||
vae.save_pretrained(dest, safe_serialization=True)
|
||||
return dest
|
||||
|
||||
def _vae_path(self, vae: Union[str,dict])->Path:
|
||||
'''
|
||||
Convert 2.3 VAE stanza to a straight path.
|
||||
'''
|
||||
vae_path = None
|
||||
|
||||
# First get a path
|
||||
if isinstance(vae,str):
|
||||
vae_path = vae
|
||||
|
||||
elif isinstance(vae,DictConfig):
|
||||
if p := vae.get('path'):
|
||||
vae_path = p
|
||||
elif repo_id := vae.get('repo_id'):
|
||||
if repo_id=='stabilityai/sd-vae-ft-mse': # this guy is already downloaded
|
||||
vae_path = 'models/core/convert/se-vae-ft-mse'
|
||||
else:
|
||||
vae_path = self._download_vae(repo_id, vae.get('subfolder'))
|
||||
|
||||
assert vae_path is not None, "Couldn't find VAE for this model"
|
||||
|
||||
# if the VAE is in the old models directory, then we must move it into the new
|
||||
# one. VAEs outside of this directory can stay where they are.
|
||||
vae_path = Path(vae_path)
|
||||
if vae_path.is_relative_to(self.src_paths.models):
|
||||
info = ModelProbe().heuristic_probe(vae_path)
|
||||
dest = self._model_probe_to_path(info) / vae_path.name
|
||||
if not dest.exists():
|
||||
self.copy_dir(vae_path,dest)
|
||||
vae_path = dest
|
||||
|
||||
if vae_path.is_relative_to(self.dest_models):
|
||||
rel_path = vae_path.relative_to(self.dest_models)
|
||||
return Path('models',rel_path)
|
||||
else:
|
||||
return vae_path
|
||||
|
||||
def migrate_repo_id(self, repo_id: str, model_name :str=None, **extra_config):
|
||||
'''
|
||||
Migrate a locally-cached diffusers pipeline identified with a repo_id
|
||||
'''
|
||||
@ -295,10 +380,11 @@ class MigrateTo3(object):
|
||||
if not info:
|
||||
return
|
||||
|
||||
dest = Path(dest_dir, info.base_type.value, info.model_type.value, f'{repo_name}')
|
||||
pipeline.save_pretrained(dest, safe_serialization=True)
|
||||
dest = self._model_probe_to_path(info) / repo_name
|
||||
self._save_pretrained(pipeline, dest)
|
||||
|
||||
rel_path = Path('models',dest.relative_to(dest_dir))
|
||||
self.write_yaml(model_name, path=rel_path, info=info)
|
||||
self.write_yaml(model_name, path=rel_path, info=info, **extra_config)
|
||||
|
||||
def migrate_path(self, location: Path, model_name: str=None, **extra_config):
|
||||
'''
|
||||
@ -332,16 +418,29 @@ class MigrateTo3(object):
|
||||
for model_name, stanza in conf.items():
|
||||
|
||||
try:
|
||||
passthru_args = {}
|
||||
|
||||
if vae := stanza.get('vae'):
|
||||
try:
|
||||
passthru_args['vae'] = str(self._vae_path(vae))
|
||||
except Exception as e:
|
||||
logger.warning(f'Could not find a VAE matching "{vae}" for model "{model_name}"')
|
||||
logger.warning(str(e))
|
||||
|
||||
if config := stanza.get('config'):
|
||||
passthru_args['config'] = config
|
||||
|
||||
if repo_id := stanza.get('repo_id'):
|
||||
logger.info(f'Migrating diffusers model {model_name}')
|
||||
self.migrate_repo_id(repo_id, model_name)
|
||||
self.migrate_repo_id(repo_id, model_name, **passthru_args)
|
||||
|
||||
elif location := stanza.get('weights'):
|
||||
logger.info(f'Migrating checkpoint model {model_name}')
|
||||
self.migrate_path(Path(location), model_name, config=stanza.get('config'))
|
||||
self.migrate_path(Path(location), model_name, **passthru_args)
|
||||
|
||||
elif location := stanza.get('path'):
|
||||
logger.info(f'Migrating diffusers model {model_name}')
|
||||
self.migrate_path(Path(location), model_name, config=stanza.get('config'))
|
||||
self.migrate_path(Path(location), model_name, **passthru_args)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
raise
|
||||
@ -424,6 +523,7 @@ def do_migrate(src_directory: Path, dest_directory: Path):
|
||||
)
|
||||
migrator.migrate()
|
||||
|
||||
shutil.rmtree(dest_directory / 'models.orig', ignore_errors=True)
|
||||
(dest_directory / 'models').replace(dest_directory / 'models.orig')
|
||||
dest_models.replace(dest_directory / 'models')
|
||||
|
||||
@ -456,6 +556,7 @@ script, which will perform a full upgrade in place."""
|
||||
required=True,
|
||||
help='Destination InvokeAI 3.0 directory (containing "invokeai.yaml")'
|
||||
)
|
||||
# TO DO: Implement full directory scanning
|
||||
# parser.add_argument('--all-models',
|
||||
# action="store_true",
|
||||
# help='Migrate all models found in `models` directory, not just those mentioned in models.yaml',
|
||||
@ -476,3 +577,5 @@ script, which will perform a full upgrade in place."""
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user