Merge branch 'main' into feat/merge-script

This commit is contained in:
Lincoln Stein 2023-01-23 09:05:38 -05:00 committed by GitHub
commit e18beaff9c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -21,7 +21,7 @@ import os
import re
import torch
from pathlib import Path
from ldm.invoke.globals import Globals
from ldm.invoke.globals import Globals, global_cache_dir
from safetensors.torch import load_file
try:
@ -637,7 +637,7 @@ def convert_ldm_bert_checkpoint(checkpoint, config):
def convert_ldm_clip_checkpoint(checkpoint):
text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14",cache_dir=global_cache_dir('hub'))
keys = list(checkpoint.keys())
@ -677,7 +677,8 @@ textenc_pattern = re.compile("|".join(protected.keys()))
def convert_paint_by_example_checkpoint(checkpoint):
config = CLIPVisionConfig.from_pretrained("openai/clip-vit-large-patch14")
cache_dir = global_cache_dir('hub')
config = CLIPVisionConfig.from_pretrained("openai/clip-vit-large-patch14",cache_dir=cache_dir)
model = PaintByExampleImageEncoder(config)
keys = list(checkpoint.keys())
@ -744,7 +745,8 @@ def convert_paint_by_example_checkpoint(checkpoint):
def convert_open_clip_checkpoint(checkpoint):
text_model = CLIPTextModel.from_pretrained("stabilityai/stable-diffusion-2", subfolder="text_encoder")
cache_dir=global_cache_dir('hub')
text_model = CLIPTextModel.from_pretrained("stabilityai/stable-diffusion-2", subfolder="text_encoder", cache_dir=cache_dir)
keys = list(checkpoint.keys())
@ -795,6 +797,7 @@ def convert_ckpt_to_diffuser(checkpoint_path:str,
):
checkpoint = load_file(checkpoint_path) if Path(checkpoint_path).suffix == '.safetensors' else torch.load(checkpoint_path)
cache_dir = global_cache_dir('hub')
# Sometimes models don't have the global_step item
if "global_step" in checkpoint:
@ -904,7 +907,7 @@ def convert_ckpt_to_diffuser(checkpoint_path:str,
if model_type == "FrozenOpenCLIPEmbedder":
text_model = convert_open_clip_checkpoint(checkpoint)
tokenizer = CLIPTokenizer.from_pretrained("stabilityai/stable-diffusion-2", subfolder="tokenizer")
tokenizer = CLIPTokenizer.from_pretrained("stabilityai/stable-diffusion-2", subfolder="tokenizer",cache_dir=global_cache_dir('diffusers'))
pipe = StableDiffusionPipeline(
vae=vae,
text_encoder=text_model,
@ -917,8 +920,8 @@ def convert_ckpt_to_diffuser(checkpoint_path:str,
)
elif model_type == "PaintByExample":
vision_model = convert_paint_by_example_checkpoint(checkpoint)
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
feature_extractor = AutoFeatureExtractor.from_pretrained("CompVis/stable-diffusion-safety-checker")
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14",cache_dir=cache_dir)
feature_extractor = AutoFeatureExtractor.from_pretrained("CompVis/stable-diffusion-safety-checker",cache_dir=cache_dir)
pipe = PaintByExamplePipeline(
vae=vae,
image_encoder=vision_model,
@ -929,9 +932,9 @@ def convert_ckpt_to_diffuser(checkpoint_path:str,
)
elif model_type in ['FrozenCLIPEmbedder','WeightedFrozenCLIPEmbedder']:
text_model = convert_ldm_clip_checkpoint(checkpoint)
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker")
feature_extractor = AutoFeatureExtractor.from_pretrained("CompVis/stable-diffusion-safety-checker")
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14",cache_dir=cache_dir)
safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker",cache_dir=cache_dir)
feature_extractor = AutoFeatureExtractor.from_pretrained("CompVis/stable-diffusion-safety-checker",cache_dir=cache_dir)
pipe = StableDiffusionPipeline(
vae=vae,
text_encoder=text_model,
@ -944,7 +947,7 @@ def convert_ckpt_to_diffuser(checkpoint_path:str,
else:
text_config = create_ldm_bert_config(original_config)
text_model = convert_ldm_bert_checkpoint(checkpoint, text_config)
tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased",cache_dir=cache_dir)
pipe = LDMTextToImagePipeline(vqvae=vae, bert=text_model, tokenizer=tokenizer, unet=unet, scheduler=scheduler)
pipe.save_pretrained(