Add support for custom config files

This commit is contained in:
blessedcoolant 2023-02-11 23:34:24 +13:00
parent 6e52ca3307
commit 310501cd8a
5 changed files with 134 additions and 61 deletions

View File

@ -43,7 +43,8 @@ if not os.path.isabs(args.outdir):
# normalize the config directory relative to root
if not os.path.isabs(opt.conf):
opt.conf = os.path.normpath(os.path.join(Globals.root,opt.conf))
opt.conf = os.path.normpath(os.path.join(Globals.root, opt.conf))
class InvokeAIWebServer:
def __init__(self, generate: Generate, gfpgan, codeformer, esrgan) -> None:
@ -189,7 +190,8 @@ class InvokeAIWebServer:
(width, height) = pil_image.size
thumbnail_path = save_thumbnail(
pil_image, os.path.basename(file_path), self.thumbnail_image_path
pil_image, os.path.basename(
file_path), self.thumbnail_image_path
)
response = {
@ -264,14 +266,16 @@ class InvokeAIWebServer:
# location for "finished" images
self.result_path = args.outdir
# temporary path for intermediates
self.intermediate_path = os.path.join(self.result_path, "intermediates/")
self.intermediate_path = os.path.join(
self.result_path, "intermediates/")
# path for user-uploaded init images and masks
self.init_image_path = os.path.join(self.result_path, "init-images/")
self.mask_image_path = os.path.join(self.result_path, "mask-images/")
# path for temp images e.g. gallery generations which are not committed
self.temp_image_path = os.path.join(self.result_path, "temp-images/")
# path for thumbnail images
self.thumbnail_image_path = os.path.join(self.result_path, "thumbnails/")
self.thumbnail_image_path = os.path.join(
self.result_path, "thumbnails/")
# txt log
self.log_path = os.path.join(self.result_path, "invoke_log.txt")
# make all output paths
@ -301,14 +305,16 @@ class InvokeAIWebServer:
try:
if not search_folder:
socketio.emit(
"foundModels",
{'search_folder': None, 'found_models': None},
)
"foundModels",
{'search_folder': None, 'found_models': None},
)
else:
search_folder, found_models = self.generate.model_manager.search_models(search_folder)
search_folder, found_models = self.generate.model_manager.search_models(
search_folder)
socketio.emit(
"foundModels",
{'search_folder': search_folder, 'found_models': found_models},
{'search_folder': search_folder,
'found_models': found_models},
)
except Exception as e:
self.socketio.emit("error", {"message": (str(e))})
@ -396,7 +402,6 @@ class InvokeAIWebServer:
@socketio.on('convertToDiffusers')
def convert_to_diffusers(model_to_convert: dict):
try:
if (model_info := self.generate.model_manager.model_info(model_name=model_to_convert['name'])):
if 'weights' in model_info:
ckpt_path = Path(model_info['weights'])
@ -404,15 +409,18 @@ class InvokeAIWebServer:
model_name = model_to_convert["name"]
model_description = model_info['description']
else:
self.socketio.emit("error", {"message": "Model is not a valid checkpoint file"})
self.socketio.emit(
"error", {"message": "Model is not a valid checkpoint file"})
else:
self.socketio.emit("error", {"message": "Could not retrieve model info."})
self.socketio.emit(
"error", {"message": "Could not retrieve model info."})
if not ckpt_path.is_absolute():
ckpt_path = Path(Globals.root,ckpt_path)
ckpt_path = Path(Globals.root, ckpt_path)
if original_config_file and not original_config_file.is_absolute():
original_config_file = Path(Globals.root, original_config_file)
original_config_file = Path(
Globals.root, original_config_file)
if model_to_convert['is_inpainting']:
original_config_file = Path(
@ -420,19 +428,24 @@ class InvokeAIWebServer:
'stable-diffusion',
'v1-inpainting-inference.yaml' if model_to_convert['is_inpainting'] else 'v1-inference.yaml'
)
diffusers_path = Path(f'{ckpt_path.parent.absolute()}\\{model_name}_diffusers')
if model_to_convert['custom_config'] is not None:
original_config_file = Path(
model_to_convert['custom_config'])
diffusers_path = Path(
f'{ckpt_path.parent.absolute()}\\{model_name}_diffusers')
if diffusers_path.exists():
shutil.rmtree(diffusers_path)
self.generate.model_manager.convert_and_import(
ckpt_path,
diffusers_path,
model_name=model_name,
model_description=model_description,
vae = None,
original_config_file = original_config_file,
vae=None,
original_config_file=original_config_file,
commit_to_conf=opt.conf,
)
@ -440,7 +453,7 @@ class InvokeAIWebServer:
socketio.emit(
"newModelAdded",
{"new_model_name": model_name,
"model_list": new_model_list, 'update': True},
"model_list": new_model_list, 'update': True},
)
print(f">> Model Converted: {model_name}")
except Exception as e:
@ -448,7 +461,7 @@ class InvokeAIWebServer:
print("\n")
traceback.print_exc()
print("\n")
print("\n")
@socketio.on("requestEmptyTempFolder")
def empty_temp_folder():
@ -463,7 +476,8 @@ class InvokeAIWebServer:
)
os.remove(thumbnail_path)
except Exception as e:
socketio.emit("error", {"message": f"Unable to delete {f}: {str(e)}"})
socketio.emit(
"error", {"message": f"Unable to delete {f}: {str(e)}"})
pass
socketio.emit("tempFolderEmptied")
@ -478,7 +492,8 @@ class InvokeAIWebServer:
def save_temp_image_to_gallery(url):
try:
image_path = self.get_image_path_from_url(url)
new_path = os.path.join(self.result_path, os.path.basename(image_path))
new_path = os.path.join(
self.result_path, os.path.basename(image_path))
shutil.copy2(image_path, new_path)
if os.path.splitext(new_path)[1] == ".png":
@ -491,7 +506,8 @@ class InvokeAIWebServer:
(width, height) = pil_image.size
thumbnail_path = save_thumbnail(
pil_image, os.path.basename(new_path), self.thumbnail_image_path
pil_image, os.path.basename(
new_path), self.thumbnail_image_path
)
image_array = [
@ -554,7 +570,8 @@ class InvokeAIWebServer:
(width, height) = pil_image.size
thumbnail_path = save_thumbnail(
pil_image, os.path.basename(path), self.thumbnail_image_path
pil_image, os.path.basename(
path), self.thumbnail_image_path
)
image_array.append(
@ -572,7 +589,8 @@ class InvokeAIWebServer:
}
)
except Exception as e:
socketio.emit("error", {"message": f"Unable to load {path}: {str(e)}"})
socketio.emit(
"error", {"message": f"Unable to load {path}: {str(e)}"})
pass
socketio.emit(
@ -626,7 +644,8 @@ class InvokeAIWebServer:
(width, height) = pil_image.size
thumbnail_path = save_thumbnail(
pil_image, os.path.basename(path), self.thumbnail_image_path
pil_image, os.path.basename(
path), self.thumbnail_image_path
)
image_array.append(
@ -645,7 +664,8 @@ class InvokeAIWebServer:
)
except Exception as e:
print(f">> Unable to load {path}")
socketio.emit("error", {"message": f"Unable to load {path}: {str(e)}"})
socketio.emit(
"error", {"message": f"Unable to load {path}: {str(e)}"})
pass
socketio.emit(
@ -683,7 +703,8 @@ class InvokeAIWebServer:
printable_parameters["init_mask"][:64] + "..."
)
print(f'\n>> Image Generation Parameters:\n\n{printable_parameters}\n')
print(
f'\n>> Image Generation Parameters:\n\n{printable_parameters}\n')
print(f'>> ESRGAN Parameters: {esrgan_parameters}')
print(f'>> Facetool Parameters: {facetool_parameters}')
@ -726,9 +747,11 @@ class InvokeAIWebServer:
if postprocessing_parameters["type"] == "esrgan":
progress.set_current_status("common:statusUpscalingESRGAN")
elif postprocessing_parameters["type"] == "gfpgan":
progress.set_current_status("common:statusRestoringFacesGFPGAN")
progress.set_current_status(
"common:statusRestoringFacesGFPGAN")
elif postprocessing_parameters["type"] == "codeformer":
progress.set_current_status("common:statusRestoringFacesCodeFormer")
progress.set_current_status(
"common:statusRestoringFacesCodeFormer")
socketio.emit("progressUpdate", progress.to_formatted_dict())
eventlet.sleep(0)
@ -904,7 +927,8 @@ class InvokeAIWebServer:
init_img_url = generation_parameters["init_img"]
original_bounding_box = generation_parameters["bounding_box"].copy()
original_bounding_box = generation_parameters["bounding_box"].copy(
)
initial_image = dataURL_to_image(
generation_parameters["init_img"]
@ -981,7 +1005,8 @@ class InvokeAIWebServer:
elif generation_parameters["generation_mode"] == "img2img":
init_img_url = generation_parameters["init_img"]
init_img_path = self.get_image_path_from_url(init_img_url)
generation_parameters["init_img"] = Image.open(init_img_path).convert('RGB')
generation_parameters["init_img"] = Image.open(
init_img_path).convert('RGB')
def image_progress(sample, step):
if self.canceled.is_set():
@ -1040,9 +1065,9 @@ class InvokeAIWebServer:
},
)
if generation_parameters["progress_latents"]:
image = self.generate.sample_to_lowres_estimated_image(sample)
image = self.generate.sample_to_lowres_estimated_image(
sample)
(width, height) = image.size
width *= 8
height *= 8
@ -1061,7 +1086,8 @@ class InvokeAIWebServer:
},
)
self.socketio.emit("progressUpdate", progress.to_formatted_dict())
self.socketio.emit(
"progressUpdate", progress.to_formatted_dict())
eventlet.sleep(0)
def image_done(image, seed, first_seed, attention_maps_image=None):
@ -1089,7 +1115,8 @@ class InvokeAIWebServer:
progress.set_current_status("common:statusGenerationComplete")
self.socketio.emit("progressUpdate", progress.to_formatted_dict())
self.socketio.emit(
"progressUpdate", progress.to_formatted_dict())
eventlet.sleep(0)
all_parameters = generation_parameters
@ -1100,7 +1127,8 @@ class InvokeAIWebServer:
and all_parameters["variation_amount"] > 0
):
first_seed = first_seed or seed
this_variation = [[seed, all_parameters["variation_amount"]]]
this_variation = [
[seed, all_parameters["variation_amount"]]]
all_parameters["with_variations"] = (
prior_variations + this_variation
)
@ -1116,7 +1144,8 @@ class InvokeAIWebServer:
if esrgan_parameters:
progress.set_current_status("common:statusUpscaling")
progress.set_current_status_has_steps(False)
self.socketio.emit("progressUpdate", progress.to_formatted_dict())
self.socketio.emit(
"progressUpdate", progress.to_formatted_dict())
eventlet.sleep(0)
image = self.esrgan.process(
@ -1139,12 +1168,15 @@ class InvokeAIWebServer:
if facetool_parameters:
if facetool_parameters["type"] == "gfpgan":
progress.set_current_status("common:statusRestoringFacesGFPGAN")
progress.set_current_status(
"common:statusRestoringFacesGFPGAN")
elif facetool_parameters["type"] == "codeformer":
progress.set_current_status("common:statusRestoringFacesCodeFormer")
progress.set_current_status(
"common:statusRestoringFacesCodeFormer")
progress.set_current_status_has_steps(False)
self.socketio.emit("progressUpdate", progress.to_formatted_dict())
self.socketio.emit(
"progressUpdate", progress.to_formatted_dict())
eventlet.sleep(0)
if facetool_parameters["type"] == "gfpgan":
@ -1174,7 +1206,8 @@ class InvokeAIWebServer:
all_parameters["facetool_type"] = facetool_parameters["type"]
progress.set_current_status("common:statusSavingImage")
self.socketio.emit("progressUpdate", progress.to_formatted_dict())
self.socketio.emit(
"progressUpdate", progress.to_formatted_dict())
eventlet.sleep(0)
# restore the stashed URLS and discard the paths, we are about to send the result to client
@ -1185,12 +1218,14 @@ class InvokeAIWebServer:
)
if "init_mask" in all_parameters:
all_parameters["init_mask"] = "" # TODO: store the mask in metadata
# TODO: store the mask in metadata
all_parameters["init_mask"] = ""
if generation_parameters["generation_mode"] == "unifiedCanvas":
all_parameters["bounding_box"] = original_bounding_box
metadata = self.parameters_to_generated_image_metadata(all_parameters)
metadata = self.parameters_to_generated_image_metadata(
all_parameters)
command = parameters_to_command(all_parameters)
@ -1220,15 +1255,18 @@ class InvokeAIWebServer:
if progress.total_iterations > progress.current_iteration:
progress.set_current_step(1)
progress.set_current_status("common:statusIterationComplete")
progress.set_current_status(
"common:statusIterationComplete")
progress.set_current_status_has_steps(False)
else:
progress.mark_complete()
self.socketio.emit("progressUpdate", progress.to_formatted_dict())
self.socketio.emit(
"progressUpdate", progress.to_formatted_dict())
eventlet.sleep(0)
parsed_prompt, _ = get_prompt_structure(generation_parameters["prompt"])
parsed_prompt, _ = get_prompt_structure(
generation_parameters["prompt"])
tokens = None if type(parsed_prompt) is Blend else \
get_tokens_for_prompt(self.generate.model, parsed_prompt)
attention_maps_image_base64_url = None if attention_maps_image is None \
@ -1402,7 +1440,8 @@ class InvokeAIWebServer:
self, parameters, original_image_path
):
try:
current_metadata = retrieve_metadata(original_image_path)["sd-metadata"]
current_metadata = retrieve_metadata(
original_image_path)["sd-metadata"]
postprocessing_metadata = {}
"""
@ -1442,7 +1481,8 @@ class InvokeAIWebServer:
postprocessing_metadata
)
else:
current_metadata["image"]["postprocessing"] = [postprocessing_metadata]
current_metadata["image"]["postprocessing"] = [
postprocessing_metadata]
return current_metadata
@ -1554,7 +1594,8 @@ class InvokeAIWebServer:
)
elif "thumbnails" in url:
return os.path.abspath(
os.path.join(self.thumbnail_image_path, os.path.basename(url))
os.path.join(self.thumbnail_image_path,
os.path.basename(url))
)
else:
return os.path.abspath(
@ -1723,10 +1764,12 @@ def dataURL_to_image(dataURL: str) -> ImageType:
)
return image
"""
Converts an image into a base64 image dataURL.
"""
def image_to_dataURL(image: ImageType) -> str:
buffered = io.BytesIO()
image.save(buffered, format="PNG")
@ -1736,7 +1779,6 @@ def image_to_dataURL(image: ImageType) -> str:
return image_base64
"""
Converts a base64 image dataURL into bytes.
The dataURL is split on the first commma.

View File

@ -74,5 +74,6 @@
"convertToDiffusersHelpText6": "Do you wish to convert this model?",
"inpaintingModel": "Inpainting Model",
"customConfig": "Custom Config",
"pathToCustomConfig": "Path To Custom Config"
"pathToCustomConfig": "Path To Custom Config",
"statusConverting": "Converting"
}

View File

@ -222,6 +222,7 @@ export declare type InvokeDiffusersModelConfigProps = {
export declare type InvokeModelConversionProps = {
name: string;
is_inpainting: boolean;
custom_config: string | null;
};
/**

View File

@ -178,8 +178,10 @@ const makeSocketIOEmitters = (
emitDeleteModel: (modelName: string) => {
socketio.emit('deleteModel', modelName);
},
emitConvertToDiffusers: (modelName: string) => {
socketio.emit('convertToDiffusers', modelName);
emitConvertToDiffusers: (
modelToConvert: InvokeAI.InvokeModelConversionProps
) => {
socketio.emit('convertToDiffusers', modelToConvert);
},
emitRequestModelChange: (modelName: string) => {
dispatch(modelChangeRequested());

View File

@ -25,6 +25,7 @@ export default function ModelConvert(props: ModelConvertProps) {
const [isInpainting, setIsInpainting] = useState<boolean>(false);
const [customConfig, setIsCustomConfig] = useState<boolean>(false);
const [pathToConfig, setPathToConfig] = useState<string>('');
const dispatch = useAppDispatch();
const { t } = useTranslation();
@ -37,20 +38,40 @@ export default function ModelConvert(props: ModelConvertProps) {
(state: RootState) => state.system.isConnected
);
useEffect(() => {
// Need to manually handle local state reset because the component does not re-render.
const stateReset = () => {
setIsInpainting(false);
setIsCustomConfig(false);
setPathToConfig('');
};
// Reset local state when model changes
useEffect(() => {
stateReset();
}, [model]);
// Handle local state reset when user cancels input
const modelConvertCancelHandler = () => {
stateReset();
};
const modelConvertHandler = () => {
const modelConvertData = {
name: model,
is_inpainting: isInpainting,
custom_config: customConfig && pathToConfig !== '' ? pathToConfig : null,
};
dispatch(setIsProcessing(true));
dispatch(convertToDiffusers({ name: model, is_inpainting: isInpainting }));
dispatch(convertToDiffusers(modelConvertData));
stateReset(); // Edge case: Cancel local state when model convert fails
};
return (
<IAIAlertDialog
title={`${t('modelmanager:convert')} ${model}`}
acceptCallback={modelConvertHandler}
cancelCallback={modelConvertCancelHandler}
acceptButtonText={`${t('modelmanager:convert')}`}
triggerComponent={
<IAIButton
@ -105,7 +126,13 @@ export default function ModelConvert(props: ModelConvertProps) {
>
{t('modelmanager:pathToCustomConfig')}
</Text>
<IAIInput width="25rem" />
<IAIInput
value={pathToConfig}
onChange={(e) => {
if (e.target.value !== '') setPathToConfig(e.target.value);
}}
width="25rem"
/>
</Flex>
)}
</Flex>