mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
new TUI is fully functional; needs some polishing
This commit is contained in:
@ -95,9 +95,12 @@ def install_requested_models(
|
||||
model_manager.install_controlnet_models(controlnet.install_models, access_token=access_token)
|
||||
model_manager.delete_controlnet_models(controlnet.remove_models)
|
||||
|
||||
model_manager.install_lora_models(lora.install_models)
|
||||
model_manager.install_lora_models(lora.install_models, access_token=access_token)
|
||||
model_manager.delete_lora_models(lora.remove_models)
|
||||
|
||||
model_manager.install_ti_models(ti.install_models, access_token=access_token)
|
||||
model_manager.delete_ti_models(ti.remove_models)
|
||||
|
||||
# TODO: Replace next three paragraphs with calls into new model manager
|
||||
if diffusers.remove_models and len(diffusers.remove_models) > 0:
|
||||
logger.info("DELETING UNCHECKED STARTER MODELS")
|
||||
@ -109,7 +112,7 @@ def install_requested_models(
|
||||
if diffusers.install_models and len(diffusers.install_models) > 0:
|
||||
logger.info("INSTALLING SELECTED STARTER MODELS")
|
||||
successfully_downloaded = download_weight_datasets(
|
||||
models=diffusers.install_initial_models,
|
||||
models=diffusers.install_models,
|
||||
access_token=None,
|
||||
precision=precision,
|
||||
) # FIX: for historical reasons, we don't use model manager here
|
||||
|
@ -1334,24 +1334,54 @@ class ModelManager(object):
|
||||
installed_models = dict()
|
||||
for root, dirs, files in os.walk(dir):
|
||||
for name in files:
|
||||
if Path(name).suffix in ['.safetensors','.ckpt','.pt']:
|
||||
installed_models.update({name: True})
|
||||
if Path(name).suffix not in ['.safetensors','.ckpt','.pt','.bin']:
|
||||
continue
|
||||
if name == 'pytorch_lora_weights.bin':
|
||||
name = Path(root,name).parent.stem #Path(root,name).stem
|
||||
else:
|
||||
name = Path(name).stem
|
||||
installed_models.update({name: True})
|
||||
|
||||
return installed_models
|
||||
|
||||
def install_lora_models(self, model_names: list[str]):
|
||||
def install_lora_models(self, model_names: list[str], access_token:str=None):
|
||||
'''Download list of LoRA/LyCORIS models'''
|
||||
|
||||
short_names = OmegaConf.load(Dataset_path).get('lora') or {}
|
||||
for name in model_names:
|
||||
url = short_names.get(name) or name
|
||||
download_with_resume(url, self.globals.lora_path)
|
||||
print(name)
|
||||
|
||||
name = short_names.get(name) or name
|
||||
|
||||
# HuggingFace style LoRA
|
||||
if re.match(r"^[\w.+-]+/([\w.+-]+)$", name):
|
||||
self.logger.info(f'Downloading LoRA/LyCORIS model {name}')
|
||||
_,dest_dir = name.split("/")
|
||||
|
||||
hf_download_with_resume(
|
||||
repo_id = name,
|
||||
model_dir = self.globals.lora_path / dest_dir,
|
||||
model_name = 'pytorch_lora_weights.bin',
|
||||
access_token = access_token,
|
||||
)
|
||||
|
||||
elif name.startswith(("http:", "https:", "ftp:")):
|
||||
download_with_resume(name, self.globals.lora_path)
|
||||
|
||||
else:
|
||||
self.logger.error(f"Unknown repo_id or URL: {name}")
|
||||
|
||||
def delete_lora_models(self, model_names: List[str]):
|
||||
'''Remove the list of lora models'''
|
||||
for name in model_names:
|
||||
path = self.globals.lora_path / name
|
||||
if path.exists():
|
||||
self.logger.info(f'Purging lora model {name}')
|
||||
path.unlink()
|
||||
file_or_directory = self.globals.lora_path / name
|
||||
if file_or_directory.is_dir():
|
||||
self.logger.info(f'Purging LoRA/LyCORIS {name}')
|
||||
shutil.rmtree(str(file_or_directory))
|
||||
else:
|
||||
for path in self.globals.lora_path.glob(f'{name}.*'):
|
||||
self.logger.info(f'Purging LoRA/LyCORIS {name}')
|
||||
path.unlink()
|
||||
|
||||
def list_ti_models(self)->Dict[str,bool]:
|
||||
'''Return a dict of installed textual models; key is either the shortname
|
||||
@ -1362,21 +1392,50 @@ class ModelManager(object):
|
||||
installed_models = {x: False for x in models.keys()}
|
||||
|
||||
dir = self.globals.embedding_path
|
||||
installed_models = dict()
|
||||
for root, dirs, files in os.walk(dir):
|
||||
for name in files:
|
||||
if not Path(name).suffix in ['.bin','.pt','.ckpt','.safetensors']:
|
||||
continue
|
||||
if name == 'learned_embeds.bin':
|
||||
name = str(Path(root,name).parent)
|
||||
name = Path(root,name).parent.stem #Path(root,name).stem
|
||||
else:
|
||||
name = Path(name).stem
|
||||
installed_models.update({name: True})
|
||||
return installed_models
|
||||
|
||||
def install_ti_models(self, model_names: list[str]):
|
||||
def install_ti_models(self, model_names: list[str], access_token: str=None):
|
||||
'''Download list of textual inversion embeddings'''
|
||||
|
||||
short_names = OmegaConf.load(Dataset_path).get('textual_inversion') or {}
|
||||
for name in model_names:
|
||||
url = short_names.get(name) or name
|
||||
download_with_resume(url, self.globals.embedding_path)
|
||||
name = short_names.get(name) or name
|
||||
|
||||
if re.match(r"^[\w.+-]+/([\w.+-]+)$", name):
|
||||
self.logger.info(f'Downloading Textual Inversion embedding {name}')
|
||||
_,dest_dir = name.split("/")
|
||||
hf_download_with_resume(
|
||||
repo_id = name,
|
||||
model_dir = self.globals.embedding_path / dest_dir,
|
||||
model_name = 'learned_embeds.bin',
|
||||
access_token = access_token
|
||||
)
|
||||
elif name.startswith(('http:','https:','ftp:')):
|
||||
download_with_resume(name, self.globals.embedding_path)
|
||||
else:
|
||||
self.logger.error(f'{name} does not look like either a HuggingFace repo_id or a downloadable URL')
|
||||
|
||||
def delete_ti_models(self, model_names: list[str]):
|
||||
'''Remove TI embeddings from disk'''
|
||||
for name in model_names:
|
||||
file_or_directory = self.globals.embedding_path / name
|
||||
if file_or_directory.is_dir():
|
||||
self.logger.info(f'Purging textual inversion embedding {name}')
|
||||
shutil.rmtree(str(file_or_directory))
|
||||
else:
|
||||
for path in self.globals.embedding_path.glob(f'{name}.*'):
|
||||
self.logger.info(f'Purging textual inversion embedding {name}')
|
||||
path.unlink()
|
||||
|
||||
def list_controlnet_models(self)->Dict[str,bool]:
|
||||
'''Return a dict of installed controlnet models; key is repo_id or short name
|
||||
of model (defined in INITIAL_MODELS), and value is True if installed'''
|
||||
|
@ -322,8 +322,8 @@ def download_with_resume(url: str, dest: Path, access_token: str = None) -> Path
|
||||
logger.warning("corrupt existing file found. re-downloading")
|
||||
os.remove(dest)
|
||||
exist_size = 0
|
||||
|
||||
if resp.status_code == 416 or exist_size == content_length:
|
||||
|
||||
if resp.status_code == 416 or (content_length > 0 and exist_size == content_length):
|
||||
logger.warning(f"{dest}: complete file found. Skipping.")
|
||||
return dest
|
||||
elif resp.status_code == 206 or exist_size > 0:
|
||||
|
Reference in New Issue
Block a user