InvokeAI/scripts/migrate_models_to_3.0.py

279 lines
11 KiB
Python

'''
Migrate the models directory and models.yaml file from an existing
InvokeAI 2.3 installation to 3.0.0.
'''
import io
import os
import argparse
import shutil
import yaml
import transformers
import diffusers
import warnings
from pathlib import Path
from omegaconf import OmegaConf
from diffusers import StableDiffusionPipeline, AutoencoderKL
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from transformers import (
CLIPTextModel,
CLIPTokenizer,
AutoFeatureExtractor,
BertTokenizerFast,
)
import invokeai.backend.util.logging as logger
from invokeai.backend.model_management.model_probe import (
ModelProbe, ModelType, BaseModelType
)
warnings.filterwarnings("ignore")
transformers.logging.set_verbosity_error()
diffusers.logging.set_verbosity_error()
def create_directory_structure(dest: Path):
for model_base in [BaseModelType.StableDiffusion1,BaseModelType.StableDiffusion2]:
for model_type in [ModelType.Pipeline, ModelType.Vae, ModelType.Lora,
ModelType.ControlNet,ModelType.TextualInversion]:
path = dest / model_base.value / model_type.value
path.mkdir(parents=True, exist_ok=True)
path = dest / 'core'
path.mkdir(parents=True, exist_ok=True)
def copy_file(src:Path,dest:Path):
logger.info(f'Copying {str(src)} to {str(dest)}')
try:
shutil.copy(src, dest)
except Exception as e:
logger.error(f'COPY FAILED: {str(e)}')
def copy_dir(src:Path,dest:Path):
logger.info(f'Copying {str(src)} to {str(dest)}')
try:
shutil.copytree(src, dest)
except Exception as e:
logger.error(f'COPY FAILED: {str(e)}')
def migrate_models(src_dir: Path, dest_dir: Path):
for root, dirs, files in os.walk(src_dir):
for f in files:
# hack - don't copy raw learned_embeds.bin, let them
# be copied as part of a tree copy operation
if f == 'learned_embeds.bin':
continue
try:
model = Path(root,f)
info = ModelProbe().heuristic_probe(model)
if not info:
continue
dest = Path(dest_dir, info.base_type.value, info.model_type.value, f)
copy_file(model, dest)
except KeyboardInterrupt:
raise
except Exception as e:
logger.error(str(e))
for d in dirs:
try:
model = Path(root,d)
info = ModelProbe().heuristic_probe(model)
if not info:
continue
dest = Path(dest_dir, info.base_type.value, info.model_type.value, model.name)
copy_dir(model, dest)
except KeyboardInterrupt:
raise
except Exception as e:
logger.error(str(e))
def migrate_support_models(dest_directory: Path):
if Path('./models/clipseg').exists():
copy_dir(Path('./models/clipseg'),dest_directory / 'core/misc/clipseg')
if Path('./models/realesrgan').exists():
copy_dir(Path('./models/realesrgan'),dest_directory / 'core/upscaling/realesrgan')
for d in ['codeformer','gfpgan']:
path = Path('./models',d)
if path.exists():
copy_dir(path,dest_directory / f'core/face_restoration/{d}')
def migrate_conversion_models(dest_directory: Path):
# These are needed for the conversion script
kwargs = dict(
cache_dir = Path('./models/hub'),
local_files_only = True
)
try:
logger.info('Migrating core tokenizers and text encoders')
target_dir = dest_directory/'core/convert'
# bert
bert = BertTokenizerFast.from_pretrained("bert-base-uncased", **kwargs)
bert.save_pretrained(target_dir/'bert-base-uncased', safe_serialization=True)
# sd-1
repo_id = 'openai/clip-vit-large-patch14'
pipeline = CLIPTokenizer.from_pretrained(repo_id, **kwargs)
pipeline.save_pretrained(target_dir/'clip-vit-large-patch14', safe_serialization=True)
pipeline = CLIPTextModel.from_pretrained(repo_id, **kwargs)
pipeline.save_pretrained(target_dir/'clip-vit-large-patch14', safe_serialization=True)
# sd-2
repo_id = "stabilityai/stable-diffusion-2"
pipeline = CLIPTokenizer.from_pretrained(repo_id, subfolder = "tokenizer", **kwargs)
pipeline.save_pretrained(target_dir/'stable-diffusion-2-tokenizer', safe_serialization=True)
pipeline = CLIPTextModel.from_pretrained(repo_id,subfolder = "text_encoder", **kwargs)
pipeline.save_pretrained(target_dir/'stable-diffusion-2-text_encoder', safe_serialization=True)
# VAE
logger.info('Migrating stable diffusion VAE')
vae = AutoencoderKL.from_pretrained('stabilityai/sd-vae-ft-mse', **kwargs)
vae.save_pretrained(target_dir/'sd-vae-ft-mse', safe_serialization=True)
# safety checking
logger.info('Migrating safety checker')
repo_id = "CompVis/stable-diffusion-safety-checker"
pipeline = AutoFeatureExtractor.from_pretrained(repo_id,**kwargs)
pipeline.save_pretrained(target_dir/'stable-diffusion-safety-checker-extractor', safe_serialization=True)
pipeline = StableDiffusionSafetyChecker.from_pretrained(repo_id,**kwargs)
pipeline.save_pretrained(target_dir/'stable-diffusion-safety-checker', safe_serialization=True)
except KeyboardInterrupt:
raise
except Exception as e:
logger.error(str(e))
def migrate_tuning_models(dest: Path):
for subdir in ['embeddings','loras','controlnets']:
src = Path('.',subdir)
if not src.is_dir():
logger.info(f'{subdir} directory not found; skipping')
continue
logger.info(f'Scanning {subdir}')
migrate_models(src, dest)
def migrate_pipelines(dest_dir: Path, dest_yaml: io.TextIOBase):
cache = Path('./models/hub')
kwargs = dict(
cache_dir = cache,
local_files_only = True,
safety_checker = None,
)
for model in cache.glob('models--*'):
if len(list(model.glob('snapshots/**/model_index.json')))==0:
continue
_,owner,repo_name=model.name.split('--')
repo_id = f'{owner}/{repo_name}'
revisions = [x.name for x in model.glob('refs/*')]
for revision in revisions:
logger.info(f'Migrating {repo_id}, revision {revision}')
try:
pipeline = StableDiffusionPipeline.from_pretrained(
repo_id,
revision=revision,
**kwargs)
info = ModelProbe().heuristic_probe(pipeline)
if not info:
continue
dest = Path(dest_dir, info.base_type.value, info.model_type.value, f'{repo_name}-{revision}')
pipeline.save_pretrained(dest, safe_serialization=True)
rel_path = Path('models',dest.relative_to(dest_dir))
stanza = {
f'{info.base_type.value}/{info.model_type.value}/{repo_name}-{revision}':
{
'name': repo_name,
'path': str(rel_path),
'description': f'diffusers model {repo_id}',
'format': 'diffusers',
'image_size': info.image_size,
'base': info.base_type.value,
'variant': info.variant_type.value,
'prediction_type': info.prediction_type.value,
}
}
print(yaml.dump(stanza),file=dest_yaml,end="")
dest_yaml.flush()
except KeyboardInterrupt:
raise
except Exception as e:
logger.warning(f'Could not load the "{revision}" version of {repo_id}. Skipping.')
def migrate_checkpoints(dest_dir: Path, dest_yaml: io.TextIOBase):
# find any checkpoints referred to in old models.yaml
conf = OmegaConf.load('./configs/models.yaml')
orig_models_dir = Path.cwd() / 'models'
for model_name, stanza in conf.items():
if stanza.get('format') and stanza['format'] == 'ckpt':
try:
logger.info(f'Migrating checkpoint model {model_name}')
weights = orig_models_dir.parent / stanza['weights']
config = stanza['config']
info = ModelProbe().heuristic_probe(weights)
if not info:
continue
# uh oh, weights is in the old models directory - move it into the new one
if Path(weights).is_relative_to(orig_models_dir):
dest = Path(dest_dir, info.base_type.value, info.model_type.value,weights.name)
copy_file(weights,dest)
weights = Path('models', info.base_type.value, info.model_type.value,weights.name)
stanza = {
f'{info.base_type.value}/{info.model_type.value}/{model_name}':
{
'name': model_name,
'path': str(weights),
'description': f'checkpoint model {model_name}',
'format': 'checkpoint',
'image_size': info.image_size,
'base': info.base_type.value,
'variant': info.variant_type.value,
'config': config
}
}
print(yaml.dump(stanza),file=dest_yaml,end="")
dest_yaml.flush()
except KeyboardInterrupt:
raise
except Exception as e:
logger.error(str(e))
def main():
parser = argparse.ArgumentParser(description="Model directory migrator")
parser.add_argument('root_directory',
help='Root directory (containing "models", "embeddings", "controlnets" and "loras")'
)
parser.add_argument('--dest-directory',
default='./models-3.0',
help='Destination for new models directory',
)
parser.add_argument('--dest-yaml',
default='./models.yaml-3.0',
help='Destination for new models.yaml file',
)
args = parser.parse_args()
root_directory = Path(args.root_directory)
assert root_directory.is_dir(), f"{root_directory} is not a valid directory"
assert (root_directory / 'models').is_dir(), f"{root_directory} does not contain a 'models' subdirectory"
dest_directory = Path(args.dest_directory).resolve()
dest_yaml = Path(args.dest_yaml).resolve()
os.chdir(root_directory)
with open(dest_yaml,'w') as yaml_file:
print(yaml.dump({'__metadata__':
{'version':'3.0.0'}
}
),file=yaml_file,end=""
)
create_directory_structure(dest_directory)
migrate_support_models(dest_directory)
migrate_conversion_models(dest_directory)
migrate_tuning_models(dest_directory)
migrate_pipelines(dest_directory,yaml_file)
migrate_checkpoints(dest_directory,yaml_file)
if __name__ == '__main__':
main()