install model if diffusers or single file, cleaned up backend logic to not mess with existing model install

This commit is contained in:
Jennifer Player 2024-03-11 08:42:04 -04:00 committed by psychedelicious
parent 4c0896e436
commit 3a5314f1ca
4 changed files with 63 additions and 24 deletions

View File

@ -241,12 +241,16 @@ class ModelInstallService(ModelInstallServiceBase):
metadata = HuggingFaceMetadataFetch(self._session).from_id(source)
assert isinstance(metadata, ModelMetadataWithFiles)
remote_files = metadata.download_urls(
session=self._session,
)
urls: List[AnyHttpUrl] = []
# return array of remote_files.url
return [x.url for x in remote_files]
for file in metadata.files:
if str(file.url).endswith(
(".safetensors", ".bin", ".onnx", ".xml", ".pth", ".pt", ".ckpt", ".msgpack", "model_index.json")
):
urls.append(file.url)
self._logger.info(f"here are the metadata files {metadata.files}")
return urls
def list_jobs(self) -> List[ModelInstallJob]: # noqa D102
return self._install_jobs

View File

@ -51,15 +51,7 @@ def filter_files(
(
"learned_embeds.bin",
"ip_adapter.bin",
# jennifer added a bunch of these, probably will break something
".safetensors",
".bin",
".onnx",
".xml",
".pth",
".pt",
".ckpt",
".msgpack",
"lora_weights.safetensors",
"weights.pb",
"onnx_data",
)
@ -79,8 +71,7 @@ def filter_files(
paths = [x for x in paths if x.parent == Path(subfolder)]
# _filter_by_variant uniquifies the paths and returns a set
# jennifer removed the filter since it removed models, probably will break something but i dont understand why it removes valid models :|
return sorted(paths)
return sorted(_filter_by_variant(paths, variant))
@dataclass

View File

@ -1,24 +1,69 @@
import { Button, Flex, FormControl, FormErrorMessage, FormLabel, Input } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
import type { ChangeEventHandler } from 'react';
import { useCallback, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { useLazyGetHuggingFaceModelsQuery } from 'services/api/endpoints/models';
import { useInstallModelMutation, useLazyGetHuggingFaceModelsQuery } from 'services/api/endpoints/models';
import { HuggingFaceResults } from './HuggingFaceResults';
export const HuggingFaceForm = () => {
const [huggingFaceRepo, setHuggingFaceRepo] = useState('');
const [displayResults, setDisplayResults] = useState(false);
const [errorMessage, setErrorMessage] = useState('');
const { t } = useTranslation();
const dispatch = useAppDispatch();
const [_getHuggingFaceModels, { isLoading, data }] = useLazyGetHuggingFaceModelsQuery();
const [installModel] = useInstallModelMutation();
const handleInstallModel = useCallback((source: string) => {
installModel({ source })
.unwrap()
.then((_) => {
dispatch(
addToast(
makeToast({
title: t('toast.modelAddedSimple'),
status: 'success',
})
)
);
})
.catch((error) => {
if (error) {
dispatch(
addToast(
makeToast({
title: `${error.data.detail} `,
status: 'error',
})
)
);
}
});
}, [installModel, dispatch, t]);
const scanFolder = useCallback(async () => {
_getHuggingFaceModels(huggingFaceRepo).catch((error) => {
if (error) {
setErrorMessage(error.data.detail);
}
});
_getHuggingFaceModels(huggingFaceRepo)
.then((response) => {
if (response.data?.some((result) => result.endsWith('model_index.json'))) {
handleInstallModel(huggingFaceRepo);
setDisplayResults(false);
} else if (response.data?.length === 1 && response.data[0]) {
handleInstallModel(response.data[0]);
setDisplayResults(false);
} else {
setDisplayResults(true);
}
})
.catch((error) => {
if (error) {
setErrorMessage(error.data.detail);
}
});
}, [_getHuggingFaceModels, huggingFaceRepo]);
const handleSetHuggingFaceRepo: ChangeEventHandler<HTMLInputElement> = useCallback((e) => {
@ -43,7 +88,7 @@ export const HuggingFaceForm = () => {
{!!errorMessage.length && <FormErrorMessage>{errorMessage}</FormErrorMessage>}
</Flex>
</FormControl>
{data && <HuggingFaceResults results={data} />}
{data && displayResults && <HuggingFaceResults results={data} />}
</Flex>
);
};

View File

@ -21,7 +21,6 @@ import { useInstallModelMutation } from 'services/api/endpoints/models';
import { HuggingFaceResultItem } from './HuggingFaceResultItem';
type HuggingFaceResultsProps = {
// results: HuggingFaceFolderResponse;
results: string[];
};