mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
address some of ebr issues
This commit is contained in:
parent
678bb4fe10
commit
ac6403f877
@ -95,8 +95,6 @@ class ModelInstall(object):
|
|||||||
prediction_type_helper: Callable[[Path],SchedulerPredictionType]=None,
|
prediction_type_helper: Callable[[Path],SchedulerPredictionType]=None,
|
||||||
access_token:str = None):
|
access_token:str = None):
|
||||||
self.config = config
|
self.config = config
|
||||||
with open('log.txt','w') as file:
|
|
||||||
print(config.model_conf_path,file=file)
|
|
||||||
self.mgr = ModelManager(config.model_conf_path)
|
self.mgr = ModelManager(config.model_conf_path)
|
||||||
self.datasets = OmegaConf.load(Dataset_path)
|
self.datasets = OmegaConf.load(Dataset_path)
|
||||||
self.prediction_helper = prediction_type_helper
|
self.prediction_helper = prediction_type_helper
|
||||||
@ -271,27 +269,36 @@ class ModelInstall(object):
|
|||||||
# we try to figure out how to download this most economically
|
# we try to figure out how to download this most economically
|
||||||
# list all the files in the repo
|
# list all the files in the repo
|
||||||
files = [x.rfilename for x in hinfo.siblings]
|
files = [x.rfilename for x in hinfo.siblings]
|
||||||
|
location = None
|
||||||
|
|
||||||
with TemporaryDirectory(dir=self.config.models_path) as staging:
|
with TemporaryDirectory(dir=self.config.models_path) as staging:
|
||||||
staging = Path(staging)
|
staging = Path(staging)
|
||||||
if 'model_index.json' in files:
|
if 'model_index.json' in files:
|
||||||
location = self._download_hf_pipeline(repo_id, staging) # pipeline
|
location = self._download_hf_pipeline(repo_id, staging) # pipeline
|
||||||
|
else:
|
||||||
elif 'pytorch_lora_weights.bin' in files:
|
for suffix in ['safetensors','bin']:
|
||||||
|
if f'pytorch_lora_weights.{suffix}' in files:
|
||||||
location = self._download_hf_model(repo_id, ['pytorch_lora_weights.bin'], staging) # LoRA
|
location = self._download_hf_model(repo_id, ['pytorch_lora_weights.bin'], staging) # LoRA
|
||||||
|
break
|
||||||
elif self.config.precision=='float16' and 'diffusion_pytorch_model.fp16.safetensors' in files: # vae, controlnet or some other standalone
|
elif self.config.precision=='float16' and f'diffusion_pytorch_model.fp16.{suffix}' in files: # vae, controlnet or some other standalone
|
||||||
files = ['config.json', 'diffusion_pytorch_model.fp16.safetensors']
|
files = ['config.json', f'diffusion_pytorch_model.fp16.{suffix}']
|
||||||
location = self._download_hf_model(repo_id, files, staging)
|
location = self._download_hf_model(repo_id, files, staging)
|
||||||
|
break
|
||||||
elif 'diffusion_pytorch_model.safetensors' in files:
|
elif f'diffusion_pytorch_model.{suffix}' in files:
|
||||||
files = ['config.json', 'diffusion_pytorch_model.safetensors']
|
files = ['config.json', f'diffusion_pytorch_model.{suffix}']
|
||||||
location = self._download_hf_model(repo_id, files, staging)
|
location = self._download_hf_model(repo_id, files, staging)
|
||||||
|
break
|
||||||
elif 'learned_embeds.bin' in files:
|
elif f'learned_embeds.{suffix}' in files:
|
||||||
location = self._download_hf_model(repo_id, ['learned_embeds.bin'], staging)
|
location = self._download_hf_model(repo_id, [f'learned_embeds.suffix'], staging)
|
||||||
|
break
|
||||||
|
if not location:
|
||||||
|
logger.warning(f'Could not determine type of repo {repo_id}. Skipping install.')
|
||||||
|
return
|
||||||
|
|
||||||
info = ModelProbe().heuristic_probe(location, self.prediction_helper)
|
info = ModelProbe().heuristic_probe(location, self.prediction_helper)
|
||||||
|
if not info:
|
||||||
|
logger.warning(f'Could not probe {location}. Skipping install.')
|
||||||
|
return
|
||||||
dest = self.config.models_path / info.base_type.value / info.model_type.value / self._get_model_name(repo_id,location)
|
dest = self.config.models_path / info.base_type.value / info.model_type.value / self._get_model_name(repo_id,location)
|
||||||
if dest.exists():
|
if dest.exists():
|
||||||
shutil.rmtree(dest)
|
shutil.rmtree(dest)
|
||||||
|
@ -1,118 +0,0 @@
|
|||||||
"""
|
|
||||||
Routines for downloading and installing models.
|
|
||||||
"""
|
|
||||||
import json
|
|
||||||
import safetensors
|
|
||||||
import safetensors.torch
|
|
||||||
import shutil
|
|
||||||
import tempfile
|
|
||||||
import torch
|
|
||||||
import traceback
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from diffusers import ModelMixin
|
|
||||||
from enum import Enum
|
|
||||||
from typing import Callable
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
|
||||||
from . import ModelManager
|
|
||||||
from .models import BaseModelType, ModelType, VariantType
|
|
||||||
from .model_probe import ModelProbe, ModelVariantInfo
|
|
||||||
from .model_cache import SilenceWarnings
|
|
||||||
|
|
||||||
class ModelInstall(object):
|
|
||||||
'''
|
|
||||||
This class is able to download and install several different kinds of
|
|
||||||
InvokeAI models. The helper function, if provided, is called on to distinguish
|
|
||||||
between v2-base and v2-768 stable diffusion pipelines. This usually involves
|
|
||||||
asking the user to select the proper type, as there is no way of distinguishing
|
|
||||||
the two type of v2 file programmatically (as far as I know).
|
|
||||||
'''
|
|
||||||
def __init__(self,
|
|
||||||
config: InvokeAIAppConfig,
|
|
||||||
model_base_helper: Callable[[Path],BaseModelType]=None,
|
|
||||||
clobber:bool = False
|
|
||||||
):
|
|
||||||
'''
|
|
||||||
:param config: InvokeAI configuration object
|
|
||||||
:param model_base_helper: A function call that accepts the Path to a checkpoint model and returns a ModelType enum
|
|
||||||
:param clobber: If true, models with colliding names will be overwritten
|
|
||||||
'''
|
|
||||||
self.config = config
|
|
||||||
self.clogger = clobber
|
|
||||||
self.helper = model_base_helper
|
|
||||||
self.prober = ModelProbe()
|
|
||||||
|
|
||||||
def install_checkpoint_file(self, checkpoint: Path)->dict:
|
|
||||||
'''
|
|
||||||
Install the checkpoint file at path and return a
|
|
||||||
configuration entry that can be added to `models.yaml`.
|
|
||||||
Model checkpoints and VAEs will be converted into
|
|
||||||
diffusers before installation. Note that the model manager
|
|
||||||
does not hold entries for anything but diffusers pipelines,
|
|
||||||
and the configuration file stanzas returned from such models
|
|
||||||
can be safely ignored.
|
|
||||||
'''
|
|
||||||
model_info = self.prober.probe(checkpoint, self.helper)
|
|
||||||
if not model_info:
|
|
||||||
raise ValueError(f"Unable to determine type of checkpoint file {checkpoint}")
|
|
||||||
|
|
||||||
key = ModelManager.create_key(
|
|
||||||
model_name = checkpoint.stem,
|
|
||||||
base_model = model_info.base_type,
|
|
||||||
model_type = model_info.model_type,
|
|
||||||
)
|
|
||||||
destination_path = self._dest_path(model_info) / checkpoint
|
|
||||||
destination_path.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
self._check_for_collision(destination_path)
|
|
||||||
stanza = {
|
|
||||||
key: dict(
|
|
||||||
name = checkpoint.stem,
|
|
||||||
description = f'{model_info.model_type} model {checkpoint.stem}',
|
|
||||||
base = model_info.base_model.value,
|
|
||||||
type = model_info.model_type.value,
|
|
||||||
variant = model_info.variant_type.value,
|
|
||||||
path = str(destination_path),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
# non-pipeline; no conversion needed, just copy into right place
|
|
||||||
if model_info.model_type != ModelType.Pipeline:
|
|
||||||
shutil.copyfile(checkpoint, destination_path)
|
|
||||||
stanza[key].update({'format': 'checkpoint'})
|
|
||||||
|
|
||||||
# pipeline - conversion needed here
|
|
||||||
else:
|
|
||||||
destination_path = self._dest_path(model_info) / checkpoint.stem
|
|
||||||
config_file = self._pipeline_type_to_config_file(model_info.model_type)
|
|
||||||
|
|
||||||
from .convert_ckpt_to_diffusers import convert_ckpt_to_diffusers
|
|
||||||
with SilenceWarnings:
|
|
||||||
convert_ckpt_to_diffusers(
|
|
||||||
checkpoint,
|
|
||||||
destination_path,
|
|
||||||
extract_ema=True,
|
|
||||||
original_config_file=config_file,
|
|
||||||
scan_needed=False,
|
|
||||||
)
|
|
||||||
stanza[key].update({'format': 'folder',
|
|
||||||
'path': destination_path, # no suffix on this
|
|
||||||
})
|
|
||||||
|
|
||||||
return stanza
|
|
||||||
|
|
||||||
|
|
||||||
def _check_for_collision(self, path: Path):
|
|
||||||
if not path.exists():
|
|
||||||
return
|
|
||||||
if self.clobber:
|
|
||||||
shutil.rmtree(path)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Destination {path} already exists. Won't overwrite unless clobber=True.")
|
|
||||||
|
|
||||||
def _staging_directory(self)->tempfile.TemporaryDirectory:
|
|
||||||
return tempfile.TemporaryDirectory(dir=self.config.root_path)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -703,7 +703,7 @@ class ModelManager(object):
|
|||||||
model_path = os.path.join(models_dir, entry_name)
|
model_path = os.path.join(models_dir, entry_name)
|
||||||
if model_path not in loaded_files: # TODO: check
|
if model_path not in loaded_files: # TODO: check
|
||||||
model_path = Path(model_path)
|
model_path = Path(model_path)
|
||||||
model_name = model_path.name if model_path.is_dir else model_path.stem
|
model_name = model_path.name if model_path.is_dir() else model_path.stem
|
||||||
model_key = self.create_key(model_name, base_model, model_type)
|
model_key = self.create_key(model_name, base_model, model_type)
|
||||||
|
|
||||||
if model_key in self.models:
|
if model_key in self.models:
|
||||||
|
@ -401,8 +401,16 @@ class ControlNetFolderProbe(FolderProbeBase):
|
|||||||
else BaseModelType.StableDiffusion2
|
else BaseModelType.StableDiffusion2
|
||||||
|
|
||||||
class LoRAFolderProbe(FolderProbeBase):
|
class LoRAFolderProbe(FolderProbeBase):
|
||||||
# I've never seen one of these in the wild, so this is a noop
|
def get_base_type(self)->BaseModelType:
|
||||||
pass
|
model_file = None
|
||||||
|
for suffix in ['safetensors','bin']:
|
||||||
|
base_file = self.folder_path / f'pytorch_lora_weights.{suffix}'
|
||||||
|
if base_file.exists():
|
||||||
|
model_file = base_file
|
||||||
|
break
|
||||||
|
if not model_file:
|
||||||
|
raise Exception('Unknown LoRA format encountered')
|
||||||
|
return LoRACheckpointProbe(model_file,None).get_base_type()
|
||||||
|
|
||||||
############## register probe classes ######
|
############## register probe classes ######
|
||||||
ModelProbe.register_probe('diffusers', ModelType.Pipeline, PipelineFolderProbe)
|
ModelProbe.register_probe('diffusers', ModelType.Pipeline, PipelineFolderProbe)
|
||||||
|
@ -87,7 +87,5 @@ sd-1/embedding/ahx-beta-453407d:
|
|||||||
repo_id: sd-concepts-library/ahx-beta-453407d
|
repo_id: sd-concepts-library/ahx-beta-453407d
|
||||||
sd-1/lora/LowRA:
|
sd-1/lora/LowRA:
|
||||||
path: https://civitai.com/api/download/models/63006
|
path: https://civitai.com/api/download/models/63006
|
||||||
sd-1/lora/Ink Scenery:
|
sd-1/lora/Ink scenery:
|
||||||
path: https://civitai.com/api/download/models/83390
|
path: https://civitai.com/api/download/models/83390
|
||||||
sd-1/lora/sd-model-finetuned-lora-t4:
|
|
||||||
repo_id: sayakpaul/sd-model-finetuned-lora-t4
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user