Add support for custom config files

This commit is contained in:
blessedcoolant 2023-02-11 23:34:24 +13:00
parent 6e52ca3307
commit 310501cd8a
5 changed files with 134 additions and 61 deletions

View File

@ -43,7 +43,8 @@ if not os.path.isabs(args.outdir):
# normalize the config directory relative to root # normalize the config directory relative to root
if not os.path.isabs(opt.conf): if not os.path.isabs(opt.conf):
opt.conf = os.path.normpath(os.path.join(Globals.root,opt.conf)) opt.conf = os.path.normpath(os.path.join(Globals.root, opt.conf))
class InvokeAIWebServer: class InvokeAIWebServer:
def __init__(self, generate: Generate, gfpgan, codeformer, esrgan) -> None: def __init__(self, generate: Generate, gfpgan, codeformer, esrgan) -> None:
@ -189,7 +190,8 @@ class InvokeAIWebServer:
(width, height) = pil_image.size (width, height) = pil_image.size
thumbnail_path = save_thumbnail( thumbnail_path = save_thumbnail(
pil_image, os.path.basename(file_path), self.thumbnail_image_path pil_image, os.path.basename(
file_path), self.thumbnail_image_path
) )
response = { response = {
@ -264,14 +266,16 @@ class InvokeAIWebServer:
# location for "finished" images # location for "finished" images
self.result_path = args.outdir self.result_path = args.outdir
# temporary path for intermediates # temporary path for intermediates
self.intermediate_path = os.path.join(self.result_path, "intermediates/") self.intermediate_path = os.path.join(
self.result_path, "intermediates/")
# path for user-uploaded init images and masks # path for user-uploaded init images and masks
self.init_image_path = os.path.join(self.result_path, "init-images/") self.init_image_path = os.path.join(self.result_path, "init-images/")
self.mask_image_path = os.path.join(self.result_path, "mask-images/") self.mask_image_path = os.path.join(self.result_path, "mask-images/")
# path for temp images e.g. gallery generations which are not committed # path for temp images e.g. gallery generations which are not committed
self.temp_image_path = os.path.join(self.result_path, "temp-images/") self.temp_image_path = os.path.join(self.result_path, "temp-images/")
# path for thumbnail images # path for thumbnail images
self.thumbnail_image_path = os.path.join(self.result_path, "thumbnails/") self.thumbnail_image_path = os.path.join(
self.result_path, "thumbnails/")
# txt log # txt log
self.log_path = os.path.join(self.result_path, "invoke_log.txt") self.log_path = os.path.join(self.result_path, "invoke_log.txt")
# make all output paths # make all output paths
@ -301,14 +305,16 @@ class InvokeAIWebServer:
try: try:
if not search_folder: if not search_folder:
socketio.emit( socketio.emit(
"foundModels", "foundModels",
{'search_folder': None, 'found_models': None}, {'search_folder': None, 'found_models': None},
) )
else: else:
search_folder, found_models = self.generate.model_manager.search_models(search_folder) search_folder, found_models = self.generate.model_manager.search_models(
search_folder)
socketio.emit( socketio.emit(
"foundModels", "foundModels",
{'search_folder': search_folder, 'found_models': found_models}, {'search_folder': search_folder,
'found_models': found_models},
) )
except Exception as e: except Exception as e:
self.socketio.emit("error", {"message": (str(e))}) self.socketio.emit("error", {"message": (str(e))})
@ -396,7 +402,6 @@ class InvokeAIWebServer:
@socketio.on('convertToDiffusers') @socketio.on('convertToDiffusers')
def convert_to_diffusers(model_to_convert: dict): def convert_to_diffusers(model_to_convert: dict):
try: try:
if (model_info := self.generate.model_manager.model_info(model_name=model_to_convert['name'])): if (model_info := self.generate.model_manager.model_info(model_name=model_to_convert['name'])):
if 'weights' in model_info: if 'weights' in model_info:
ckpt_path = Path(model_info['weights']) ckpt_path = Path(model_info['weights'])
@ -404,15 +409,18 @@ class InvokeAIWebServer:
model_name = model_to_convert["name"] model_name = model_to_convert["name"]
model_description = model_info['description'] model_description = model_info['description']
else: else:
self.socketio.emit("error", {"message": "Model is not a valid checkpoint file"}) self.socketio.emit(
"error", {"message": "Model is not a valid checkpoint file"})
else: else:
self.socketio.emit("error", {"message": "Could not retrieve model info."}) self.socketio.emit(
"error", {"message": "Could not retrieve model info."})
if not ckpt_path.is_absolute(): if not ckpt_path.is_absolute():
ckpt_path = Path(Globals.root,ckpt_path) ckpt_path = Path(Globals.root, ckpt_path)
if original_config_file and not original_config_file.is_absolute(): if original_config_file and not original_config_file.is_absolute():
original_config_file = Path(Globals.root, original_config_file) original_config_file = Path(
Globals.root, original_config_file)
if model_to_convert['is_inpainting']: if model_to_convert['is_inpainting']:
original_config_file = Path( original_config_file = Path(
@ -420,19 +428,24 @@ class InvokeAIWebServer:
'stable-diffusion', 'stable-diffusion',
'v1-inpainting-inference.yaml' if model_to_convert['is_inpainting'] else 'v1-inference.yaml' 'v1-inpainting-inference.yaml' if model_to_convert['is_inpainting'] else 'v1-inference.yaml'
) )
diffusers_path = Path(f'{ckpt_path.parent.absolute()}\\{model_name}_diffusers') if model_to_convert['custom_config'] is not None:
original_config_file = Path(
model_to_convert['custom_config'])
diffusers_path = Path(
f'{ckpt_path.parent.absolute()}\\{model_name}_diffusers')
if diffusers_path.exists(): if diffusers_path.exists():
shutil.rmtree(diffusers_path) shutil.rmtree(diffusers_path)
self.generate.model_manager.convert_and_import( self.generate.model_manager.convert_and_import(
ckpt_path, ckpt_path,
diffusers_path, diffusers_path,
model_name=model_name, model_name=model_name,
model_description=model_description, model_description=model_description,
vae = None, vae=None,
original_config_file = original_config_file, original_config_file=original_config_file,
commit_to_conf=opt.conf, commit_to_conf=opt.conf,
) )
@ -440,7 +453,7 @@ class InvokeAIWebServer:
socketio.emit( socketio.emit(
"newModelAdded", "newModelAdded",
{"new_model_name": model_name, {"new_model_name": model_name,
"model_list": new_model_list, 'update': True}, "model_list": new_model_list, 'update': True},
) )
print(f">> Model Converted: {model_name}") print(f">> Model Converted: {model_name}")
except Exception as e: except Exception as e:
@ -448,7 +461,7 @@ class InvokeAIWebServer:
print("\n") print("\n")
traceback.print_exc() traceback.print_exc()
print("\n") print("\n")
@socketio.on("requestEmptyTempFolder") @socketio.on("requestEmptyTempFolder")
def empty_temp_folder(): def empty_temp_folder():
@ -463,7 +476,8 @@ class InvokeAIWebServer:
) )
os.remove(thumbnail_path) os.remove(thumbnail_path)
except Exception as e: except Exception as e:
socketio.emit("error", {"message": f"Unable to delete {f}: {str(e)}"}) socketio.emit(
"error", {"message": f"Unable to delete {f}: {str(e)}"})
pass pass
socketio.emit("tempFolderEmptied") socketio.emit("tempFolderEmptied")
@ -478,7 +492,8 @@ class InvokeAIWebServer:
def save_temp_image_to_gallery(url): def save_temp_image_to_gallery(url):
try: try:
image_path = self.get_image_path_from_url(url) image_path = self.get_image_path_from_url(url)
new_path = os.path.join(self.result_path, os.path.basename(image_path)) new_path = os.path.join(
self.result_path, os.path.basename(image_path))
shutil.copy2(image_path, new_path) shutil.copy2(image_path, new_path)
if os.path.splitext(new_path)[1] == ".png": if os.path.splitext(new_path)[1] == ".png":
@ -491,7 +506,8 @@ class InvokeAIWebServer:
(width, height) = pil_image.size (width, height) = pil_image.size
thumbnail_path = save_thumbnail( thumbnail_path = save_thumbnail(
pil_image, os.path.basename(new_path), self.thumbnail_image_path pil_image, os.path.basename(
new_path), self.thumbnail_image_path
) )
image_array = [ image_array = [
@ -554,7 +570,8 @@ class InvokeAIWebServer:
(width, height) = pil_image.size (width, height) = pil_image.size
thumbnail_path = save_thumbnail( thumbnail_path = save_thumbnail(
pil_image, os.path.basename(path), self.thumbnail_image_path pil_image, os.path.basename(
path), self.thumbnail_image_path
) )
image_array.append( image_array.append(
@ -572,7 +589,8 @@ class InvokeAIWebServer:
} }
) )
except Exception as e: except Exception as e:
socketio.emit("error", {"message": f"Unable to load {path}: {str(e)}"}) socketio.emit(
"error", {"message": f"Unable to load {path}: {str(e)}"})
pass pass
socketio.emit( socketio.emit(
@ -626,7 +644,8 @@ class InvokeAIWebServer:
(width, height) = pil_image.size (width, height) = pil_image.size
thumbnail_path = save_thumbnail( thumbnail_path = save_thumbnail(
pil_image, os.path.basename(path), self.thumbnail_image_path pil_image, os.path.basename(
path), self.thumbnail_image_path
) )
image_array.append( image_array.append(
@ -645,7 +664,8 @@ class InvokeAIWebServer:
) )
except Exception as e: except Exception as e:
print(f">> Unable to load {path}") print(f">> Unable to load {path}")
socketio.emit("error", {"message": f"Unable to load {path}: {str(e)}"}) socketio.emit(
"error", {"message": f"Unable to load {path}: {str(e)}"})
pass pass
socketio.emit( socketio.emit(
@ -683,7 +703,8 @@ class InvokeAIWebServer:
printable_parameters["init_mask"][:64] + "..." printable_parameters["init_mask"][:64] + "..."
) )
print(f'\n>> Image Generation Parameters:\n\n{printable_parameters}\n') print(
f'\n>> Image Generation Parameters:\n\n{printable_parameters}\n')
print(f'>> ESRGAN Parameters: {esrgan_parameters}') print(f'>> ESRGAN Parameters: {esrgan_parameters}')
print(f'>> Facetool Parameters: {facetool_parameters}') print(f'>> Facetool Parameters: {facetool_parameters}')
@ -726,9 +747,11 @@ class InvokeAIWebServer:
if postprocessing_parameters["type"] == "esrgan": if postprocessing_parameters["type"] == "esrgan":
progress.set_current_status("common:statusUpscalingESRGAN") progress.set_current_status("common:statusUpscalingESRGAN")
elif postprocessing_parameters["type"] == "gfpgan": elif postprocessing_parameters["type"] == "gfpgan":
progress.set_current_status("common:statusRestoringFacesGFPGAN") progress.set_current_status(
"common:statusRestoringFacesGFPGAN")
elif postprocessing_parameters["type"] == "codeformer": elif postprocessing_parameters["type"] == "codeformer":
progress.set_current_status("common:statusRestoringFacesCodeFormer") progress.set_current_status(
"common:statusRestoringFacesCodeFormer")
socketio.emit("progressUpdate", progress.to_formatted_dict()) socketio.emit("progressUpdate", progress.to_formatted_dict())
eventlet.sleep(0) eventlet.sleep(0)
@ -904,7 +927,8 @@ class InvokeAIWebServer:
init_img_url = generation_parameters["init_img"] init_img_url = generation_parameters["init_img"]
original_bounding_box = generation_parameters["bounding_box"].copy() original_bounding_box = generation_parameters["bounding_box"].copy(
)
initial_image = dataURL_to_image( initial_image = dataURL_to_image(
generation_parameters["init_img"] generation_parameters["init_img"]
@ -981,7 +1005,8 @@ class InvokeAIWebServer:
elif generation_parameters["generation_mode"] == "img2img": elif generation_parameters["generation_mode"] == "img2img":
init_img_url = generation_parameters["init_img"] init_img_url = generation_parameters["init_img"]
init_img_path = self.get_image_path_from_url(init_img_url) init_img_path = self.get_image_path_from_url(init_img_url)
generation_parameters["init_img"] = Image.open(init_img_path).convert('RGB') generation_parameters["init_img"] = Image.open(
init_img_path).convert('RGB')
def image_progress(sample, step): def image_progress(sample, step):
if self.canceled.is_set(): if self.canceled.is_set():
@ -1040,9 +1065,9 @@ class InvokeAIWebServer:
}, },
) )
if generation_parameters["progress_latents"]: if generation_parameters["progress_latents"]:
image = self.generate.sample_to_lowres_estimated_image(sample) image = self.generate.sample_to_lowres_estimated_image(
sample)
(width, height) = image.size (width, height) = image.size
width *= 8 width *= 8
height *= 8 height *= 8
@ -1061,7 +1086,8 @@ class InvokeAIWebServer:
}, },
) )
self.socketio.emit("progressUpdate", progress.to_formatted_dict()) self.socketio.emit(
"progressUpdate", progress.to_formatted_dict())
eventlet.sleep(0) eventlet.sleep(0)
def image_done(image, seed, first_seed, attention_maps_image=None): def image_done(image, seed, first_seed, attention_maps_image=None):
@ -1089,7 +1115,8 @@ class InvokeAIWebServer:
progress.set_current_status("common:statusGenerationComplete") progress.set_current_status("common:statusGenerationComplete")
self.socketio.emit("progressUpdate", progress.to_formatted_dict()) self.socketio.emit(
"progressUpdate", progress.to_formatted_dict())
eventlet.sleep(0) eventlet.sleep(0)
all_parameters = generation_parameters all_parameters = generation_parameters
@ -1100,7 +1127,8 @@ class InvokeAIWebServer:
and all_parameters["variation_amount"] > 0 and all_parameters["variation_amount"] > 0
): ):
first_seed = first_seed or seed first_seed = first_seed or seed
this_variation = [[seed, all_parameters["variation_amount"]]] this_variation = [
[seed, all_parameters["variation_amount"]]]
all_parameters["with_variations"] = ( all_parameters["with_variations"] = (
prior_variations + this_variation prior_variations + this_variation
) )
@ -1116,7 +1144,8 @@ class InvokeAIWebServer:
if esrgan_parameters: if esrgan_parameters:
progress.set_current_status("common:statusUpscaling") progress.set_current_status("common:statusUpscaling")
progress.set_current_status_has_steps(False) progress.set_current_status_has_steps(False)
self.socketio.emit("progressUpdate", progress.to_formatted_dict()) self.socketio.emit(
"progressUpdate", progress.to_formatted_dict())
eventlet.sleep(0) eventlet.sleep(0)
image = self.esrgan.process( image = self.esrgan.process(
@ -1139,12 +1168,15 @@ class InvokeAIWebServer:
if facetool_parameters: if facetool_parameters:
if facetool_parameters["type"] == "gfpgan": if facetool_parameters["type"] == "gfpgan":
progress.set_current_status("common:statusRestoringFacesGFPGAN") progress.set_current_status(
"common:statusRestoringFacesGFPGAN")
elif facetool_parameters["type"] == "codeformer": elif facetool_parameters["type"] == "codeformer":
progress.set_current_status("common:statusRestoringFacesCodeFormer") progress.set_current_status(
"common:statusRestoringFacesCodeFormer")
progress.set_current_status_has_steps(False) progress.set_current_status_has_steps(False)
self.socketio.emit("progressUpdate", progress.to_formatted_dict()) self.socketio.emit(
"progressUpdate", progress.to_formatted_dict())
eventlet.sleep(0) eventlet.sleep(0)
if facetool_parameters["type"] == "gfpgan": if facetool_parameters["type"] == "gfpgan":
@ -1174,7 +1206,8 @@ class InvokeAIWebServer:
all_parameters["facetool_type"] = facetool_parameters["type"] all_parameters["facetool_type"] = facetool_parameters["type"]
progress.set_current_status("common:statusSavingImage") progress.set_current_status("common:statusSavingImage")
self.socketio.emit("progressUpdate", progress.to_formatted_dict()) self.socketio.emit(
"progressUpdate", progress.to_formatted_dict())
eventlet.sleep(0) eventlet.sleep(0)
# restore the stashed URLS and discard the paths, we are about to send the result to client # restore the stashed URLS and discard the paths, we are about to send the result to client
@ -1185,12 +1218,14 @@ class InvokeAIWebServer:
) )
if "init_mask" in all_parameters: if "init_mask" in all_parameters:
all_parameters["init_mask"] = "" # TODO: store the mask in metadata # TODO: store the mask in metadata
all_parameters["init_mask"] = ""
if generation_parameters["generation_mode"] == "unifiedCanvas": if generation_parameters["generation_mode"] == "unifiedCanvas":
all_parameters["bounding_box"] = original_bounding_box all_parameters["bounding_box"] = original_bounding_box
metadata = self.parameters_to_generated_image_metadata(all_parameters) metadata = self.parameters_to_generated_image_metadata(
all_parameters)
command = parameters_to_command(all_parameters) command = parameters_to_command(all_parameters)
@ -1220,15 +1255,18 @@ class InvokeAIWebServer:
if progress.total_iterations > progress.current_iteration: if progress.total_iterations > progress.current_iteration:
progress.set_current_step(1) progress.set_current_step(1)
progress.set_current_status("common:statusIterationComplete") progress.set_current_status(
"common:statusIterationComplete")
progress.set_current_status_has_steps(False) progress.set_current_status_has_steps(False)
else: else:
progress.mark_complete() progress.mark_complete()
self.socketio.emit("progressUpdate", progress.to_formatted_dict()) self.socketio.emit(
"progressUpdate", progress.to_formatted_dict())
eventlet.sleep(0) eventlet.sleep(0)
parsed_prompt, _ = get_prompt_structure(generation_parameters["prompt"]) parsed_prompt, _ = get_prompt_structure(
generation_parameters["prompt"])
tokens = None if type(parsed_prompt) is Blend else \ tokens = None if type(parsed_prompt) is Blend else \
get_tokens_for_prompt(self.generate.model, parsed_prompt) get_tokens_for_prompt(self.generate.model, parsed_prompt)
attention_maps_image_base64_url = None if attention_maps_image is None \ attention_maps_image_base64_url = None if attention_maps_image is None \
@ -1402,7 +1440,8 @@ class InvokeAIWebServer:
self, parameters, original_image_path self, parameters, original_image_path
): ):
try: try:
current_metadata = retrieve_metadata(original_image_path)["sd-metadata"] current_metadata = retrieve_metadata(
original_image_path)["sd-metadata"]
postprocessing_metadata = {} postprocessing_metadata = {}
""" """
@ -1442,7 +1481,8 @@ class InvokeAIWebServer:
postprocessing_metadata postprocessing_metadata
) )
else: else:
current_metadata["image"]["postprocessing"] = [postprocessing_metadata] current_metadata["image"]["postprocessing"] = [
postprocessing_metadata]
return current_metadata return current_metadata
@ -1554,7 +1594,8 @@ class InvokeAIWebServer:
) )
elif "thumbnails" in url: elif "thumbnails" in url:
return os.path.abspath( return os.path.abspath(
os.path.join(self.thumbnail_image_path, os.path.basename(url)) os.path.join(self.thumbnail_image_path,
os.path.basename(url))
) )
else: else:
return os.path.abspath( return os.path.abspath(
@ -1723,10 +1764,12 @@ def dataURL_to_image(dataURL: str) -> ImageType:
) )
return image return image
""" """
Converts an image into a base64 image dataURL. Converts an image into a base64 image dataURL.
""" """
def image_to_dataURL(image: ImageType) -> str: def image_to_dataURL(image: ImageType) -> str:
buffered = io.BytesIO() buffered = io.BytesIO()
image.save(buffered, format="PNG") image.save(buffered, format="PNG")
@ -1736,7 +1779,6 @@ def image_to_dataURL(image: ImageType) -> str:
return image_base64 return image_base64
""" """
Converts a base64 image dataURL into bytes. Converts a base64 image dataURL into bytes.
The dataURL is split on the first commma. The dataURL is split on the first commma.

View File

@ -74,5 +74,6 @@
"convertToDiffusersHelpText6": "Do you wish to convert this model?", "convertToDiffusersHelpText6": "Do you wish to convert this model?",
"inpaintingModel": "Inpainting Model", "inpaintingModel": "Inpainting Model",
"customConfig": "Custom Config", "customConfig": "Custom Config",
"pathToCustomConfig": "Path To Custom Config" "pathToCustomConfig": "Path To Custom Config",
"statusConverting": "Converting"
} }

View File

@ -222,6 +222,7 @@ export declare type InvokeDiffusersModelConfigProps = {
export declare type InvokeModelConversionProps = { export declare type InvokeModelConversionProps = {
name: string; name: string;
is_inpainting: boolean; is_inpainting: boolean;
custom_config: string | null;
}; };
/** /**

View File

@ -178,8 +178,10 @@ const makeSocketIOEmitters = (
emitDeleteModel: (modelName: string) => { emitDeleteModel: (modelName: string) => {
socketio.emit('deleteModel', modelName); socketio.emit('deleteModel', modelName);
}, },
emitConvertToDiffusers: (modelName: string) => { emitConvertToDiffusers: (
socketio.emit('convertToDiffusers', modelName); modelToConvert: InvokeAI.InvokeModelConversionProps
) => {
socketio.emit('convertToDiffusers', modelToConvert);
}, },
emitRequestModelChange: (modelName: string) => { emitRequestModelChange: (modelName: string) => {
dispatch(modelChangeRequested()); dispatch(modelChangeRequested());

View File

@ -25,6 +25,7 @@ export default function ModelConvert(props: ModelConvertProps) {
const [isInpainting, setIsInpainting] = useState<boolean>(false); const [isInpainting, setIsInpainting] = useState<boolean>(false);
const [customConfig, setIsCustomConfig] = useState<boolean>(false); const [customConfig, setIsCustomConfig] = useState<boolean>(false);
const [pathToConfig, setPathToConfig] = useState<string>('');
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { t } = useTranslation(); const { t } = useTranslation();
@ -37,20 +38,40 @@ export default function ModelConvert(props: ModelConvertProps) {
(state: RootState) => state.system.isConnected (state: RootState) => state.system.isConnected
); );
useEffect(() => { // Need to manually handle local state reset because the component does not re-render.
const stateReset = () => {
setIsInpainting(false); setIsInpainting(false);
setIsCustomConfig(false); setIsCustomConfig(false);
setPathToConfig('');
};
// Reset local state when model changes
useEffect(() => {
stateReset();
}, [model]); }, [model]);
// Handle local state reset when user cancels input
const modelConvertCancelHandler = () => {
stateReset();
};
const modelConvertHandler = () => { const modelConvertHandler = () => {
const modelConvertData = {
name: model,
is_inpainting: isInpainting,
custom_config: customConfig && pathToConfig !== '' ? pathToConfig : null,
};
dispatch(setIsProcessing(true)); dispatch(setIsProcessing(true));
dispatch(convertToDiffusers({ name: model, is_inpainting: isInpainting })); dispatch(convertToDiffusers(modelConvertData));
stateReset(); // Edge case: Cancel local state when model convert fails
}; };
return ( return (
<IAIAlertDialog <IAIAlertDialog
title={`${t('modelmanager:convert')} ${model}`} title={`${t('modelmanager:convert')} ${model}`}
acceptCallback={modelConvertHandler} acceptCallback={modelConvertHandler}
cancelCallback={modelConvertCancelHandler}
acceptButtonText={`${t('modelmanager:convert')}`} acceptButtonText={`${t('modelmanager:convert')}`}
triggerComponent={ triggerComponent={
<IAIButton <IAIButton
@ -105,7 +126,13 @@ export default function ModelConvert(props: ModelConvertProps) {
> >
{t('modelmanager:pathToCustomConfig')} {t('modelmanager:pathToCustomConfig')}
</Text> </Text>
<IAIInput width="25rem" /> <IAIInput
value={pathToConfig}
onChange={(e) => {
if (e.target.value !== '') setPathToConfig(e.target.value);
}}
width="25rem"
/>
</Flex> </Flex>
)} )}
</Flex> </Flex>