Model Manager Backend Implementation

This commit is contained in:
blessedcoolant 2022-12-25 09:28:46 +13:00
parent e66b1a685c
commit 3521557541
2 changed files with 106 additions and 1 deletions

View File

@ -9,6 +9,7 @@ import io
import base64
import os
import json
import tkinter as tk
from werkzeug.utils import secure_filename
from flask import Flask, redirect, send_from_directory, request, make_response
@ -17,6 +18,7 @@ from PIL import Image, ImageOps
from PIL.Image import Image as ImageType
from uuid import uuid4
from threading import Event
from tkinter import filedialog
from ldm.generate import Generate
from ldm.invoke.args import Args, APP_ID, APP_VERSION, calculate_init_img_hash
@ -297,6 +299,87 @@ class InvokeAIWebServer:
config["infill_methods"] = infill_methods()
socketio.emit("systemConfig", config)
@socketio.on('searchForModels')
def handle_search_models():
try:
# Using tkinter to get the filepath because JS doesn't allow
root = tk.Tk()
root.iconify() # for macos
root.withdraw()
root.wm_attributes('-topmost', 1)
root.focus_force()
search_folder = filedialog.askdirectory(parent=root, title='Select Checkpoint Folder')
root.destroy()
if not search_folder:
socketio.emit(
"foundModels",
{'search_folder': None, 'found_models': None},
)
else:
search_folder, found_models = self.generate.model_cache.search_models(search_folder)
socketio.emit(
"foundModels",
{'search_folder': search_folder, 'found_models': found_models},
)
except Exception as e:
self.socketio.emit("error", {"message": (str(e))})
print("\n")
traceback.print_exc()
print("\n")
@socketio.on("addNewModel")
def handle_add_model(new_model_config: dict):
try:
model_name = new_model_config['name']
del new_model_config['name']
model_attributes = new_model_config
update = False
current_model_list = self.generate.model_cache.list_models()
if model_name in current_model_list:
update = True
print(f">> Adding New Model: {model_name}")
self.generate.model_cache.add_model(
model_name=model_name, model_attributes=model_attributes, clobber=True)
self.generate.model_cache.commit(opt.conf)
new_model_list = self.generate.model_cache.list_models()
socketio.emit(
"newModelAdded",
{"new_model_name": model_name,
"model_list": new_model_list, 'update': update},
)
print(f">> New Model Added: {model_name}")
except Exception as e:
self.socketio.emit("error", {"message": (str(e))})
print("\n")
traceback.print_exc()
print("\n")
@socketio.on("deleteModel")
def handle_delete_model(model_name: str):
try:
print(f">> Deleting Model: {model_name}")
self.generate.model_cache.del_model(model_name)
self.generate.model_cache.commit(opt.conf)
updated_model_list = self.generate.model_cache.list_models()
socketio.emit(
"modelDeleted",
{"deleted_model_name": model_name,
"model_list": updated_model_list},
)
print(f">> Model Deleted: {model_name}")
except Exception as e:
self.socketio.emit("error", {"message": (str(e))})
print("\n")
traceback.print_exc()
print("\n")
@socketio.on("requestModelChange")
def handle_set_model(model_name: str):
try:

View File

@ -23,6 +23,7 @@ from omegaconf.errors import ConfigAttributeError
from ldm.util import instantiate_from_config, ask_user
from ldm.invoke.globals import Globals
from picklescan.scanner import scan_file_path
from pathlib import Path
DEFAULT_MAX_MODELS=2
@ -135,8 +136,10 @@ class ModelCache(object):
for name in self.config:
try:
description = self.config[name].description
weights = self.config[name].weights
except ConfigAttributeError:
description = '<no description>'
weights = '<not found>'
if self.current_model == name:
status = 'active'
@ -147,7 +150,8 @@ class ModelCache(object):
models[name]={
'status' : status,
'description' : description
'description' : description,
'weights': weights
}
return models
@ -186,6 +190,8 @@ class ModelCache(object):
config = omega[model_name] if model_name in omega else {}
for field in model_attributes:
if field == 'weights':
field.replace('\\', '/')
config[field] = model_attributes[field]
omega[model_name] = config
@ -311,6 +317,22 @@ class ModelCache(object):
sys.exit()
else:
print('>> Model Scanned. OK!!')
def search_models(self, search_folder):
print(f'>> Finding Models In: {search_folder}')
models_folder = Path(search_folder).glob('**/*.ckpt')
files = [x for x in models_folder if x.is_file()]
found_models = []
for file in files:
found_models.append({
'name': file.stem,
'location': str(file.resolve()).replace('\\', '/')
})
return search_folder, found_models
def _make_cache_room(self) -> None:
num_loaded_models = len(self.models)