Compare commits

..

748 Commits

Author SHA1 Message Date
37c2b57791 simplify 2023-06-22 14:51:27 -04:00
bcd3cb645f use BASE and TOKEN from OpenAPI if they are set 2023-06-22 14:47:55 -04:00
22c337b1aa Update UI To Use New Model Manager (#3548)
PR for the Model Manager UI work related to 3.0

[DONE]

- Update ModelType Config names to be specific so that the front end can
parse them correctly.
- Rebuild frontend schema to reflect these changes.
- Update Linear UI Text To Image and Image to Image to work with the new
model loader.
- Updated the ModelInput component in the Node Editor to work with the
new changes.

[TODO REMEMBER]

- Add proper types for ModelLoaderType in `ModelSelect.tsx`

[TODO] 

- Everything else.
2023-06-22 22:06:26 +12:00
339e7ce213 feat(ui): initial implementation of model loading
- Update model listing code to use `rtk-query`
- Update all graph generation to use new `pipeline_model_loader` node
2023-06-22 17:48:57 +10:00
2a178f5a25 chore(ui): regen api client 2023-06-22 17:48:13 +10:00
1bc170727b tidy(nodes): rename sd_model_loader to pipeline_model_loader
this is more accurate bc it can do eg kandinsky also
2023-06-22 17:47:58 +10:00
3722cdf5d6 chore(ui): regen api client 2023-06-22 17:36:20 +10:00
42a59aa147 feat(nodes): add sd_model_loader node
Loads any pipeline model.

Also introduced is `PipelineModelField`, which includes a model name and base model.
2023-06-22 17:36:05 +10:00
b937b7da01 feat(models): update model manager service & route to return list of models 2023-06-22 17:34:12 +10:00
21245a0fb2 Set model type to const value in openapi schema, add model format enums to model schema(as they not not referenced in case of Literal definition) 2023-06-22 16:51:53 +10:00
da566b59e8 Update model format field to use enums 2023-06-22 16:51:53 +10:00
e4dc9c5a04 Rename format to model_format(still named format when work with config) 2023-06-22 16:51:53 +10:00
aceadacad4 Remove default model logic 2023-06-22 16:51:53 +10:00
d3dec59cc3 tweal: UI colors 2023-06-22 16:51:53 +10:00
6c98700740 fix: Adjust the Schedular select width
So the long names do not get cut off.
2023-06-22 16:51:53 +10:00
c4c3c96062 Revert "feat: Port Schedulers to Mantine"
This reverts commit e0c105f413.
2023-06-22 16:51:35 +10:00
6256be480c fix: Remove type from Model type name 2023-06-22 16:48:35 +10:00
7033071934 fix: Unserialization key issue 2023-06-22 16:48:35 +10:00
e48528bbef revert: getModels to receivedModels 2023-06-22 16:48:35 +10:00
6bdf68dd4c feat: Port Schedulers to Mantine 2023-06-22 16:48:35 +10:00
0c3616229e cleanup: Updated model slice names to be more descriptive
Basically updated all slices to be more descriptive in their names. Did so in order to make sure theres good naming scheme available for secondary models.
2023-06-22 16:43:14 +10:00
604cc1adcd wip: Move Model Selector to own file 2023-06-22 16:43:14 +10:00
4847212d5b feat: Enable 2.x Model Generation in Linear UI 2023-06-22 16:43:14 +10:00
727293d722 fix: 2.1 models breaking generation
Co-Authored-By: StAlKeR7779 <7768370+StAlKeR7779@users.noreply.github.com>
2023-06-22 16:42:59 +10:00
d2f3500e1b chore: Rebuild API - base_model and type added 2023-06-22 16:42:59 +10:00
ef83a2fffe Add name, base_mode, type fields to model info 2023-06-22 16:42:51 +10:00
f8d7477c7a wip: Add 2.x Models to the Model List 2023-06-22 16:42:51 +10:00
e374211313 chore: Rebuild API with new Model API names 2023-06-22 16:41:31 +10:00
01d17601b8 Generate config names for openapi 2023-06-22 16:41:19 +10:00
bf0d5f4cfc fix: Update missing name types to new names 2023-06-22 16:41:02 +10:00
663f4935f5 chore: Rebuild API 2023-06-22 16:41:02 +10:00
9838dda1b7 chore: Update model config type names 2023-06-22 16:40:40 +10:00
2d889e133d chore(ui): regen api client 2023-06-22 16:25:49 +10:00
6779f1a5ad fix(db): update models for boards w/ nullable deleted_at 2023-06-22 16:25:49 +10:00
19a6e5dad8 chore(ui): regen api client 2023-06-22 16:25:49 +10:00
285195bf72 feat(api): add get_board route 2023-06-22 16:25:49 +10:00
10008859a4 tidy(ui): remove all refs to boards thunks 2023-06-22 16:25:49 +10:00
3c04340f3f tidy(ui): tidy up update image board modal 2023-06-22 16:25:49 +10:00
79f0c4d3c4 feat(ui): add remove from board to image context menu 2023-06-22 16:25:49 +10:00
37d4e05838 fix(ui): fix board's image list not updating when image removed from board 2023-06-22 16:25:49 +10:00
a00ad6ac03 feat(ui): dropping image on All Images board removes it from board 2023-06-22 16:25:49 +10:00
2ffead000c tidy(ui): remove console.log() 2023-06-22 16:25:49 +10:00
922319cb84 fix(ui): fix first added board doesn't show until refresh
Had incorrect `invalidatesTags` array for the mutation.
2023-06-22 16:25:49 +10:00
6ee0e197bb feat(db): add deleted_at to board_images 2023-06-22 16:25:49 +10:00
d3e6f0130c fix(ui): fix issue with gallery not letting you load more images
To determine whether the Load More button should work, we need to keep track of how many images are left to load for a given board or category.

The Assets tab doesn't work, though. Need to figure out a better way to handle this.
2023-06-22 16:25:49 +10:00
421c23d3ea fix(ui): fix gallery image fetching for board categories 2023-06-22 16:25:49 +10:00
4545f3209f fix(ui): fix bug with image deletion not removing image from gallery 2023-06-22 16:25:49 +10:00
e2ee8102c2 tidy(db): tidy image_record_storage.py 2023-06-22 16:25:49 +10:00
083a0fc4cf tidy(ui): remove references to boardsAdapter 2023-06-22 16:25:49 +10:00
26b75b85f7 fix(ui): if deleting selected board, deselect it 2023-06-22 16:25:49 +10:00
f560a462a0 feat(ui): rudimentary categorized gallery image fetching 2023-06-22 16:25:49 +10:00
d501986610 chore(ui): regen api client 2023-06-22 16:25:49 +10:00
67a75f6895 feat(api, db): support board_id filter on images service get_many() 2023-06-22 16:25:49 +10:00
3c032c0767 feat(ui): only auto-add image to board if is not intermediate 2023-06-22 16:25:49 +10:00
abd6561140 feat(ui): just fetch all boards instead of paginating them 2023-06-22 16:25:49 +10:00
bd533426fc feat(ui): first pass at boards styling 2023-06-22 16:25:49 +10:00
2489d5459f chore(ui): regen api client 2023-06-22 16:25:49 +10:00
ac477cf5d6 fix(ui): improve image deletion handling 2023-06-22 16:25:49 +10:00
be3bdae847 fix: resolve rebase conflicts 2023-06-22 16:25:49 +10:00
3e0ee838cf fix(ui): add initial image dimensions to state
We need to access the initial image dimensions during the creation of the `ImageToImage` graph to determine if we need to resize the image.

Because the `initialImage` is now just an image name, we need to either store (easy) or dynamically retrieve its dimensions during graph creation (a bit less easy).

Took the easiest path. May need to revise this in the future.
2023-06-22 16:25:49 +10:00
8d3bec57d5 feat(ui): store only image name in parameters
Images that are used as parameters (e.g. init image, canvas images) are stored as full `ImageDTO` objects in state, separate from and duplicating any object representing those same objects in the `imagesSlice`.

We cannot store only image names as parameters, then pull the full `ImageDTO` from `imagesSlice`, because if an image is not on a loaded page, it doesn't exist in `imagesSlice`. For example, if you scroll down a few pages in the gallery and send that image to canvas, on reloading the app, the canvas will be unable to load that image.

We solved this temporarily by storing the full `ImageDTO` object wherever it was needed, but this is both inefficient and allows for stale `ImageDTO`s across the app.

One other possible solution was to just fetch the `ImageDTO` for all images at startup, and insert them into the `imagesSlice`, but then we run into an issue where we are displaying images in the gallery totally out of context.

For example, if an image from several pages into the gallery was sent to canvas, and the user refreshes, we'd display the first 20 images in gallery. Then to populate the canvas, we'd fetch that image we sent to canvas and add it to `imagesSlice`. Now we'd have 21 images in the gallery: 1 to 20 and whichever image we sent to canvas. Weird.

Using `rtk-query` solves this by allowing us to very easily fetch individual images in the components that need them, and not directly interact with `imagesSlice`.

This commit changes all references to images-as-parameters to store only the name of the image, and not the full `ImageDTO` object. Then, we use an `rtk-query` generated `useGetImageDTOQuery()` hook in each of those components to fetch the image.

We can use cache invalidation when we mutate any image to trigger automated re-running of the query and all the images are automatically kept up to date.

This also obviates the need for the convoluted URL fetching scheme for images that are used as parameters. The `imagesSlice` still need this handling unfortunately.
2023-06-22 16:25:49 +10:00
cfda128e06 feat(ui): wip boards via rtk-query 2023-06-22 16:25:49 +10:00
661a94b3de feat(db): add get_all() method for boards
This is needed to show the full list of boards in the update boards modal.
2023-06-22 16:25:49 +10:00
9ef64016c7 feat(db): sort board by created_at 2023-06-22 16:25:49 +10:00
21f0d0b0c1 fix(db): fix deserialize_board_record()
It was not adding `cover_image_name`
2023-06-22 16:25:49 +10:00
8bce234542 feat(db): update image-board relationships on add
Functionally, `add_image_to_board()` now moves images between boards.
2023-06-22 16:25:49 +10:00
daadf6ebfd feat(ui): add board image count badge 2023-06-22 16:25:49 +10:00
fe10a9f747 render cover image based on URL in image entities 2023-06-22 16:25:49 +10:00
7a2d3f628a add boardToAddTo state so that result can be added to board when generation is complete 2023-06-22 16:25:49 +10:00
4defb92105 handle long board names 2023-06-22 16:25:49 +10:00
f9f3c91a83 drag and drop to move image to board, a bit of board list UI 2023-06-22 16:25:49 +10:00
95b9c8e505 return cover_image_name since urls change, override one from db for now 2023-06-22 16:25:49 +10:00
49a02c157b feat(ui): fix UpdateImageBoardModal select 2023-06-22 16:25:49 +10:00
d604d986f9 feat(db, api): update get_board_for_image & service dependencies
- previously was `get_boards_for_image`, returning a list of `BoardDTO`, now returns a single `board_id`
2023-06-22 16:25:49 +10:00
70cc037a9c fix(ui): do not persist boards 2023-06-22 16:25:49 +10:00
e4893e4031 fix(db): return board records from CRUD methods 2023-06-22 16:25:49 +10:00
4a0a718b96 foiled by a comma 2023-06-22 16:25:49 +10:00
ca8f1a7828 (api) use most recently generated image for cover photo 2023-06-22 16:25:49 +10:00
2e41af2109 [half-baked] adding image to board modal 2023-06-22 16:25:49 +10:00
bd29e5e655 UI tweaks 2023-06-22 16:25:49 +10:00
dcfee2e1e4 add searching to boards list 2023-06-22 16:25:49 +10:00
8aac683319 can delete and rename boards 2023-06-22 16:25:49 +10:00
d306a84447 feat(ui): rough out boards UI 2023-06-22 16:25:49 +10:00
5865ecd530 feat(db): add FK for boards.cover_image_name 2023-06-22 16:25:49 +10:00
e1f9685b02 feat(db): add index for boards 2023-06-22 16:25:49 +10:00
498bf0d0ba feat(db): add indices for board_images 2023-06-22 16:25:49 +10:00
163ef2c941 feat(ui): remove refs to BoardRecord in UI
UI should only work w/ BoardDTO
2023-06-22 16:25:49 +10:00
48193b7fa7 chore(ui): regen api client 2023-06-22 16:25:49 +10:00
dd1b3c9f35 fix(api): update API models to use BoardDTOs 2023-06-22 16:25:49 +10:00
4b32322a58 feat(nodes): make board <> images a one-to-many relationship
we can extend this to many-to-many in the future if desired.
2023-06-22 16:25:49 +10:00
e06c43adc8 lint fix 2023-06-22 16:25:49 +10:00
c009f46b00 regenerate api schema 2023-06-22 16:25:49 +10:00
748016bdab routes working 2023-06-22 16:25:49 +10:00
72e9ced889 feat(nodes): add boards and board_images services 2023-06-22 16:25:49 +10:00
3833304f57 [WIP] board list endpoint w cover photos 2023-06-22 16:25:49 +10:00
4bfaae6617 fix type 2023-06-22 16:25:49 +10:00
499a174832 some more 2023-06-22 16:25:49 +10:00
6ca5ad9075 filter images by board_id 2023-06-22 16:25:49 +10:00
a121e6b3a0 add board_id association to image 2023-06-22 16:25:49 +10:00
207602f425 remove unused 2023-06-22 16:25:49 +10:00
a1671519d5 board CRUD 2023-06-22 16:25:49 +10:00
257e972599 fix failing pytest for config module 2023-06-20 13:26:01 -04:00
d339c8627f feat: Upgrade to Diffusers 0.17.1 (#3545)
Just syncing up with diffusers upstream.
2023-06-19 23:25:22 +12:00
a53e0dce6c Merge branch 'upgrade-diffusers' of https://github.com/blessedcoolant/InvokeAI into upgrade-diffusers 2023-06-19 23:21:06 +12:00
0ae6325353 chore: Add torchsde as a dependency for the SDE schedulers 2023-06-19 23:20:53 +12:00
12299120ab Merge branch 'main' into upgrade-diffusers 2023-06-19 23:16:39 +12:00
1a7fe172ca Fix inpaint node to new manager (#3550)
Inpaint node still used by canvas, so fixed it to new model manager api.
Other old generation code deleted.
2023-06-19 23:01:05 +12:00
4f5693040e Merge branch 'main' into fix/inpaint_new_manager 2023-06-19 22:55:00 +12:00
bb2df88c06 Add dpmpp_sde and dpmpp_2m_sde schedulers(with karras) (#3554)
Added sde schedulers.
Problem - they add random on each step, to get consistent image we need
to provide seed or generator.
I done it, but if you think that it better do in other way - feel free
to change.

Also made ancestral schedulers reproducible, this done same way as for
sde scheduler.
2023-06-19 22:52:33 +12:00
41442eb7f6 feat(ui): convert canvas txt2img & img2img to latents
- Add graph builders for canvas txt2img & img2img - they are mostly copy and paste from the linear graph builders but different in a few ways that are very tricky to work around. Just made totally new functions for them.
- Canvas txt2img and img2img support ControlNet (not inpaint/outpaint). There's no way to determine in real-time which mode the canvas is in just yet, so we cannot disable the ControlNet UI when the mode will be inpaint/outpaint - it will always display. It's possible to determine this in near-real-time, will add this at some point.
- Canvas inpaint/outpaint migrated to use model loader, though inpaint/outpaint are still using the non-latents nodes.
2023-06-19 15:57:28 +10:00
223a679ac1 chore(ui): regen api client 2023-06-19 15:57:28 +10:00
3c60616b4d feat(ui): simplify linear graph creation logic
Instead of manually creating every node and edge, we can simply copy/paste the base graph from node editor, then sub in parameters.

This is a much more intelligible process. We still need to handle seed, img2img fit and controlnet separately.
2023-06-19 15:57:28 +10:00
a01998d095 Remove more old logic 2023-06-19 15:57:28 +10:00
7b35162b9e Remove old logic except for inpaint, add support for lora and ti to inpaint node 2023-06-19 15:57:28 +10:00
c26e1a9271 Rewrite inpaint node to new model manager, remove TextToImage and ImageToImage nodes 2023-06-19 15:57:28 +10:00
9b32407744 Provide generator to all schedulers step function to make both ancestral and sde schedulers reproducible 2023-06-19 00:34:01 +03:00
f3d9797ebe Add dpmpp_sde and dpmpp_2m_sde schedulers(with karras) 2023-06-18 23:38:15 +03:00
f312e1448f Update index.md
fixed typo
2023-06-18 10:39:02 -04:00
a11946f0ad feat: Port Schedulers to Mantine (#3552)
- Ports Schedulers to use IAIMantineSelect.
- Adds ability to favorite schedulers in Settings. Favorited schedulers
show up at the top of the list.
- Adds IAIMantineMultiSelect component.
- Change SettingsSchedulers component to use IAIMantineMultiSelect
instead of Chakra Menus.
2023-06-18 22:22:03 +12:00
80a8d3ef28 style: Theme placeholder style for IAIMantineMultiSelect 2023-06-18 22:17:09 +12:00
f4ca9d0e09 Merge branch 'scheduler-select' of https://github.com/blessedcoolant/InvokeAI into scheduler-select 2023-06-18 22:05:12 +12:00
a960fa009d fix: Fix some styling issues with IAIMantineMultiSelect 2023-06-18 22:04:12 +12:00
b96b95bc95 feat(ui): enabledSchedulers -> favoriteSchedulers 2023-06-18 20:01:05 +10:00
450641c414 fix(ui): enable all schedulers by default 2023-06-18 19:39:31 +10:00
94cfcdc411 feat(ui): improve scheduler selection logic
- remove UI-specific state (the enabled schedulers) from redux, instead derive it in a selector
- simplify logic by putting schedulers in an object instead of an array
- rename `activeSchedulers` to `enabledSchedulers`
- remove need for `useEffect()` when `enabledSchedulers` changes by adding a listener for the `enabledSchedulersChanged` action/event to `generationSlice`
- increase type safety by making `enabledSchedulers` an array of `SchedulerParam`, which is created by the zod schema for scheduler
2023-06-18 19:34:37 +10:00
150059f704 fix(ui): create all scheduler constants up-front 2023-06-18 18:49:10 +10:00
f1a8b9daee fix(ui): clarify scheduler logic
- use full conditional syntax with `{}`
- do not mutate `action.payload` in a reducer
2023-06-18 18:47:59 +10:00
be8c0bb952 feat: Use Labels for Schedulers 2023-06-18 20:17:51 +12:00
dae5b9b259 fix: Minor styling fix to the IAIMantineMultiSelect component 2023-06-18 20:06:56 +12:00
06428fac67 fix: Revert scheduler back to zod validation 2023-06-18 20:02:36 +12:00
59b5dfc3e0 feat: Port Schedulers to Mantine 2023-06-18 19:47:27 +12:00
fd981a90be Add lms and dpmpp2_s karras scheduler (#3551)
Karras sigmas support added to lms and dpmpp2_s schedulers in 0.17.0
diffusers.
2023-06-18 17:36:47 +12:00
6b7cf3f3be Add lms and dpmpp2_s karras scheduler 2023-06-17 21:00:16 +03:00
9d4b84ef68 feat: Upgrade to Diffusers 0.17.1 2023-06-16 23:57:57 +12:00
4cbc802e36 Model manager fixes (#3541)
Fix lora import
Fix sd2 config - `variant` field not added
Fix list models api - `base_model` arg not provided, redundant assert
check
2023-06-16 06:43:00 +12:00
5f2d07917d Fix lora import, fix sd2 config, fix list models api 2023-06-15 21:30:15 +03:00
5c740452f6 Model Manager rewrite (#3335) 2023-06-14 08:44:04 -07:00
82c2498043 Merge branch 'main' into lstein/new-model-manager 2023-06-14 08:41:40 -07:00
0497bea264 fix: add dynamicprompts to pyproject.toml 2023-06-15 01:05:16 +10:00
b8e32fa459 chore(ui): regen api client 2023-06-15 01:05:16 +10:00
34ebee67b7 fix(nodes): fix revert conflict 2023-06-15 01:05:16 +10:00
e0c998d192 Revert "feat(ui): add warning socket event handling"
This reverts commit e7a61e631a42190e4b64e0d5e22771c669c5b30c.
2023-06-15 01:05:16 +10:00
b51e9a6bdb Revert "feat(nodes): add warning socket event"
This reverts commit cefdd9d634e515239bd85666c872a0d64bb9d772.
2023-06-15 01:05:16 +10:00
09f396ce84 feat(ui): add warning socket event handling 2023-06-15 01:05:16 +10:00
abee37eab3 feat(nodes): add warning socket event 2023-06-15 01:05:16 +10:00
42e48b2bef feat(nodes): add dynamic prompt node 2023-06-15 01:05:16 +10:00
70ece4364c refactor(minor): Image & Latent File Storage (#3538)
- `DiskImageStorage` and `DiskLatentsStorage` have now both been updated
to exclusively work with `Path` objects and not rely on the `os` lib to
handle pathing related functions.
- We now also validate the existence of the required image output
folders and latent output folders to ensure that the app does not break
in case the required folders get tampered with mid-session.
- Just overall general cleanup.

Tested it. Don't seem to be any thing breaking.
2023-06-15 02:43:27 +12:00
f9d5f9d52c fix(nodes): minor fixes for folder validation
- fix type for `__output_folder`
- prefix `validate_storage_folders()` with `__` to indicate private method
2023-06-15 00:40:39 +10:00
d0ee3558d1 Merge branch 'main' into lstein/new-model-manager 2023-06-14 17:29:01 +03:00
587297878a refactor(minor): Latent Disk Storage 2023-06-15 02:21:49 +12:00
b4c998a9ae refactor(minor): Image File Storage 2023-06-15 01:58:58 +12:00
88e8e3977b feat(ui): update UI to not use image_origin
see commit `8ad8de8: feat(nodes): remove `image_origin` from most places` for details.
2023-06-14 23:08:27 +10:00
24b86cffe9 chore(ui): regen api client & types 2023-06-14 23:08:27 +10:00
a1773197e9 feat(nodes): remove image_origin from most places
- remove `image_origin` from most places where we interact with images
- consolidate image file storage into a single `images/` dir

Images have an `image_origin` attribute but it is not actually used when retrieving images, nor will it ever be. It is still used when creating images and helps to differentiate between internally generated images and uploads.

It was included in eg API routes and image service methods as a holdover from the previous app implementation where images were not managed in a database. Now that we have images in a db, we can do away with this and simplify basically everything that touches images.

The one potentially controversial change is to no longer separate internal and external images on disk. If we retain this separation, we have to keep `image_origin` around in a number of spots and it getting image paths on disk painful.

So, I am have gotten rid of this organisation. Images are now all stored in `images`, regardless of their origin. As we improve the image management features, this change will hopefully become transparent.
2023-06-14 23:08:27 +10:00
1e08d865c9 chore: dummy commit to trigger actions 2023-06-14 14:14:24 +10:00
f8bb650cc1 revert: IAIScrollArea 2023-06-14 14:14:24 +10:00
2cee8bebb2 fix(ui): revert offset scrollbars
The wonky padding is too janky. Just overlay for now.
2023-06-14 14:14:24 +10:00
ade4ec5fd8 fix(ui): fix crash when toggling pinned parameters panel 2023-06-14 14:14:24 +10:00
70ffd6b03f fix(ui): fix controlnet selects data types 2023-06-14 14:14:24 +10:00
6c551df311 fix(ui): fix rebase conflicts 2023-06-14 14:14:24 +10:00
24f605629e cleanup: Remove OverlayScrollable component 2023-06-14 14:14:24 +10:00
2af1ec9d02 fix: Minor padding issue in unpinned drawer 2023-06-14 14:14:24 +10:00
79d53341de fix: Stretch scroll area so it retains parent width 2023-06-14 14:14:24 +10:00
e40b3506c4 fix: Options squishing on accordion collapse 2023-06-14 14:14:24 +10:00
33912382e3 feat: Introduce Mantine's ScrollArea 2023-06-14 14:14:24 +10:00
d282810e53 cleanup: Remove IAICustomSelect and port types 2023-06-14 14:14:24 +10:00
9df502fc77 fix(ui): fix mantine select props 2023-06-14 14:14:24 +10:00
705573f0a8 feat(ui): even more pedantic mantine select theming 2023-06-14 14:14:24 +10:00
1878ea94f6 feat: Port Canvas Layer Select to IAIMantineSelect 2023-06-14 14:14:24 +10:00
4ba5086b9a feat(ui): add tooltip to IAIMantineSelect 2023-06-14 14:14:24 +10:00
4a991b4daa feat(ui): more pedantic mantine select theming 2023-06-14 14:14:24 +10:00
80474d26f9 feat(ui): mantine scrollbar theming 2023-06-14 14:14:24 +10:00
9a77bd9140 feat: Port IAISelect's to IAIMantineSelect's
Ported everything except Model Manager selects and the Canvas Layer Select (this needs tooltip support)
2023-06-14 14:14:24 +10:00
14cdc800c3 feat(ui): pedantic mantine select theming 2023-06-14 14:14:24 +10:00
9cfbea4c25 feat: Match styling of Mantine Select with InvokeAI 2023-06-14 14:14:24 +10:00
5fe674e223 feat: Standardize IAIMantineSelect Component 2023-06-14 14:14:24 +10:00
32200efce8 feat: Change default font to Inter 2023-06-14 14:14:24 +10:00
68a02da990 feat: Use Mantine Select for Scheduler 2023-06-14 14:14:24 +10:00
5b20766ea3 chore: Move Mantine Theme Override to own file 2023-06-14 14:14:24 +10:00
9a914250a0 feat: Change Model Select To Mantine 2023-06-14 14:14:24 +10:00
0e3106f631 feat: Add Mantine Support 2023-06-14 14:14:24 +10:00
6c5954f9d1 Add controlnet to model manager, fixes 2023-06-14 04:26:21 +03:00
740c05a0bb Save models on rescan, uncache model on edit/delete, fixes 2023-06-14 03:12:12 +03:00
26090011c4 Fix conflict resolve, add model configs to type annotation 2023-06-14 00:26:37 +03:00
c9ae26a176 Merge branch 'main' into lstein/new-model-manager 2023-06-13 23:37:52 +03:00
e7db6d8120 Fix ckpt and vae conversion, migrate script, remove sd2-base 2023-06-13 18:05:12 +03:00
a6af7e8824 use format "diffusers" rather than format "folder" in models.yaml 2023-06-13 01:43:05 -04:00
87ba17a1f5 add migration script and update convert and face restoration paths 2023-06-13 01:27:51 -04:00
c7ea46a5da use latest version of transformers to avoid deprecation warnings 2023-06-12 16:07:39 -04:00
1439dc7712 Add SchedulerPredictionType and ModelVariantType enums 2023-06-12 16:07:04 -04:00
46cac6468e Upgrade to Diffusers 0.17.0 (#3514)
Diffusers is due for an update soon. #3512

Opening up a PR now with the required changes for when the new version
is live.

I've tested it out on Windows and nothing has broken from what I could
tell. I'd like someone to run some tests on Linux / Mac just to make
sure. Refer to the PR above on how to test it or install the release
branch.

```
pip install diffusers[torch]==0.17.0
```

Feel free to push any other changes to this PR you see fit.
2023-06-13 07:11:02 +12:00
2a814d886b Merge branch 'main' into diffusers-upgrade 2023-06-13 05:29:15 +12:00
60a2fbec41 feat(ui): improve controlnet-related config types 2023-06-13 00:04:21 +10:00
f15a328b80 fix(ui): allow controlnet with preprocessed control image 2023-06-13 00:04:21 +10:00
811d9ab55a fix(ui): disable shouldAutoConfig switch while processing 2023-06-13 00:04:21 +10:00
e00fed5c46 feat(ui): support disabling controlnet models & processors 2023-06-13 00:04:21 +10:00
a3fa38b353 fix(ui): revert IAICustomSelect usage to IAISelect
There are some bugs with it that I cannot figure out related to `floating-ui` and `downshift`'s handling of refs.

Will need to revisit this component in the future.
2023-06-13 00:04:21 +10:00
2e42a4bdd9 feat(ui): disable controlnets during processing 2023-06-13 00:04:21 +10:00
36f72b5a49 fix(ui): check for valid controlnets before adding to graph 2023-06-13 00:04:21 +10:00
af42d7d347 feat(ui): support negative controlnet weights 2023-06-13 00:04:21 +10:00
8607b1994c fix(ui): fix crash when controlnet enabled but no controlnets added 2023-06-13 00:04:21 +10:00
36eb1bd893 Fixes 2023-06-12 16:14:09 +03:00
9fa78443de Fixes, add sd variant detection 2023-06-12 05:52:30 +03:00
893f776f1d model_probe working; model_install incomplete 2023-06-11 19:51:53 -04:00
e051c450ed fix: git stash (#3528) 2023-06-12 08:55:36 +12:00
50135b726e fix: git stash 2023-06-12 08:53:09 +12:00
085ab54124 remove modified models.py and migrate code to models/base.py 2023-06-11 16:10:15 -04:00
8e1a56875e remove defunct code 2023-06-11 12:57:06 -04:00
000626ab2e move all installation code out of model_manager 2023-06-11 12:51:50 -04:00
694fd0c92f Fixes, first runable version 2023-06-11 16:42:40 +03:00
c647056287 Feat/easy param (#3504)
* Testing change to LatentsToText to allow setting different cfg_scale values per diffusion step.

* Adding first attempt at float param easing node, using Penner easing functions.

* Core implementation of ControlNet and MultiControlNet.

* Added support for ControlNet and MultiControlNet to legacy non-nodal Txt2Img in backend/generator. Although backend/generator will likely disappear by v3.x, right now they are very useful for testing core ControlNet and MultiControlNet functionality while node codebase is rapidly evolving.

* Added example of using ControlNet with legacy Txt2Img generator

* Resolving rebase conflict

* Added first controlnet preprocessor node for canny edge detection.

* Initial port of controlnet node support from generator-based TextToImageInvocation node to latent-based TextToLatentsInvocation node

* Switching to ControlField for output from controlnet nodes.

* Resolving conflicts in rebase to origin/main

* Refactored ControlNet nodes so they subclass from PreprocessedControlInvocation, and only need to override run_processor(image) (instead of reimplementing invoke())

* changes to base class for controlnet nodes

* Added HED, LineArt, and OpenPose ControlNet nodes

* Added an additional "raw_processed_image" output port to controlnets, mainly so could route ImageField to a ShowImage node

* Added more preprocessor nodes for:
      MidasDepth
      ZoeDepth
      MLSD
      NormalBae
      Pidi
      LineartAnime
      ContentShuffle
Removed pil_output options, ControlNet preprocessors should always output as PIL. Removed diagnostics and other general cleanup.

* Prep for splitting pre-processor and controlnet nodes

* Refactored controlnet nodes: split out controlnet stuff into separate node, stripped controlnet stuff form image processing/analysis nodes.

* Added resizing of controlnet image based on noise latent. Fixes a tensor mismatch issue.

* More rebase repair.

* Added support for using multiple control nets. Unfortunately this breaks direct usage of Control node output port  ==> TextToLatent control input port -- passing through a Collect node is now required. Working on fixing this...

* Fixed use of ControlNet control_weight parameter

* Fixed lint-ish formatting error

* Core implementation of ControlNet and MultiControlNet.

* Added first controlnet preprocessor node for canny edge detection.

* Initial port of controlnet node support from generator-based TextToImageInvocation node to latent-based TextToLatentsInvocation node

* Switching to ControlField for output from controlnet nodes.

* Refactored controlnet node to output ControlField that bundles control info.

* changes to base class for controlnet nodes

* Added more preprocessor nodes for:
      MidasDepth
      ZoeDepth
      MLSD
      NormalBae
      Pidi
      LineartAnime
      ContentShuffle
Removed pil_output options, ControlNet preprocessors should always output as PIL. Removed diagnostics and other general cleanup.

* Prep for splitting pre-processor and controlnet nodes

* Refactored controlnet nodes: split out controlnet stuff into separate node, stripped controlnet stuff form image processing/analysis nodes.

* Added resizing of controlnet image based on noise latent. Fixes a tensor mismatch issue.

* Cleaning up TextToLatent arg testing

* Cleaning up mistakes after rebase.

* Removed last bits of dtype and and device hardwiring from controlnet section

* Refactored ControNet support to consolidate multiple parameters into data struct. Also redid how multiple controlnets are handled.

* Added support for specifying which step iteration to start using
each ControlNet, and which step to end using each controlnet (specified as fraction of total steps)

* Cleaning up prior to submitting ControlNet PR. Mostly turning off diagnostic printing. Also fixed error when there is no controlnet input.

* Added dependency on controlnet-aux v0.0.3

* Commented out ZoeDetector. Will re-instate once there's a controlnet-aux release that supports it.

* Switched CotrolNet node modelname input from free text to default list of popular ControlNet model names.

* Fix to work with current stable release of controlnet_aux (v0.0.3). Turned of pre-processor params that were added post v0.0.3. Also change defaults for shuffle.

* Refactored most of controlnet code into its own method to declutter TextToLatents.invoke(), and make upcoming integration with LatentsToLatents easier.

* Cleaning up after ControlNet refactor in TextToLatentsInvocation

* Extended node-based ControlNet support to LatentsToLatentsInvocation.

* chore(ui): regen api client

* fix(ui): add value to conditioning field

* fix(ui): add control field type

* fix(ui): fix node ui type hints

* fix(nodes): controlnet input accepts list or single controlnet

* Moved to controlnet_aux v0.0.4, reinstated Zoe controlnet preprocessor. Also in pyproject.toml  had to specify downgrade of timm to 0.6.13 _after_ controlnet-aux installs timm >= 0.9.2, because timm >0.6.13 breaks Zoe preprocessor.

* Core implementation of ControlNet and MultiControlNet.

* Added first controlnet preprocessor node for canny edge detection.

* Switching to ControlField for output from controlnet nodes.

* Resolving conflicts in rebase to origin/main

* Refactored ControlNet nodes so they subclass from PreprocessedControlInvocation, and only need to override run_processor(image) (instead of reimplementing invoke())

* changes to base class for controlnet nodes

* Added HED, LineArt, and OpenPose ControlNet nodes

* Added more preprocessor nodes for:
      MidasDepth
      ZoeDepth
      MLSD
      NormalBae
      Pidi
      LineartAnime
      ContentShuffle
Removed pil_output options, ControlNet preprocessors should always output as PIL. Removed diagnostics and other general cleanup.

* Prep for splitting pre-processor and controlnet nodes

* Refactored controlnet nodes: split out controlnet stuff into separate node, stripped controlnet stuff form image processing/analysis nodes.

* Added resizing of controlnet image based on noise latent. Fixes a tensor mismatch issue.

* Added support for using multiple control nets. Unfortunately this breaks direct usage of Control node output port  ==> TextToLatent control input port -- passing through a Collect node is now required. Working on fixing this...

* Fixed use of ControlNet control_weight parameter

* Core implementation of ControlNet and MultiControlNet.

* Added first controlnet preprocessor node for canny edge detection.

* Initial port of controlnet node support from generator-based TextToImageInvocation node to latent-based TextToLatentsInvocation node

* Switching to ControlField for output from controlnet nodes.

* Refactored controlnet node to output ControlField that bundles control info.

* changes to base class for controlnet nodes

* Added more preprocessor nodes for:
      MidasDepth
      ZoeDepth
      MLSD
      NormalBae
      Pidi
      LineartAnime
      ContentShuffle
Removed pil_output options, ControlNet preprocessors should always output as PIL. Removed diagnostics and other general cleanup.

* Prep for splitting pre-processor and controlnet nodes

* Refactored controlnet nodes: split out controlnet stuff into separate node, stripped controlnet stuff form image processing/analysis nodes.

* Added resizing of controlnet image based on noise latent. Fixes a tensor mismatch issue.

* Cleaning up TextToLatent arg testing

* Cleaning up mistakes after rebase.

* Removed last bits of dtype and and device hardwiring from controlnet section

* Refactored ControNet support to consolidate multiple parameters into data struct. Also redid how multiple controlnets are handled.

* Added support for specifying which step iteration to start using
each ControlNet, and which step to end using each controlnet (specified as fraction of total steps)

* Cleaning up prior to submitting ControlNet PR. Mostly turning off diagnostic printing. Also fixed error when there is no controlnet input.

* Commented out ZoeDetector. Will re-instate once there's a controlnet-aux release that supports it.

* Switched CotrolNet node modelname input from free text to default list of popular ControlNet model names.

* Fix to work with current stable release of controlnet_aux (v0.0.3). Turned of pre-processor params that were added post v0.0.3. Also change defaults for shuffle.

* Refactored most of controlnet code into its own method to declutter TextToLatents.invoke(), and make upcoming integration with LatentsToLatents easier.

* Cleaning up after ControlNet refactor in TextToLatentsInvocation

* Extended node-based ControlNet support to LatentsToLatentsInvocation.

* chore(ui): regen api client

* fix(ui): fix node ui type hints

* fix(nodes): controlnet input accepts list or single controlnet

* Added Mediapipe image processor for use as ControlNet preprocessor.
Also hacked in ability to specify HF subfolder when loading ControlNet models from string.

* Fixed bug where MediapipFaceProcessorInvocation was ignoring max_faces and min_confidence params.

* Added nodes for float params: ParamFloatInvocation and FloatCollectionOutput. Also added FloatOutput.

* Added mediapipe install requirement. Should be able to remove once controlnet_aux package adds mediapipe to its requirements.

* Added float to FIELD_TYPE_MAP ins constants.ts

* Progress toward improvement in fieldTemplateBuilder.ts  getFieldType()

* Fixed controlnet preprocessors and controlnet handling in TextToLatents to work with revised Image services.

* Cleaning up from merge, re-adding cfg_scale to FIELD_TYPE_MAP

* Making sure cfg_scale of type list[float] can be used in image metadata, to support param easing for cfg_scale

* Fixed math for per-step param easing.

* Added option to show plot of param value at each step

* Just cleaning up after adding param easing plot option, removing vestigial code.

* Modified control_weight ControlNet param to be polistmorphic --
can now be either a single float weight applied for all steps, or a list of floats of size total_steps, that specifies weight for each step.

* Added more informative error message when _validat_edge() throws an error.

* Just improving parm easing bar chart title to include easing type.

* Added requirement for easing-functions package

* Taking out some diagnostic prints.

* Added option to use both easing function and mirror of easing function together.

* Fixed recently introduced problem (when pulled in main), triggered by num_steps in StepParamEasingInvocation not having a default value -- just added default.

---------

Co-authored-by: psychedelicious <4822129+psychedelicious@users.noreply.github.com>
2023-06-11 16:27:44 +10:00
738ba40f51 Fixes 2023-06-11 06:12:21 +03:00
3ce3a7ee72 Rewrite model configs, separate models 2023-06-11 04:49:09 +03:00
74b43c9bdf fix incorrect variable/typenames in model_cache 2023-06-10 10:41:48 -04:00
3d2ff7755e resolve conflicts 2023-06-10 10:13:54 -04:00
a87d52a389 resolve conflicts between lstein & sttalker changes 2023-06-10 09:59:19 -04:00
959e64c9b3 start removing repo_id support 2023-06-10 09:57:23 -04:00
2c056ead42 New models structure draft 2023-06-10 03:14:10 +03:00
30f20b55d5 fix logger behavior so that it is initialized after command line parsed (#3509)
In some cases the command-line was getting parsed before the logger was
initialized, causing the logger not to pick up custom logging
instructions from `--log_handlers`. This PR fixes the issue.
2023-06-09 08:24:47 -07:00
1bca32ed16 Merge branch 'main' into lstein/fix-logger-reconfiguration 2023-06-09 06:27:26 -07:00
7f91139e21 fix(ui): fix crash when using dropdown on certain device resolutions 2023-06-09 22:19:30 +10:00
c53b7c7389 ui: misc fixes (#3525)
[fix(ui): blur tab on
click](93f3658a4a)

Fixes issue where after clicking a tab, using the arrow keys changes tab
instead of changing selected image

[fix(ui): fix canvas not filling screen on first
load](68be95acbb)

[feat(ui): remove clear temp folder canvas
button](813f79f0f9)

This button is nonfunctional.

Soon we will introduce a different way to handle clearing out
intermediate images (likely automated).
2023-06-09 23:44:21 +12:00
93f3658a4a fix(ui): blur tab on click
Fixes issue where after clicking a tab, using the arrow keys changes tab instead of changing selected image
2023-06-09 18:20:52 +10:00
68be95acbb fix(ui): fix canvas not filling screen on first load 2023-06-09 17:55:11 +10:00
813f79f0f9 feat(ui): remove clear temp folder canvas button
This button is nonfunctional.

Soon we will introduce a different way to handle clearing out intermediate images (likely automated).
2023-06-09 17:33:17 +10:00
c3ec86bc70 feat(ui): enhance IAICustomSelect (#3523)
Now accepts an array of strings or array of `IAICustomSelectOption`s.
This supports custom labels and tooltips within the select component.
2023-06-09 18:26:20 +12:00
05a19753c6 feat(ui): remove controlnet model descriptions
These are not yet exposed on the UI - somebody who understands what they do better can add them when we have a place to expose them
2023-06-09 16:20:30 +10:00
a33327c651 feat(ui): enhance IAICustomSelect
Now accepts an array of strings or array of `IAICustomSelectOption`s. This supports custom labels and tooltips within the select component.
2023-06-09 16:00:17 +10:00
6ad7cc4f2a feat(ui): decrease delay on dnd to 150ms (#3522) 2023-06-09 17:54:24 +12:00
c506355b8b feat(ui): decrease delay on dnd to 150ms 2023-06-09 15:53:17 +10:00
d54168b8fb feat(nodes): add tests for depth-first execution 2023-06-09 14:53:45 +10:00
c91b071c47 fix(nodes): use DFS with preorder traversal 2023-06-09 14:53:45 +10:00
9c57b18008 fix(nodes): update Invoker.invoke() docstring 2023-06-09 14:53:45 +10:00
69539a0472 feat(nodes): depth-first execution
There was an issue where for graphs w/ iterations, your images were output all at once, at the very end of processing. So if you canceled halfway through an execution of 10 nodes, you wouldn't get any images - even though you'd completed 5 images' worth of inference.

## Cause

Because graphs executed breadth-first (i.e. depth-by-depth), leaf nodes were necessarily processed last. For image generation graphs, your `LatentsToImage` will be leaf nodes, and be the last depth to be executed.

For example, a `TextToLatents` graph w/ 3 iterations would execute all 3 `TextToLatents` nodes fully before moving to the next depth, where the `LatentsToImage` nodes produce output images, resulting in a node execution order like this:

1. TextToLatents
2. TextToLatents
3. TextToLatents
4. LatentsToImage
5. LatentsToImage
6. LatentsToImage

## Solution

This PR makes a two changes to graph execution to execute as deeply as it can along each branch of the graph.

### Eager node preparation

We now prepare as many nodes as possible, instead of just a single node at a time.

We also need to change the conditions in which nodes are prepared. Previously, nodes were prepared only when all of their direct ancestors were executed.

The updated logic prepares nodes that:
- are *not* `Iterate` nodes whose inputs have *not* been executed
- do *not* have any unexecuted `Iterate` ancestor nodes

This results in graphs always being maximally prepared.

### Always execute the deepest prepared node

We now choose the next node to execute by traversing from the bottom of the graph instead of the top, choosing the first node whose inputs are all executed.

This means we always execute the deepest node possible.

## Result

Graphs now execute depth-first, so instead of an execution order like this:

1. TextToLatents
2. TextToLatents
3. TextToLatents
4. LatentsToImage
5. LatentsToImage
6. LatentsToImage

... we get an execution order like this:

1. TextToLatents
2. LatentsToImage
3. TextToLatents
4. LatentsToImage
5. TextToLatents
6. LatentsToImage

Immediately after inference, the image is decoded and sent to the gallery.

fixes #3400
2023-06-09 14:53:45 +10:00
7bce455d16 Merge branch 'main' into diffusers-upgrade 2023-06-09 16:27:52 +12:00
3f45294c61 feat(ui): restore reset button for init image (#3521) 2023-06-09 16:02:26 +12:00
fd03c7eebe feat(ui): restore reset button for init image 2023-06-09 14:00:23 +10:00
07c49a5726 feat(ui): skip resize on img2img if not needed (#3520) 2023-06-09 15:56:22 +12:00
8c688f8e29 feat(ui): skip resize on img2img if not needed 2023-06-09 13:54:23 +10:00
887576d217 add directory scanning for loras, controlnets and textual_inversions 2023-06-08 23:11:53 -04:00
6652f3405b merge with main 2023-06-08 21:08:43 -04:00
3d13167d32 Merge branch 'main' into lstein/fix-logger-reconfiguration 2023-06-08 13:41:24 -07:00
f2bb507ebb allow logger to be reconfigured after startup 2023-06-08 09:23:11 -04:00
fe8f3381fc create databases directory on startup (#3518)
This PR creates the databases directory at app startup time. It also
removes a couple of debugging statements that were inadvertently left in
the model manager.
2023-06-08 23:40:32 +12:00
2a6d11e645 create databases directory on startup 2023-06-08 07:17:54 -04:00
01f46d3c7d Merge branch 'main' into lstein/fix-logger-reconfiguration 2023-06-07 19:50:44 -07:00
5f76b62553 Update installer support for main (#3448)
#  Make InvokeAI package installable by mere mortals
    
This commit makes InvokeAI 3.0 to be installable via PyPi.org and/or the
installer script. The install process is now pretty much identical to
the 2.3 process, including creating launcher scripts `invoke.sh` and
`invoke.bat`.
    
Main changes:
    
1. Moved static web pages into `invokeai/frontend/web` and modified the
API to look for them there. This allows pip to copy the files into the
distribution directory so that user no longer has to be in repo root to
launch, and enables PyPi installations with `pip install invokeai`
    
2. Update invoke.sh and invoke.bat to launch the new web application
properly. This also changes the wording for launching the CLI from
"generate images" to "explore the InvokeAI node system," since I would
not recommend using the CLI to generate images routinely.
    
3. Fix a bug in the checkpoint converter script that was identified
during testing.
    
4. Better error reporting when checkpoint converter fails.
    
5. Rebuild front end.

# Major improvements to the model installer.

1. The text user interface for `invokeai-model-install` has been
expanded to allow the user to install controlnet, LoRA, textual
inversion, diffusers and checkpoint models. The user can install
interactively (without leaving the TUI), or in batch mode after exiting
the application.
 

![image](https://github.com/invoke-ai/InvokeAI/assets/111189/f8f7ac23-3e18-4973-b7fe-729864c703a0)

2. The `invokeai-model-install` command now lets you list, add and
delete models from the command line:

## Listing models
```
$ invokeai-model-install --list diffusers
Diffuser models:
analog-diffusion-1.0      not loaded  diffusers  An SD-1.5 model trained on diverse analog photographs (2.13 GB)
d&d-diffusion-1.0         not loaded  diffusers  Dungeons & Dragons characters (2.13 GB)
deliberate-1.0            not loaded  diffusers  Versatile model that produces detailed images up to 768px (4.27 GB)
DreamShaper               not loaded  diffusers  Imported diffusers model DreamShaper
sd-inpainting-1.5         not loaded  diffusers  RunwayML SD 1.5 model optimized for inpainting, diffusers version (4.27 GB)
sd-inpainting-2.0         not loaded  diffusers  Stable Diffusion version 2.0 inpainting model (5.21 GB)
stable-diffusion-1.5      not loaded  diffusers  Stable Diffusion version 1.5 diffusers model (4.27 GB)
stable-diffusion-2.1      not loaded  diffusers  Stable Diffusion version 2.1 diffusers model, trained on 768 pixel images (5.21 GB)
```

```
$ invokeai-model-install --list tis
Loading Python libraries...

Installed Textual Inversion Embeddings:
   EasyNegative
   ahx-beta-453407d
```

## Installing models

(this example shows correct handling of a server side error at Civitai)
```
$ invokeai-model-install --diffusers https://civitai.com/api/download/models/46259 Linaqruf/anything-v3.0
Loading Python libraries...

[2023-06-05 22:17:23,556]::[InvokeAI]::INFO --> INSTALLING EXTERNAL MODELS
[2023-06-05 22:17:23,557]::[InvokeAI]::INFO --> Probing https://civitai.com/api/download/models/46259 for import
[2023-06-05 22:17:23,557]::[InvokeAI]::INFO --> https://civitai.com/api/download/models/46259 appears to be a URL
[2023-06-05 22:17:23,763]::[InvokeAI]::ERROR --> An error occurred during downloading /home/lstein/invokeai-test/models/ldm/stable-diffusion-v1/46259: Internal Server Error
[2023-06-05 22:17:23,763]::[InvokeAI]::ERROR --> ERROR DOWNLOADING https://civitai.com/api/download/models/46259: {"error":"Invalid database operation","cause":{"clientVersion":"4.12.0"}}
[2023-06-05 22:17:23,764]::[InvokeAI]::INFO --> Probing Linaqruf/anything-v3.0 for import
[2023-06-05 22:17:23,764]::[InvokeAI]::DEBUG --> Linaqruf/anything-v3.0 appears to be a HuggingFace diffusers repo_id
[2023-06-05 22:17:23,768]::[InvokeAI]::INFO --> Loading diffusers model from Linaqruf/anything-v3.0
[2023-06-05 22:17:23,769]::[InvokeAI]::DEBUG --> Using faster float16 precision
[2023-06-05 22:17:23,883]::[InvokeAI]::ERROR --> An unexpected error occurred while downloading the model: 404 Client Error. (Request ID: Root=1-647e9733-1b0ee3af67d6ac3456b1ebfc)

Revision Not Found for url: https://huggingface.co/Linaqruf/anything-v3.0/resolve/fp16/model_index.json.
Invalid rev id: fp16)
Downloading (…)ain/model_index.json: 100%|██████████████████████████████████████████████████████████████████████████████████████████████| 511/511 [00:00<00:00, 2.57MB/s]
Downloading (…)cial_tokens_map.json: 100%|██████████████████████████████████████████████████████████████████████████████████████████████| 472/472 [00:00<00:00, 6.13MB/s]
Downloading (…)cheduler_config.json: 100%|██████████████████████████████████████████████████████████████████████████████████████████████| 341/341 [00:00<00:00, 3.30MB/s]
Downloading (…)okenizer_config.json: 100%|██████████████████████████████████████████████████████████████████████████████████████████████| 807/807 [00:00<00:00, 11.3MB/s]
```

## Deleting models

```
 invokeai-model-install --delete --diffusers anything-v3
Loading Python libraries...

[2023-06-05 22:19:45,927]::[InvokeAI]::INFO --> Processing requested deletions
[2023-06-05 22:19:45,927]::[InvokeAI]::INFO --> anything-v3...
[2023-06-05 22:19:45,927]::[InvokeAI]::INFO --> Deleting the cached model directory for Linaqruf/anything-v3.0
[2023-06-05 22:19:45,948]::[InvokeAI]::WARNING --> Deletion of this model is expected to free 4.3G

```
2023-06-07 19:25:07 -07:00
4bbe3b0d00 Merge branch 'main' into release/make-web-dist-startable 2023-06-07 19:21:01 -07:00
9ed86a08f1 multiple small fixes
1. Contents of autoscan directory field are restored after doing an installation.
2. Activate dialogue to choose V2 parameterization when importing from a directory.
3. Remove autoscan directory from init file when its checkbox is unselected.
4. Add widget cycling behavior to install models form.
2023-06-07 17:32:00 -04:00
68405910ba Upgrade to Diffusers 0.17.0 2023-06-08 04:42:52 +12:00
0a50e2638c fix(ui): default controlnet autoprocess to true (#3513)
I had accidentally defaulted it to false
2023-06-08 01:56:53 +12:00
fc7c5da4dd fix(ui): default controlnet autoprocess to true
I had accidentally defaulted it to false
2023-06-07 23:55:24 +10:00
a3357e073c refactor exception handling 2023-06-07 07:35:34 -04:00
d114833a12 pause after printing exception 2023-06-07 07:26:14 -04:00
96038bd075 print exception on TUI crash 2023-06-07 07:23:14 -04:00
2f383c2598 docs(nodes): update INVOCATIONS.md (#3511) 2023-06-07 20:47:57 +12:00
702a8d1f72 docs(nodes): update INVOCATIONS.md 2023-06-07 18:44:43 +10:00
0a8390356f feat(ui): enhance autoprocessing
The processor is automatically selected when model is changed.

But if the user manually changes the processor, processor settings, or disables the new `Auto configure processor` switch, auto processing is disabled.

The user can enable auto configure by turning the switch back on.

When auto configure is enabled, a small dot is overlaid on the expand button to remind the user that the system is not auto configuring the processor for them.

If auto configure is enabled, the processor settings are reset to the default for the selected model.
2023-06-07 18:25:30 +10:00
844058c0a5 feat(ui): make prompt not required
- also change the placeholder text
2023-06-07 18:25:30 +10:00
7d74cbe29c fix(ui): make progress image not draggable 2023-06-07 18:25:30 +10:00
62ac0ed2dc feat(ui): tweak cnet model change
If there is no control image, and the model does not have a default processor, set the processor to `none`.
2023-06-07 18:25:30 +10:00
ae14adec2a feat(ui): add reset button for control image 2023-06-07 18:25:30 +10:00
6c2b39d1df feat(ui): improve controlnet image style
css is terrible
2023-06-07 18:25:30 +10:00
0843028e6e fix(ui): improve dragging activation
- delay of 250ms
- prevent gallery images from accidentally activating native drag and drop
2023-06-07 18:25:30 +10:00
de0fd87035 fix(ui): when a session errors, reset controlnet processing spinner 2023-06-07 18:25:30 +10:00
8b6c0be259 feat(ui): fix IAIDndImage button styles when upload disabled 2023-06-07 18:25:30 +10:00
58fec84858 feat(ui): add upload to IAIDndImage
Add uploading to IAIDndImage
- add `postUploadAction` arg to `imageUploaded` thunk, with several current valid options (set control image, set init, set nodes image, set canvas, or toast)
- updated IAIDndImage to optionally allow click to upload
2023-06-07 18:25:30 +10:00
f223ad7776 fix(ui): only show loading indicator on processing control images 2023-06-07 18:25:30 +10:00
00eabf630d fix(ui): fix control image not used if processor type is none 2023-06-07 18:25:30 +10:00
6245a27650 feat(ui): auto-select controlnet processor
- when the controlnet model is changed, if there is a default processor for the model set, the processor is changed.
- once a control image is selected (and processed), changing the model does not change the processor - must be manually changed
2023-06-07 18:25:30 +10:00
fa1ac57c90 Graph overlay was expanding off the screen to the size of the prompt line (#3510)
sure this isn't really important at the moment

just limited the size of the width and gave it a shadow

![image](https://github.com/invoke-ai/InvokeAI/assets/115216705/96e2db0a-9edb-48b8-9040-56ce054b5ecf)
2023-06-07 18:01:35 +12:00
0f16b1c98d Remove Shadow 2023-06-07 15:51:37 +10:00
08e66c5451 Update NodeGraphOverlay.tsx
Graph overlay was expanding off the screen to the size of the prompt
2023-06-07 14:49:03 +10:00
563bf70c95 fix CI failure in configure non-interactive mode; merged with main 2023-06-06 23:24:40 -04:00
49d29420c4 Merge branch 'main' into release/make-web-dist-startable 2023-06-06 23:24:16 -04:00
ae9d0c6c1b fix logger behavior so that it is initialized after command line parsed 2023-06-06 23:19:10 -04:00
04f9757f8d prevent crash when trying to calculate size of missing safety_checker
- Also fixed up order in which logger is created in invokeai-web
  so that handlers are installed after command-line options are
  parsed (and not before!)
2023-06-06 22:57:49 -04:00
1f9e1eb964 merge with main 2023-06-06 22:18:41 -04:00
d8d11f9bbb quench fp16 rev id not found warning 2023-06-06 22:01:05 -04:00
13fa0d3bc0 make log message textbox deeper 2023-06-06 17:23:13 -04:00
5eeb4b8e06 allow user to abort conversion of V2 models from within TUI 2023-06-06 17:21:50 -04:00
f5044c290d fix crash during model conversion 2023-06-06 17:05:29 -04:00
1b43276e5d make widget selection wrap around 2023-06-06 13:53:11 -07:00
294f086857 configure/install working correctly on windows11 2023-06-06 12:51:34 -07:00
e5024bf5e9 fix conhost launch-with args 2023-06-06 15:17:15 -04:00
79198b4bba feat(ui): fix bugs with image deletion (#3506)
- `imageUsage` object was always stale due to react component lifecycle,
fixed this
- cleaned up the deletion listener and context
2023-06-07 05:33:05 +12:00
1a2f0984db Merge branch 'main' into feat/ui/fix-stale-imageUsage 2023-06-07 04:35:16 +12:00
454683e6eb feat(ui): update image urls on connect (#3507)
* feat(ui): update image urls on connect

Add `updateImageUrlsOnConnect` RTK listener:
- requests URLs for *every* image the app knows about, on connect: gallery, selectedImage, initialImage, canvas images, nodes images, controlnet images
- only fires when `shouldUpdateImagesOnConnect` config is enabled

* remove prop

---------

Co-authored-by: Mary Hipp <maryhipp@Marys-MacBook-Air.local>
2023-06-06 10:23:51 -04:00
bbb2a08e8f feat(ui): fix bugs with image deletion
- `imageUsage` object was always stale due to react component lifecycle, fixed this
- cleaned up the deletion listener and context
2023-06-06 20:01:27 +10:00
bf116927e1 feat(ui): clear features if image used by them is deleted
This handles the case when an image is deleted but is still in use in as eg an init image on canvas, or a control image. If we just delete the image, canvas/controlnet/etc may break (the image would just fail to load).

When an image is deleted, the app checks to see if it is in use in:
- Image to Image
- ControlNet
- Unified Canvas
- Node Editor

The delete dialog will always open if the image is in use anywhere, and the user is advised that deleting the image will reset the feature(s).

Even if the user has ticked the box to not confirm on delete, the dialog will still show if the image is in use somewhere.
2023-06-06 14:35:07 +10:00
3d249c4fa3 feat(ui): refactor image deletion
Add `DeleteImageContext`:
- provide a single function to delete an image
- opens the modal or immediately deletes, if confirm is off
2023-06-06 14:35:07 +10:00
fa338ddb6a feat(ui): add useGetIsImageInUse
Checks if an image is currently being used eg in canvas, nodes, controlnet, init image.
2023-06-06 14:35:07 +10:00
b200451330 feat(ui): add nodesSelector 2023-06-06 14:35:07 +10:00
8283d23b74 feat(ui): remove shouldTransformUrls
This is no longer used.
2023-06-06 14:35:07 +10:00
2fc0a4d53b feat(ui): improve handling for urls/metadata received
Update images everywhere when urls or metadata is received:
- control images
- init images
- canvas
- nodes
- init image

Also renamed the variable.
2023-06-06 14:35:07 +10:00
3ff732d583 feat(ui): clear controlnet image when image deleted 2023-06-06 14:35:07 +10:00
840c632c0a feat(ui): sort images by updated_at instead of created_at
fixes issue where saved staging area images are sorted as expected in gallery.
2023-06-06 14:30:53 +10:00
40d6e4f287 fix(ui): fix canvas auto-save not working 2023-06-06 14:30:53 +10:00
fc5f9c30a6 fix(ui): fix metadata viewer not working for canvas images 2023-06-06 14:30:53 +10:00
229de2dbb8 feat(ui): fix canvas saving
- fix "bounding box region only" not being respected when saving
- add toasts for each action
- improve workflow `take()` predicates to use the requestId
2023-06-06 14:30:53 +10:00
cc22427f25 feat(ui): improve UI on smaller screens
- responsive changes were causing a lot of weird layout issues, had to remove the rest of them
- canvas (non-beta) toolbar now wraps
- reduces minH for prompt boxes a bit
2023-06-06 14:29:57 +10:00
90333c0074 merge with main 2023-06-05 22:03:44 -04:00
54e5301b35 Multiple fixes
1. Model installer works correctly under Windows 11 Terminal
2. Fixed crash when configure script hands control off to installer
3. Kill install subprocess on keyboard interrupt
4. Command-line functionality for --yes configuration and model installation
   restored.
5. New command-line features:
   - install/delete lists of diffusers, LoRAS, controlnets and textual inversions
     using repo ids, paths or URLs.

Help:

```
usage: invokeai-model-install [-h] [--diffusers [DIFFUSERS ...]] [--loras [LORAS ...]] [--controlnets [CONTROLNETS ...]] [--textual-inversions [TEXTUAL_INVERSIONS ...]] [--delete] [--full-precision | --no-full-precision]
                              [--yes] [--default_only] [--list-models {diffusers,loras,controlnets,tis}] [--config_file CONFIG_FILE] [--root_dir ROOT]

InvokeAI model downloader

options:
  -h, --help            show this help message and exit
  --diffusers [DIFFUSERS ...]
                        List of URLs or repo_ids of diffusers to install/delete
  --loras [LORAS ...]   List of URLs or repo_ids of LoRA/LyCORIS models to install/delete
  --controlnets [CONTROLNETS ...]
                        List of URLs or repo_ids of controlnet models to install/delete
  --textual-inversions [TEXTUAL_INVERSIONS ...]
                        List of URLs or repo_ids of textual inversion embeddings to install/delete
  --delete              Delete models listed on command line rather than installing them
  --full-precision, --no-full-precision
                        use 32-bit weights instead of faster 16-bit weights (default: False)
  --yes, -y             answer "yes" to all prompts
  --default_only        only install the default model
  --list-models {diffusers,loras,controlnets,tis}
                        list installed models
  --config_file CONFIG_FILE, -c CONFIG_FILE
                        path to configuration file to create
  --root_dir ROOT       path to root of install directory
```
2023-06-05 21:45:35 -04:00
b31fc43bfa Fix potential race condition in config system (#3466)
There was a potential gotcha in the config system that was previously
merged with main. The `InvokeAIAppConfig` object was configuring itself
from the command line and configuration file within its initialization
routine. However, this could cause it to read `argv` from the command
line at unexpected times. This PR fixes the object so that it only reads
from the init file and command line when its `parse_args()` method is
explicitly called, which should be done at startup time in any top level
script that uses it.

In addition, using the `get_invokeai_config()` function to get a global
version of the config object didn't feel pythonic to me, so I have
changed this to `InvokeAIAppConfig.get_config()` throughout.

## Updated Usage

In the main script, at startup time, do the following:

```
from invokeai.app.services.config import InvokeAIAppConfig
config = InvokeAIAppConfig.get_config()
config.parse_args()
```

In non-main scripts, it is not necessary (or recommended) to call
`parse_args()`:
```
from invokeai.app.services.config import InvokeAIAppConfig
config = InvokeAIAppConfig.get_config()
```

The configuration object properties can be overridden when
`get_config()` is called by passing initialization values in the usual
way. If a property is set this way, then it will not be changed by
subsequent calls to `parse_args()`, but can only be changed by
explicitly setting the property.

```
config = InvokeAIAppConfig.get_config(nsfw_checker=True)
config.parse_args(argv=['--no-nsfw_checker'])
config.nsfw_checker
# True
```

You may specify alternative argv lists and configuration files in
`parse_args()`:

```
config.parse_args(argv=['--no-nsfw_checker'],
                             conf = OmegaConf.load('/tmp/test.yaml')
)
```

For backward compatibility, the `get_invokeai_config()` function is
still available from the module, but has been removed from the rest of
the source tree.
2023-06-05 15:26:50 -07:00
9bcf0b2251 Merge branch 'main' into lstein/config-management-fixes 2023-06-05 15:10:33 -07:00
d4bc98c383 revert to conhost method 2023-06-05 11:46:01 -07:00
bc892c535c feat(ui): fix image fit (#3501)
- Prevent init, current & control images from overflowing
2023-06-05 20:48:55 +12:00
099e1e7c08 feat(ui): fix image fit
- Prevent init, current & control images from overflowing
2023-06-05 17:16:30 +10:00
b1000e30c1 feat(ui): disable keyboard dnd
Need to fix a bug w/ collision detection before enabling it. Will pursue later.
2023-06-05 15:24:24 +10:00
7bd94eac0e feat(ui): support image dnd to canvas 2023-06-05 15:24:24 +10:00
2c77563dcc feat(ui): move DropOverlay into its own IAIDropOverlay component 2023-06-05 15:24:24 +10:00
603c9a587e open Windows Terminal maximized 2023-06-05 00:24:13 -04:00
1a5a2dfda9 increased window size 2023-06-04 23:54:52 -04:00
090b7eeaf3 workaround to get adequate window size on Windows Terminal 2023-06-04 23:44:07 -04:00
117536324c the "restore" env variable in .bat launcher confuses pydantic 2023-06-04 22:53:46 -04:00
999c092b6a fix mouse and window resizing issues 2023-06-04 22:00:11 -04:00
9e31b1f387 Merge branch 'main' into lstein/config-management-fixes 2023-06-04 18:17:43 -04:00
cb157ea530 fix crash when install-models launched from config script 2023-06-04 14:55:51 -04:00
5f6f38074d merge with main 2023-06-04 13:59:31 -04:00
25b8dd340a Prompting: enable long prompts and compel's new .and() concatenating feature (#3497)
this PR adds long prompt support and enables compel's new `.and()`
concatenation feature which improves image quality especially with SD2.1

example of a long prompt:
> a moist sloppy pindlesackboy sloppy hamblin' bogomadong, Clem Fandango
is pissed-off, Wario's Woods in background, making a noise like
ga-woink-a
![000075 6dfd7adf
466129594](https://github.com/invoke-ai/InvokeAI/assets/144366/051608b6-8d52-463b-af10-04b695cda9c1)

the same prompt broken into fragments and concatenated using `.and()`
(syntax works like `.blend()`):
```
("a moist sloppy pindlesackboy sloppy hamblin' bogomadong", 
"Clem Fandango is pissed-off", 
"Wario's Woods in background", 
"making a noise like ga-woink-a").and()
```
![000076 68b1c320
466129594](https://github.com/invoke-ai/InvokeAI/assets/144366/3fee291f-5562-40f9-9c3c-a73765fc893a)


and a less silly example:

> A dream of a distant galaxy, by Caspar David Friedrich, matte
painting, trending on artstation, HQ
![000129 1b33b559
2793529321](https://github.com/invoke-ai/InvokeAI/assets/144366/d4113756-ed0d-49cd-bb2e-a2fc4a09e0af)

the same prompt broken into two fragments and concatenated:
```
("A dream of a distant galaxy, by Caspar David Friedrich, matte painting", 
"trending on artstation, HQ").and()
```
![000128 b5d5cd62
2793529321](https://github.com/invoke-ai/InvokeAI/assets/144366/c373c009-05db-4c42-8a1d-c89fbdb334ec)

as with `.blend()` you can also weight the parts eg `("a man eating an
apple", "sitting on the roof of a car", "high quality, trending on
artstation, 8K UHD").and(1, 0.5, 0.5)` which will assign weight `1` to
`a man eating an apple` and `0.5` to `sitting on the roof of a car` and
`high quality, trending on artstation, 8K UHD`.
2023-06-05 04:53:08 +12:00
fb06f5b892 Merge branch 'main' into feat_compel_longprompts_and_concat 2023-06-05 04:34:39 +12:00
1a7fb601dc ask user for v2 variant when model manager can't infer it 2023-06-04 11:27:44 -04:00
cdcfda164d enable long prompts, upgrade compel to enable .and() (concatenating prompts) 2023-06-04 15:30:54 +02:00
966b154a1f Update web README.md (#3496) 2023-06-05 00:56:00 +12:00
95fa66661c dummy commit to make github actions run 2023-06-04 22:55:35 +10:00
6247b79111 docs(ui): update API_CLIENT 2023-06-04 22:46:53 +10:00
5831364f9c Update web README.md 2023-06-04 22:44:18 +10:00
919b81cff1 fix(ui): fix rebase issue 2023-06-04 22:34:58 +10:00
065fff7db5 fix(ui): fix wonkiness with image dnd 2023-06-04 22:34:58 +10:00
a664ee30a2 feat(ui): do not change images if the dropped image is the same image 2023-06-04 22:34:58 +10:00
03f3ad435a feat(ui): updated controlnet logic/ui 2023-06-04 22:34:58 +10:00
2270c270ef feat(ui): add tooltip to IAISwitch 2023-06-04 22:34:58 +10:00
4f7820719b feat(ui): add ellipsis direction to IAICustomSelect 2023-06-04 22:34:58 +10:00
fa285883ad feat(ui): make OverlayDragImage translucent 2023-06-04 22:34:58 +10:00
474fca8e6a feat(ui): add controlNetDenylist 2023-06-04 22:34:58 +10:00
5dc0250b00 feat(ui): ControlNet layout tweaks 2023-06-04 22:34:58 +10:00
f269377a01 feat(ui): "ProcessorOptionsContainer" -> "ProcessorWrapper", organise 2023-06-04 22:34:58 +10:00
d0406024e3 feat(ui): IAICustomSelect tweak styles 2023-06-04 22:34:58 +10:00
aa3a969bd2 feat: Update ControlNet Model List & Map 2023-06-04 22:34:58 +10:00
73a95973a8 wip: Add Wrapper Container for Preprocessor Options
For fast altering of the layout across all pre-preocessors.
2023-06-04 22:34:58 +10:00
bf4fe3c1ac wip: Fixing layout shifts with the ControlNet tab 2023-06-04 22:34:58 +10:00
d6c08ba469 feat(ui): add mini/advanced controlnet ui 2023-06-04 22:34:58 +10:00
69f0ba65f1 chore(ui): bump react-icons 2023-06-04 22:34:58 +10:00
828c86964d feat(ui): IAICustomSelect prevent label wrap 2023-06-04 22:34:58 +10:00
54b7ddd63f feat(ui): IAIDndImage cursor: 'grab' 2023-06-04 22:34:58 +10:00
a0dde66b5d feat(ui): more work on controlnet mini 2023-06-04 22:34:58 +10:00
b6b3b9f99c feat(ui): make scrollbar less bright 2023-06-04 22:34:58 +10:00
faa69f8a47 feat(ui): add alpha colors 2023-06-04 22:34:58 +10:00
d92c7f5483 feat(ui): organize IAIDndImage component 2023-06-04 22:34:58 +10:00
6b824eb112 feat(ui): initial mini controlnet UI, dnd improvements 2023-06-04 22:34:58 +10:00
72b4371804 feat(ui): control image auto-process 2023-06-04 22:34:58 +10:00
fa290aff8d feat(ui): add defaults for all processors 2023-06-04 22:34:58 +10:00
3d99d7ae8b feat(ui): update handling of inProgess, do not allow cnet process when processing 2023-06-04 22:34:58 +10:00
2eb367969c feat(ui): do not autoprocess control if invocation in progress 2023-06-04 22:34:58 +10:00
9cdad95f48 feat(ui): add rest of controlnet processors 2023-06-04 22:34:58 +10:00
707ed39300 chore(ui): regen api client 2023-06-04 22:34:58 +10:00
6bbb5f061a feat(nodes): update controlnet names/descriptions 2023-06-04 22:34:58 +10:00
6896e69e95 fix(ui): fix multiple controlnets 2023-06-04 22:34:58 +10:00
b17f4c1650 feat(ui): more tweaking controlnet ui 2023-06-04 22:34:58 +10:00
98493ed9e2 feat(ui): reorg parameter panel to make room for controlnet 2023-06-04 22:34:58 +10:00
94c953deab feat(ui): get processed images back into controlnet ui 2023-06-04 22:34:58 +10:00
fa4d88e163 feat(ui): improve drag and drop ux 2023-06-04 22:34:58 +10:00
b1e1e3efc7 fix(ui): fix IAISelectableImage fallback 2023-06-04 22:34:58 +10:00
3b9426eb72 feat(ui): controlnet/image dnd wip
Implement `dnd-kit` for image drag and drop
- vastly simplifies logic bc we can drag and drop non-serializable data (like an `ImageDTO`)
- also much prettier
- also will fix conflicts with file upload via OS drag and drop, bc `dnd-kit` does not use native HTML drag and drop API
- Implemented for Init image, controlnet, and node editor so far

More progress on the ControlNet UI
2023-06-04 22:34:58 +10:00
e2e07696fc feat(ui): wip controlnet ui 2023-06-04 22:34:58 +10:00
d6a959b000 feat(nodes): tidy controlnet processor nodes & improve descriptions 2023-06-04 22:34:58 +10:00
c3935d3849 feat(nodes): add separate scripts to launch cli and web (#3495) 2023-06-04 08:13:14 -04:00
383e3d77cb feat(nodes): add separate scripts to launch cli and web 2023-06-04 22:02:47 +10:00
31e97ead2a move invokeai.db to ~/invokeai/databases
- The invokeai.db database file has now been moved into
  `INVOKEAIROOT/databases`. Using plural here for possible
  future with more than one database file.

- Removed a few dangling debug messages that appeared during
  testing.

- Rebuilt frontend to test web.
2023-06-03 20:25:34 -04:00
0b49995659 merge with main 2023-06-03 20:06:27 -04:00
ff204db6b2 Add logging configuration (#3460)
This PR provides a number of options for controlling how InvokeAI logs
messages, including options to log to a file, syslog and a web server.
Several logging handlers can be configured simultaneously.

## Controlling How InvokeAI Logs Status Messages

InvokeAI logs status messages using a configurable logging system. You
can log to the terminal window, to a designated file on the local
machine, to the syslog facility on a Linux or Mac, or to a properly
configured web server. You can configure several logs at the same time,
and control the level of message logged and the logging format (to a
limited extent).

Three command-line options control logging:

### `--log_handlers <handler1> <handler2> ...`

This option activates one or more log handlers. Options are "console",
"file", "syslog" and "http". To specify more than one, separate them by
spaces:

```bash
invokeai-web --log_handlers console syslog=/dev/log file=C:\Users\fred\invokeai.log
```

The format of these options is described below.

### `--log_format {plain|color|legacy|syslog}`

This controls the format of log messages written to the console. Only
the "console" log handler is currently affected by this setting.

* "plain" provides formatted messages like this:

```bash

[2023-05-24 23:18:2[2023-05-24 23:18:50,352]::[InvokeAI]::DEBUG --> this is a debug message
[2023-05-24 23:18:50,352]::[InvokeAI]::INFO --> this is an informational messages
[2023-05-24 23:18:50,352]::[InvokeAI]::WARNING --> this is a warning
[2023-05-24 23:18:50,352]::[InvokeAI]::ERROR --> this is an error
[2023-05-24 23:18:50,352]::[InvokeAI]::CRITICAL --> this is a critical error
```

* "color" produces similar output, but the text will be color coded to
indicate the severity of the message.

* "legacy" produces output similar to InvokeAI versions 2.3 and earlier:

```bash
### this is a critical error
*** this is an error
** this is a warning
>> this is an informational messages
   | this is a debug message
```

* "syslog" produces messages suitable for syslog entries:

```bash
InvokeAI [2691178] <CRITICAL> this is a critical error
InvokeAI [2691178] <ERROR> this is an error
InvokeAI [2691178] <WARNING> this is a warning
InvokeAI [2691178] <INFO> this is an informational messages
InvokeAI [2691178] <DEBUG> this is a debug message
```

(note that the date, time and hostname will be added by the syslog
system)

### `--log_level {debug|info|warning|error|critical}`

Providing this command-line option will cause only messages at the
specified level or above to be emitted.

## Console logging

When "console" is provided to `--log_handlers`, messages will be written
to the command line window in which InvokeAI was launched. By default,
the color formatter will be used unless overridden by `--log_format`.

## File logging

When "file" is provided to `--log_handlers`, entries will be written to
the file indicated in the path argument. By default, the "plain" format
will be used:

```bash
invokeai-web --log_handlers file=/var/log/invokeai.log
```

## Syslog logging

When "syslog" is requested, entries will be sent to the syslog system.
There are a variety of ways to control where the log message is sent:

* Send to the local machine using the `/dev/log` socket:

```
invokeai-web --log_handlers syslog=/dev/log
```

* Send to the local machine using a UDP message:

```
invokeai-web --log_handlers syslog=localhost
```

* Send to the local machine using a UDP message on a nonstandard port:

```
invokeai-web --log_handlers syslog=localhost:512
```

* Send to a remote machine named "loghost" on the local LAN using
facility LOG_USER and UDP packets:

```
invokeai-web --log_handlers syslog=loghost,facility=LOG_USER,socktype=SOCK_DGRAM
```

This can be abbreviated `syslog=loghost`, as LOG_USER and SOCK_DGRAM are
defaults.

* Send to a remote machine named "loghost" using the facility LOCAL0 and
using a TCP socket:

```
invokeai-web --log_handlers syslog=loghost,facility=LOG_LOCAL0,socktype=SOCK_STREAM
```

If no arguments are specified (just a bare "syslog"), then the logging
system will look for a UNIX socket named `/dev/log`, and if not found
try to send a UDP message to `localhost`. The Macintosh OS used to
support logging to a socket named `/var/run/syslog`, but this feature
has since been disabled.

## Web logging

If you have access to a web server that is configured to log messages
when a particular URL is requested, you can log using the "http" method:

```
invokeai-web --log_handlers http=http://my.server/path/to/logger,method=POST
```

The optional [,method=] part can be used to specify whether the URL
accepts GET (default) or POST messages.

Currently password authentication and SSL are not supported.

## Using the configuration file

You can set and forget logging options by adding a "Logging" section to
`invokeai.yaml`:

```
InvokeAI:
  [... other settings...]
  Logging:
    log_handlers:
       - console
       - syslog=/dev/log
    log_level: info
    log_format: color
```
2023-06-03 20:03:40 -04:00
f74f3d6a3a many TUI improvements:
1. Separated the "starter models" and "more models" sections. This
   gives us room to list all installed diffuserse models, not just
   those that are on the starter list.

2. Support mouse-based paste into the textboxes with either middle
   or right mouse buttons.

3. Support terminal-style cursor movement:
     ^A to move to beginning of line
     ^E to move to end of line
     ^K kill text to right and put in killring
     ^Y yank text back

4. Internal code cleanup.
2023-06-03 16:17:53 -04:00
713fb061e8 Merge branch 'main' into release/make-web-dist-startable 2023-06-02 23:19:33 -04:00
77b7680b32 slight refactoring of code; configure --yes should work now 2023-06-02 23:19:14 -04:00
ff63433591 Merge branch 'main' into lstein/config-management-fixes 2023-06-02 22:56:43 -04:00
31281d7181 Merge branch 'main' into lstein/logging-improvements 2023-06-02 22:56:13 -04:00
8285fbb0b1 Merge branch 'lstein/new-model-manager' of github.com:invoke-ai/InvokeAI into lstein/new-model-manager 2023-06-02 22:48:00 -04:00
951e6b746c remove model cache test; should be replaced with something else 2023-06-02 22:47:48 -04:00
44a6623094 Merge branch 'main' into lstein/new-model-manager 2023-06-02 22:40:51 -04:00
72d1e4e404 fix bug in model_manager that prevented import of inpainting models 2023-06-02 22:39:26 -04:00
91918e648b dynamic display of log messages now working 2023-06-02 22:24:46 -04:00
1390b65a9c new TUI is fully functional; needs some polishing 2023-06-02 17:20:50 -04:00
82231369d3 Make Invoke Button also the progress bar (#3492)
Find on some screens the progress bar at top is hard to see, Bar should
only show when in progress


![Animation](https://github.com/invoke-ai/InvokeAI/assets/115216705/04f945d3-377b-4646-b125-1355e74b6b09)
2023-06-02 19:30:45 +12:00
7620bacc01 feat: Add temporary NodeInvokeButton 2023-06-02 17:55:15 +12:00
ea9cf04765 fix: Remove progress bg instead of altering button bg 2023-06-02 17:36:14 +12:00
47301e6f85 fix: Do the same without zIndex 2023-06-02 17:33:38 +12:00
f143fb7254 feat: Make Invoke Button also the progress bar 2023-06-02 17:24:40 +12:00
2bdb655375 Change to absolute 2023-06-02 14:59:10 +10:00
41f7758977 listing, downloading and deleting LoRAs working; TI support pending 2023-06-02 00:40:15 -04:00
8ae1eaaccc Add Progress bar under invoke Button
Find on some screens the progress bar at top of screen gets cut off
2023-06-02 14:19:02 +10:00
98773b20ac merge with main 2023-06-01 18:09:49 -04:00
d66979073b add optional config for settings modal 2023-06-02 00:36:45 +10:00
c9e621093e fix(ui): fix looping gallery images fetch
The gallery could get in a state where it thought it had just reached the end of the list and endlessly fetches more images, if there are no more images to fetch (weird I know).

Add some logic to remove the `end reached` handler when there are no more images to load.
2023-06-02 00:34:03 +10:00
e06ba40795 fix(ui): do not allow dpmpp_2s to be used ever
it doesn't work for the img2img pipelines, but the implemented conditional display could break the scheduler selection dropdown.

simple fix until diffusers merges the fix - never use this scheduler.
2023-06-02 00:30:01 +10:00
6571e4c2fd feat(ui): refactor parameter recall
- use zod to validate parameters before recalling
- update recall params hook to handle all validation and UI feedback
2023-06-02 00:30:01 +10:00
ff9240b51d slight code cleanup 2023-06-01 00:45:07 -04:00
18466e01fd tab selection seems very natural; not wired to backend yet 2023-06-01 00:43:28 -04:00
e9821ab711 implemented tabbed model selection; not wired to backend yet 2023-06-01 00:31:46 -04:00
d6530df635 rename invokeai.backend.config to invokeai.backend.install 2023-05-31 21:34:20 -04:00
b47786e846 First working TI draft 2023-05-31 02:12:27 +03:00
062b2cf46f fix(ui): fix width and height not working on txt2img tab
I missed a spot when working on the graph logic yesterday.
2023-05-30 18:41:09 -04:00
082ecf6f25 minor formatting improvements 2023-05-30 13:59:32 -04:00
1632ac6b9f add controlnet model downloading 2023-05-30 13:49:43 -04:00
69ccd3a0b5 Fixes for checkpoint models 2023-05-30 19:12:47 +03:00
877959b413 fix(ui): ensure download image opens in new tab 2023-05-30 09:22:54 -04:00
6e60f7517b feat(ui): add model description tooltips 2023-05-30 09:06:13 -04:00
296ee6b7ea feat(ui): tidy ParamScheduler component 2023-05-30 09:06:13 -04:00
7c7ffddb2b feat(ui): upgrade IAICustomSelect to optionally display tooltips for each item 2023-05-30 09:06:13 -04:00
e1ae7842ff feat(ui): add defaultModel to config 2023-05-30 09:06:13 -04:00
9687fe7bac fix(ui): set default model to first model (alpha sort) 2023-05-30 09:06:13 -04:00
a9a2bd90c2 fix(nodes): set min and max for l2l strength 2023-05-30 09:06:13 -04:00
47ca71a7eb fix(nodes): set cfg_scale min to 1 in latents 2023-05-30 09:06:13 -04:00
a9c47237b1 fix(ui): mark img2img resize node intermediate 2023-05-30 09:06:13 -04:00
33bbae2f47 fix(ui): fix missing init image when fit disabled 2023-05-30 09:06:13 -04:00
fab7a1d337 fix(ui): fix bug with staging bbox not resetting 2023-05-30 09:06:13 -04:00
cffcf80977 fix(ui): remove w/h from canvas params, add bbox w/h 2023-05-30 09:06:13 -04:00
1a3fd05b81 fix(ui): fix canvas bbox autoscale 2023-05-30 09:06:13 -04:00
c22c6ca135 fix(ui): fix img2img fit 2023-05-30 09:06:13 -04:00
3afb6a387f chore(ui): regen api 2023-05-30 09:06:13 -04:00
33e5ed7180 fix(ui): fix edge case in nodes graph building
Inputs with explicit values are validated by pydantic even if they also
have a connection (which is the actual value that is used).

Fix this by omitting explicit values for inputs that have a connection.
2023-05-30 09:06:13 -04:00
2067757fab feat(ui): enable progress images by default 2023-05-30 09:06:13 -04:00
b1b94a3d56 Fixed problem with inpainting after controlnet support added to main.
Problem was that controlnet support involved adding **kwargs to method calls down in denoising loop, and AddsMaskLatents didn't accept **kwarg arg. So just changed to accept and pass on **kwargs.
2023-05-30 08:01:21 -04:00
c9ee42450e added controlnet models to frontend; backend needs to be done 2023-05-30 00:38:37 -04:00
10fe31c2a1 Merge branch 'main' into lstein/config-management-fixes 2023-05-29 21:03:03 -04:00
420a76ecdd Add lora loader node 2023-05-30 02:12:33 +03:00
79de9047b5 First working lora implementation 2023-05-30 01:11:00 +03:00
dc54cbb1fc Merge branch 'main' into release/make-web-dist-startable 2023-05-29 14:16:10 -04:00
070218aba7 feat(ui): add progress image toggle to current image buttons 2023-05-29 09:07:46 -04:00
f1c226b171 fix(ui): remove console.log() 2023-05-29 09:07:46 -04:00
7004430380 feat(ui): gallery filter dropdown -> Images/Assets toggle 2023-05-29 09:07:46 -04:00
1ddc620192 feat(ui): only cancel on staging commit if processing 2023-05-29 09:07:46 -04:00
a7cebbd970 feat(ui): cancel session when staging image accepted 2023-05-29 09:07:46 -04:00
d97438b0b3 fix(ui): fix typo in actionsDenylist 2023-05-29 09:07:46 -04:00
4522f3f4c9 fix(ui): fix progress images in canvas 2023-05-29 09:07:46 -04:00
6fe28980b0 feat(ui): revert in-gallery progress
wasn't fully baked. will revisist in the future.
2023-05-29 09:07:46 -04:00
4aec5d8ffc fix(ui): typo 2023-05-29 09:07:46 -04:00
bbb4e8f5ef feat(nodes): add resize image and scale image nodes 2023-05-29 09:07:46 -04:00
bce33ea62e fix(ui): when session is complete, null out progress image
This may cause minor gallery jumpiness at the very end of processing, but is necessary to prevent the progress image from sticking around if the last node in a session did not have an image output.
2023-05-29 09:07:46 -04:00
e4705d5ce7 fix(ui): add additional socket event layer to gate handling socket events
Some socket events should not be handled by the slice reducers. For example generation progress should not be handled for a canceled session.

Added another layer of socket actions.

Example:
- `socketGeneratorProgress` is dispatched when the actual socket event is received
- Listener middleware exclusively handles this event and determines if the application should also handle it
- If so, it dispatches `appSocketGeneratorProgress`, which the slices can handle

Needed to fix issues related to canceling invocations.
2023-05-29 09:07:46 -04:00
6764b2a854 fix(ui): fix save to gallery without bounding box 2023-05-29 09:07:46 -04:00
970340cf62 fix(ui): infill and scaling options label 2023-05-29 09:07:46 -04:00
043f9d9ba4 fix(ui): fix auto-switch to new images 2023-05-29 09:07:46 -04:00
6f82801d07 fix(ui): fix canvas save to gallery incorrect is_intermediate flag 2023-05-28 20:19:56 -04:00
3e3dd39ae4 fix(nodes): fix images service update() for is_intermediate 2023-05-28 20:19:56 -04:00
89aa06e014 feat(ui): consolidate images slice
Now that images are in a database and we can make filtered queries, we can do away with the cumbersome `resultsSlice` and `uploadsSlice`.

- Remove `resultsSlice` and `uploadsSlice` entirely
- Add `imagesSlice` fills the same role
- Convert the application to use `imagesSlice`, reducing a lot of messy logic where we had to check which category was selected
- Add a simple filter popover to the gallery, which lets you select any number of image categories
2023-05-28 20:19:56 -04:00
6cc00ef4b7 chore(ui): regen api client 2023-05-28 20:19:56 -04:00
f31e62afad feat(nodes): make list images route use offset pagination
Because we dynamically insert images into the DB and UI's images state, `page`/`per_page` pagination makes loading the images awkward.

Using `offset`/`limit` pagination lets us query for images with an offset equal to the number of images already loaded (which match the query parameters).

The result is that we always get the correct next page of images when loading more.
2023-05-28 20:19:56 -04:00
38fd2ad45d fix(ui): fix metadata viewer crash 2023-05-28 20:19:56 -04:00
05b99b5377 fix(ui): fix erroneously displays is_intermediate field on nodes 2023-05-28 20:19:56 -04:00
08a14ee6d5 fix(nodes): fix conflicts with controlnet 2023-05-28 20:19:56 -04:00
29fcc92da9 feat(ui): handle new image origin/category setup
- Update all thunks & network related things
- Update gallery

What I have not done yet is rename the gallery tabs and the relevant slices, but I believe the functionality is all there.

Also I fixed several bugs along the way but couldn't really commit them separately bc I was refactoring. Can't remember what they were, but related to the gallery image switching.
2023-05-28 20:19:56 -04:00
d78e3572e3 chore(ui): regen api client 2023-05-28 20:19:56 -04:00
160267c71a feat(nodes): refactor image types
- Remove `ImageType` entirely, it is confusing
- Create `ResourceOrigin`, may be `internal` or `external`
- Revamp `ImageCategory`, may be `general`, `mask`, `control`, `user`, `other`. Expect to add more as time goes on
- Update images `list` route to accept `include_categories` OR `exclude_categories` query parameters to afford finer-grained querying. All services are updated to accomodate this change.

The new setup should account for our types of images, including the combinations we couldn't really handle until now:
- Canvas init and masks
- Canvas when saved-to-gallery or merged
2023-05-28 20:19:56 -04:00
fd47e70c92 feat(nodes): use higher precision timestamps in db 2023-05-28 20:19:56 -04:00
9317b42e5f feat(nodes, ui): wip image types 2023-05-28 20:19:56 -04:00
bdab73701f fix(ui): canvas images not added to staging 2023-05-28 20:19:56 -04:00
3ea5e78322 fix(nodes): fix list images route param descriptions 2023-05-28 20:19:56 -04:00
f609ee21a2 fix(ui): handle intermediates when fetching gallery 2023-05-28 20:19:56 -04:00
f51defeeb3 chore(ui): regen api client 2023-05-28 20:19:56 -04:00
ee0225f4ba fix(nodes): handle intermediates during images.get_many() 2023-05-28 20:19:56 -04:00
33a0af4637 feat(nodes): add nameservice
Currenly only used to make names for images, but when latents, conditioning, etc are managed in DB, will do the same for them.

Intended to eventually support custom naming schemes.
2023-05-28 20:19:56 -04:00
d37b08a7dd Merge branch 'main' into release/make-web-dist-startable 2023-05-28 19:46:09 -04:00
9a796364da Fixed controlnet preprocessors and controlnet handling in TextToLatents to work with revised Image services. 2023-05-26 21:44:00 -04:00
1ad4eb3a7b Progress toward improvement in fieldTemplateBuilder.ts getFieldType() 2023-05-26 21:44:00 -04:00
3767a453bb Added float to FIELD_TYPE_MAP ins constants.ts 2023-05-26 21:44:00 -04:00
b0892d30a4 Added mediapipe install requirement. Should be able to remove once controlnet_aux package adds mediapipe to its requirements. 2023-05-26 21:44:00 -04:00
d9b1e4a98c Added nodes for float params: ParamFloatInvocation and FloatCollectionOutput. Also added FloatOutput. 2023-05-26 21:44:00 -04:00
a4dec8c1d6 Fixed bug where MediapipFaceProcessorInvocation was ignoring max_faces and min_confidence params. 2023-05-26 21:44:00 -04:00
8960ceb98b Added Mediapipe image processor for use as ControlNet preprocessor.
Also hacked in ability to specify HF subfolder when loading ControlNet models from string.
2023-05-26 21:44:00 -04:00
be79d088c0 fix(nodes): controlnet input accepts list or single controlnet 2023-05-26 21:44:00 -04:00
009407ea3f fix(ui): fix node ui type hints 2023-05-26 21:44:00 -04:00
6999d28c7f chore(ui): regen api client 2023-05-26 21:44:00 -04:00
324e9eb74b Extended node-based ControlNet support to LatentsToLatentsInvocation. 2023-05-26 21:44:00 -04:00
56cff40362 Cleaning up after ControlNet refactor in TextToLatentsInvocation 2023-05-26 21:44:00 -04:00
2ba40c5e52 Refactored most of controlnet code into its own method to declutter TextToLatents.invoke(), and make upcoming integration with LatentsToLatents easier. 2023-05-26 21:44:00 -04:00
3ab147204c Fix to work with current stable release of controlnet_aux (v0.0.3). Turned of pre-processor params that were added post v0.0.3. Also change defaults for shuffle. 2023-05-26 21:44:00 -04:00
e4c89cba9c Switched CotrolNet node modelname input from free text to default list of popular ControlNet model names. 2023-05-26 21:44:00 -04:00
322ea84c4e Commented out ZoeDetector. Will re-instate once there's a controlnet-aux release that supports it. 2023-05-26 21:44:00 -04:00
f2b41c60ff Cleaning up prior to submitting ControlNet PR. Mostly turning off diagnostic printing. Also fixed error when there is no controlnet input. 2023-05-26 21:44:00 -04:00
754acec92f Added support for specifying which step iteration to start using
each ControlNet, and which step to end using each controlnet (specified as fraction of total steps)
2023-05-26 21:44:00 -04:00
11fc7e40a5 Refactored ControNet support to consolidate multiple parameters into data struct. Also redid how multiple controlnets are handled. 2023-05-26 21:44:00 -04:00
d15bb88eb2 Removed last bits of dtype and and device hardwiring from controlnet section 2023-05-26 21:44:00 -04:00
70ba36eefc Cleaning up mistakes after rebase. 2023-05-26 21:44:00 -04:00
7e70391c2b Cleaning up TextToLatent arg testing 2023-05-26 21:44:00 -04:00
e2a94be336 Added resizing of controlnet image based on noise latent. Fixes a tensor mismatch issue. 2023-05-26 21:44:00 -04:00
63a86eefb4 Refactored controlnet nodes: split out controlnet stuff into separate node, stripped controlnet stuff form image processing/analysis nodes. 2023-05-26 21:44:00 -04:00
b0727b9d47 Prep for splitting pre-processor and controlnet nodes 2023-05-26 21:44:00 -04:00
d96e727dd5 Added more preprocessor nodes for:
MidasDepth
      ZoeDepth
      MLSD
      NormalBae
      Pidi
      LineartAnime
      ContentShuffle
Removed pil_output options, ControlNet preprocessors should always output as PIL. Removed diagnostics and other general cleanup.
2023-05-26 21:44:00 -04:00
fe480886dc changes to base class for controlnet nodes 2023-05-26 21:44:00 -04:00
8031d1827b Refactored controlnet node to output ControlField that bundles control info. 2023-05-26 21:44:00 -04:00
b5acdb322d Switching to ControlField for output from controlnet nodes. 2023-05-26 21:44:00 -04:00
a4d1fe8819 Initial port of controlnet node support from generator-based TextToImageInvocation node to latent-based TextToLatentsInvocation node 2023-05-26 21:44:00 -04:00
10b7a58887 Added first controlnet preprocessor node for canny edge detection. 2023-05-26 21:44:00 -04:00
901a277959 Core implementation of ControlNet and MultiControlNet. 2023-05-26 21:44:00 -04:00
aaa093bef1 Fixed use of ControlNet control_weight parameter 2023-05-26 21:44:00 -04:00
bb96543d66 Added support for using multiple control nets. Unfortunately this breaks direct usage of Control node output port ==> TextToLatent control input port -- passing through a Collect node is now required. Working on fixing this... 2023-05-26 21:44:00 -04:00
a2a2cfa765 Added resizing of controlnet image based on noise latent. Fixes a tensor mismatch issue. 2023-05-26 21:44:00 -04:00
18e6a2b410 Refactored controlnet nodes: split out controlnet stuff into separate node, stripped controlnet stuff form image processing/analysis nodes. 2023-05-26 21:44:00 -04:00
db27263bc2 Prep for splitting pre-processor and controlnet nodes 2023-05-26 21:44:00 -04:00
0e027ec3ef Added more preprocessor nodes for:
MidasDepth
      ZoeDepth
      MLSD
      NormalBae
      Pidi
      LineartAnime
      ContentShuffle
Removed pil_output options, ControlNet preprocessors should always output as PIL. Removed diagnostics and other general cleanup.
2023-05-26 21:44:00 -04:00
5acbbeecaa Added HED, LineArt, and OpenPose ControlNet nodes 2023-05-26 21:44:00 -04:00
6ef2168b67 changes to base class for controlnet nodes 2023-05-26 21:44:00 -04:00
6d958a214c Refactored ControlNet nodes so they subclass from PreprocessedControlInvocation, and only need to override run_processor(image) (instead of reimplementing invoke()) 2023-05-26 21:44:00 -04:00
4ae4bf4ff9 Resolving conflicts in rebase to origin/main 2023-05-26 21:44:00 -04:00
fdef53b2de Switching to ControlField for output from controlnet nodes. 2023-05-26 21:44:00 -04:00
11bd038b9d Added first controlnet preprocessor node for canny edge detection. 2023-05-26 21:44:00 -04:00
768cfe3aab Core implementation of ControlNet and MultiControlNet. 2023-05-26 21:44:00 -04:00
c4277b0662 Moved to controlnet_aux v0.0.4, reinstated Zoe controlnet preprocessor. Also in pyproject.toml had to specify downgrade of timm to 0.6.13 _after_ controlnet-aux installs timm >= 0.9.2, because timm >0.6.13 breaks Zoe preprocessor. 2023-05-26 21:44:00 -04:00
020f3ccf07 fix(nodes): controlnet input accepts list or single controlnet 2023-05-26 21:44:00 -04:00
7467fa5e57 fix(ui): fix node ui type hints 2023-05-26 21:44:00 -04:00
e19ef7ed2f fix(ui): add control field type 2023-05-26 21:44:00 -04:00
71003be6b8 fix(ui): add value to conditioning field 2023-05-26 21:44:00 -04:00
c1dbafc2df chore(ui): regen api client 2023-05-26 21:44:00 -04:00
dcebd71381 Extended node-based ControlNet support to LatentsToLatentsInvocation. 2023-05-26 21:44:00 -04:00
d855a65e73 Cleaning up after ControlNet refactor in TextToLatentsInvocation 2023-05-26 21:44:00 -04:00
a9007c7e0f Refactored most of controlnet code into its own method to declutter TextToLatents.invoke(), and make upcoming integration with LatentsToLatents easier. 2023-05-26 21:44:00 -04:00
af60304f97 Fix to work with current stable release of controlnet_aux (v0.0.3). Turned of pre-processor params that were added post v0.0.3. Also change defaults for shuffle. 2023-05-26 21:44:00 -04:00
6de241eead Switched CotrolNet node modelname input from free text to default list of popular ControlNet model names. 2023-05-26 21:44:00 -04:00
51032dc0b2 Commented out ZoeDetector. Will re-instate once there's a controlnet-aux release that supports it. 2023-05-26 21:44:00 -04:00
9ec3d2bc0c Added dependency on controlnet-aux v0.0.3 2023-05-26 21:44:00 -04:00
297931f5d9 Cleaning up prior to submitting ControlNet PR. Mostly turning off diagnostic printing. Also fixed error when there is no controlnet input. 2023-05-26 21:44:00 -04:00
f613c073c1 Added support for specifying which step iteration to start using
each ControlNet, and which step to end using each controlnet (specified as fraction of total steps)
2023-05-26 21:44:00 -04:00
63d248622c Refactored ControNet support to consolidate multiple parameters into data struct. Also redid how multiple controlnets are handled. 2023-05-26 21:44:00 -04:00
48485fe92f Removed last bits of dtype and and device hardwiring from controlnet section 2023-05-26 21:44:00 -04:00
07726af703 Cleaning up mistakes after rebase. 2023-05-26 21:44:00 -04:00
ad1004b485 Cleaning up TextToLatent arg testing 2023-05-26 21:44:00 -04:00
0096fb2790 Added resizing of controlnet image based on noise latent. Fixes a tensor mismatch issue. 2023-05-26 21:44:00 -04:00
9c8c2e49d6 Refactored controlnet nodes: split out controlnet stuff into separate node, stripped controlnet stuff form image processing/analysis nodes. 2023-05-26 21:44:00 -04:00
2005a96847 Prep for splitting pre-processor and controlnet nodes 2023-05-26 21:44:00 -04:00
00a8d60c1b Added more preprocessor nodes for:
MidasDepth
      ZoeDepth
      MLSD
      NormalBae
      Pidi
      LineartAnime
      ContentShuffle
Removed pil_output options, ControlNet preprocessors should always output as PIL. Removed diagnostics and other general cleanup.
2023-05-26 21:44:00 -04:00
3aa182390a changes to base class for controlnet nodes 2023-05-26 21:44:00 -04:00
e44f1d6d4e Refactored controlnet node to output ControlField that bundles control info. 2023-05-26 21:44:00 -04:00
dfdf8e2ead Switching to ControlField for output from controlnet nodes. 2023-05-26 21:44:00 -04:00
3a645c4e80 Initial port of controlnet node support from generator-based TextToImageInvocation node to latent-based TextToLatentsInvocation node 2023-05-26 21:44:00 -04:00
113129daf9 Added first controlnet preprocessor node for canny edge detection. 2023-05-26 21:44:00 -04:00
940e3b6635 Core implementation of ControlNet and MultiControlNet. 2023-05-26 21:44:00 -04:00
7fb29dabff Fixed lint-ish formatting error 2023-05-26 21:44:00 -04:00
714ad6dbb8 Fixed use of ControlNet control_weight parameter 2023-05-26 21:44:00 -04:00
c0863fa20f Added support for using multiple control nets. Unfortunately this breaks direct usage of Control node output port ==> TextToLatent control input port -- passing through a Collect node is now required. Working on fixing this... 2023-05-26 21:44:00 -04:00
78b0b37ba6 More rebase repair. 2023-05-26 21:44:00 -04:00
5d5cdc7716 Added resizing of controlnet image based on noise latent. Fixes a tensor mismatch issue. 2023-05-26 21:44:00 -04:00
93cd818f6a Refactored controlnet nodes: split out controlnet stuff into separate node, stripped controlnet stuff form image processing/analysis nodes. 2023-05-26 21:44:00 -04:00
598a628790 Prep for splitting pre-processor and controlnet nodes 2023-05-26 21:44:00 -04:00
f3666eda63 Added more preprocessor nodes for:
MidasDepth
      ZoeDepth
      MLSD
      NormalBae
      Pidi
      LineartAnime
      ContentShuffle
Removed pil_output options, ControlNet preprocessors should always output as PIL. Removed diagnostics and other general cleanup.
2023-05-26 21:44:00 -04:00
754017b59e Added an additional "raw_processed_image" output port to controlnets, mainly so could route ImageField to a ShowImage node 2023-05-26 21:44:00 -04:00
21251ce12c Added HED, LineArt, and OpenPose ControlNet nodes 2023-05-26 21:44:00 -04:00
dc12fa6cd6 changes to base class for controlnet nodes 2023-05-26 21:44:00 -04:00
f2f4c37f19 Refactored ControlNet nodes so they subclass from PreprocessedControlInvocation, and only need to override run_processor(image) (instead of reimplementing invoke()) 2023-05-26 21:44:00 -04:00
0864fca641 Resolving conflicts in rebase to origin/main 2023-05-26 21:44:00 -04:00
5e4c0217c7 Switching to ControlField for output from controlnet nodes. 2023-05-26 21:44:00 -04:00
78cd106c23 Initial port of controlnet node support from generator-based TextToImageInvocation node to latent-based TextToLatentsInvocation node 2023-05-26 21:44:00 -04:00
6ed0efa938 Added first controlnet preprocessor node for canny edge detection. 2023-05-26 21:44:00 -04:00
ca0669c337 Resolving rebase conflict 2023-05-26 21:44:00 -04:00
b59a749627 Added example of using ControlNet with legacy Txt2Img generator 2023-05-26 21:44:00 -04:00
a91dee87d0 Added support for ControlNet and MultiControlNet to legacy non-nodal Txt2Img in backend/generator. Although backend/generator will likely disappear by v3.x, right now they are very useful for testing core ControlNet and MultiControlNet functionality while node codebase is rapidly evolving. 2023-05-26 21:44:00 -04:00
5ff98a4179 Core implementation of ControlNet and MultiControlNet. 2023-05-26 21:44:00 -04:00
36b2f12219 Merge branch 'main' into release/make-web-dist-startable 2023-05-26 12:56:24 -04:00
5569f205ee Update CODEOWNERS 2023-05-26 08:59:10 -04:00
a76cf8aab2 Update CODEOWNERS 2023-05-26 08:59:10 -04:00
5c0f0d1808 Merge branch 'main' into lstein/logging-improvements 2023-05-26 08:57:17 -04:00
951900a86a Merge branch 'main' into lstein/config-management-fixes 2023-05-26 08:56:41 -04:00
582f516fef Merge branch 'main' into release/make-web-dist-startable 2023-05-26 18:06:38 +10:00
a25bae2545 fix(ui): tweak log levels 2023-05-26 18:06:08 +10:00
0ea35b1e3d feat(ui): improve session canceled handling 2023-05-26 18:06:08 +10:00
c6f935bf1a feat(ui): improve gallery page handling 2023-05-26 18:06:08 +10:00
96b4d35d43 fix(ui): fix uploads not loading more images correctly after generation 2023-05-26 18:06:08 +10:00
7b0938e7e4 feat(ui): add comments for weird stuff 2023-05-26 18:06:08 +10:00
249522b568 fix(ui): fix gallery not loading more images correctly after generation 2023-05-26 18:06:08 +10:00
39088e42cc fix(ui): remove console logs 2023-05-26 18:06:08 +10:00
30e0033ebe fix(ui): fix results not added to gallery 2023-05-26 18:06:08 +10:00
b599c40099 feat(ui): improve session invoked handling 2023-05-26 18:06:08 +10:00
8f190169db feat(ui): improve session creation handling 2023-05-26 18:06:08 +10:00
1d4d705795 feat(ui): improve image urls handling 2023-05-26 18:06:08 +10:00
b3f71b3078 feat(ui): improve image metadata handling 2023-05-26 18:06:08 +10:00
6059db4f15 feat(ui): improve image delete handling 2023-05-26 18:06:08 +10:00
0d5f44b153 feat(ui): improve image upload handling 2023-05-26 18:06:08 +10:00
17164a37a8 fix(ui): fix gallery auto switch 2023-05-26 18:06:08 +10:00
f88ccabe30 fix(ui): gallery not loading on page load 2023-05-26 18:06:08 +10:00
e1c85f1234 Merge branch 'main' into release/make-web-dist-startable 2023-05-26 18:04:09 +10:00
f50293920e correct typo in tiled_vae field definition 2023-05-25 23:29:16 -04:00
1e2db3a17f hook tiled_decode up to configuration 2023-05-25 23:28:15 -04:00
57a3eb3652 feat(ui): unset progress image inside invocationComplete listener 2023-05-26 13:25:50 +10:00
82a8972bde create listener for imageMetdataReceived to swap our progressImage 2023-05-26 13:25:50 +10:00
497a885c85 Merge branch 'main' into release/make-web-dist-startable 2023-05-25 22:49:18 -04:00
4d9f55d0f6 replace deleted get_root() 2023-05-25 22:48:50 -04:00
5f8f51436a merge with main; fix conflicts 2023-05-25 22:40:45 -04:00
0c3b4bb70d chore(ui): regen api client 2023-05-25 22:17:14 -04:00
33e13820fc feat(nodes): remove meta node field; use individual is_intermediate field instead
as suggested by @Kyle0654
2023-05-25 22:17:14 -04:00
43d991cfdb fix(ui): fix incorrect comment 2023-05-25 22:17:14 -04:00
291e9cf14b fix(nodes): add is_intermediate to all image-outputting nodes 2023-05-25 22:17:14 -04:00
a2de5c9963 feat(ui): change intermediates handling
- Update the canvas graph generation to flag its uploaded init and mask images as `intermediate`.
- During canvas setup, hit the update route to associate the uploaded images with the session id.
- Organize the socketio and RTK listener middlware better. Needed to facilitate the updated canvas logic.
- Add a new action `sessionReadyToInvoke`. The `sessionInvoked` action is *only* ever run in response to this event. This lets us do whatever complicated setup (eg canvas) and explicitly invoking. Previously, invoking was tied to the socket subscribe events.
- Some minor tidying.
2023-05-25 22:17:14 -04:00
5025f84627 chore(ui): regen api client 2023-05-25 22:17:14 -04:00
d2c8a53c55 feat(nodes): change intermediates handling
- `ImageType` is now restricted to `results` and `uploads`.
- Add a reserved `meta` field to nodes to hold the `is_intermediate` boolean. We can extend it in the future to support other node `meta`.
- Add a `is_intermediate` column to the `images` table to hold this. (When `latents`, `conditioning` etc are added to the DB, they will also have this column.)
- All nodes default to `*not* intermediate`. Nodes must explicitly be marked `intermediate` for their outputs to be `intermediate`.
- When building a graph, you can set `node.meta.is_intermediate=True` and it will be handled as an intermediate.
- Add a new `update()` method to the `ImageService`, and a route to call it. Updates have a strict model, currently only `session_id` and `image_category` may be updated.
- Add a new `update()` method to the `ImageRecordStorageService` to update the image record using the model.
2023-05-25 22:17:14 -04:00
5659d10778 remove unused function get_root() 2023-05-25 22:06:37 -04:00
46cab81d6f fix missing web_dir 2023-05-25 22:01:48 -04:00
dd157bce85 Merge branch 'main' into release/make-web-dist-startable 2023-05-25 21:52:05 -04:00
2f25dd7d0d Merge branch 'main' into lstein/config-management-fixes 2023-05-25 21:10:12 -04:00
e56965ad76 documentation tweaks; fixed initialization in a couple more places 2023-05-25 21:10:00 -04:00
2273b3a8c8 fix potential race condition in config system 2023-05-25 20:41:26 -04:00
05fb0ac2b2 Update latent.py 2023-05-26 10:27:33 +10:00
d4acd49ee3 Update generate.py 2023-05-26 10:27:33 +10:00
d98868e524 Update generationSlice.ts to change Default Scheduler 2023-05-26 10:27:33 +10:00
93bb27f2c7 fix gallery navigation 2023-05-26 10:01:06 +10:00
a4c44edf8d more use parameter fixes 2023-05-26 10:01:06 +10:00
1e94d7739a fix metadata references, add support for negative_conditioning syntax 2023-05-26 10:01:06 +10:00
9110838fe4 Merge branch 'main' into release/make-web-dist-startable 2023-05-25 19:06:09 -04:00
ca7b267326 raise error if syslogging requested and syslog lib not available 2023-05-25 10:10:46 -04:00
7f5992d6a5 Merge branch 'lstein/logging-improvements' of github.com:invoke-ai/InvokeAI into lstein/logging-improvements 2023-05-25 09:39:56 -04:00
88776fb2de get invokeai_configure working again 2023-05-25 09:39:45 -04:00
34f567abd4 Merge branch 'main' into lstein/logging-improvements 2023-05-25 08:48:47 -04:00
b87f3043ae add logging configuration 2023-05-24 23:57:15 -04:00
3829ffbe66 fix(tests): add --use_memory_db flag; use it in tests 2023-05-25 12:12:31 +10:00
ad619ae880 fix(tests): log db_location 2023-05-25 12:12:31 +10:00
d22ebe08be fix(tests): log db_location 2023-05-25 12:12:31 +10:00
ee0c6ad86e fix(cli): fix invocation services for cli 2023-05-25 12:12:31 +10:00
96adb56633 fix(tests): fix missing services in tests; fix ImageField instantiation 2023-05-25 12:12:31 +10:00
3000436121 chore(nodes): remove unused imports 2023-05-25 12:12:31 +10:00
37cdd91f5d fix(nodes): use forward declarations for InvocationServices
Also use `TYPE_CHECKING` to get IDE hints.
2023-05-25 12:12:31 +10:00
6f3c6ddf3f Update 020_INSTALL_MANUAL.md
Corrected a markdown formatting error (missing backtick).
2023-05-24 11:33:32 -04:00
0bfbda512d build(nodes): remove references to metadata service in tests 2023-05-24 11:30:47 -04:00
295b98a13c build(nodes): remove outdated metadata test
I will add tests for the new service soon
2023-05-24 11:30:47 -04:00
ff6b345d45 fix(nodes): rebase fixes 2023-05-24 11:30:47 -04:00
1fb307abf4 feat(nodes): restore canvas functionality (non-latents) 2023-05-24 11:30:47 -04:00
29c952dcf6 feat(ui): restore canvas functionality 2023-05-24 11:30:47 -04:00
010f63a50d feat(ui): misc tidy 2023-05-24 11:30:47 -04:00
068bbe3a39 fix(ui): fix uploads tab in gallery 2023-05-24 11:30:47 -04:00
ad39680feb feat(nodes): wip inpainting nodes prep 2023-05-24 11:30:47 -04:00
1e0ae8404c feat(nodes): comment out seamless
this will be a model config feature when model manager is ready
2023-05-24 11:30:47 -04:00
460d555a3d feat(nodes): add image mul, channel, convert nodes
also make img node names consistent
2023-05-24 11:30:47 -04:00
66ad04fcfc feat(nodes): add mask image category 2023-05-24 11:30:47 -04:00
c7c0836721 feat(ui): migrate linear workflows to latents 2023-05-24 11:30:47 -04:00
d2c223de8f feat(nodes): move fully* to new images service
* except i haven't rebuilt inpaint in latents
2023-05-24 11:30:47 -04:00
dd16f788ed fix(nodes): fix RangeOfSizeInvocation off-by-one error 2023-05-24 11:30:47 -04:00
b25c1af018 feat(nodes): add RangeOfSizeInvocation
The `RangeInvocation` is a simple wrapper around `range()`, but you must provide `stop > start`.

`RangeOfSizeInvocation` replaces the `stop` parameter with `size`, so that you can just provide the `start` and `step` and get a range of `size` length.
2023-05-24 11:30:47 -04:00
8f393b64b8 feat(nodes): add seed validator
If `seed>SEED_MAX`, we can still continue if we parse the seed as `seed % SEED_MAX`.
2023-05-24 11:30:47 -04:00
55b3193629 fix(nodes): add RangeInvocation validator
`stop` must be greater than `start`.
2023-05-24 11:30:47 -04:00
6f78c073ed fix(ui): fix uploads & other bugs 2023-05-24 11:30:47 -04:00
c406be6f4f fix(ui): fix image deletion 2023-05-24 11:30:47 -04:00
aeaf3737aa fix(ui): fix gallery bugs 2023-05-24 11:30:47 -04:00
23d9d58c08 fix(nodes): fix bugs with serving images
When returning a `FileResponse`, we must provide a valid path, else an exception is raised outside the route handler.

Add the `validate_path` method back to the service so we can validate paths before returning the file.

I don't like this but apparently this is just how `starlette` and `fastapi` works with `FileResponse`.
2023-05-24 11:30:47 -04:00
4c331a5d7e chore(ui): regen api client 2023-05-24 11:30:47 -04:00
035425ef24 feat(nodes): address feedback
- Address database feedback:
  - Remove all the extraneous tables. Only an `images` table now:
  - `image_type` and `image_category` are unrestricted strings. When creating images, the provided values are checked to ensure they are a valid type and category.
  - Add `updated_at` and `deleted_at` columns. `deleted_at` is currently unused.
  - Use SQLite's built-in timestamp features to populate these. Add a trigger to update `updated_at` when the row is updated. Currently no way to update a row.
  - Rename the `id` column in `images` to `image_name`
- Rename `ImageCategory.IMAGE` to `ImageCategory.GENERAL`
- Move all exceptions outside their base classes to make them more portable.
- Add `width` and `height` columns to the database. These store the actual dimensions of the image file, whereas the metadata's `width` and `height` refer to the respective generation parameters and are nullable.
- Make `deserialize_image_record` take a `dict` instead of `sqlite3.Row`
- Improve comments throughout
- Tidy up unused code/files and some minor organisation
2023-05-24 11:30:47 -04:00
021e5a2aa3 feat(nodes): improve metadata service comments 2023-05-24 11:30:47 -04:00
7a1de3887e feat(ui): wip update UI for migration 2023-05-24 11:30:47 -04:00
4a7a5234df fix(ui): fix image nodes losing image 2023-05-24 11:30:47 -04:00
6aebe1614d feat(ui): wip use new images service 2023-05-24 11:30:47 -04:00
74292eba28 chore(ui): regen api client 2023-05-24 11:30:47 -04:00
c31ff364ab fix(nodes): tidy images service 2023-05-24 11:30:47 -04:00
f310a39381 feat(nodes): finalize image routes 2023-05-24 11:30:47 -04:00
5a7e611e0a fix(nodes): fix image url 2023-05-24 11:30:47 -04:00
4e29a751d8 feat(ui): add POC image record fetching 2023-05-24 11:30:47 -04:00
3f94f81acd chore(ui): regen api client 2023-05-24 11:30:47 -04:00
5de3c41d19 feat(nodes): add metadata handling 2023-05-24 11:30:47 -04:00
f071b03ceb chore(ui): regen api client 2023-05-24 11:30:47 -04:00
b9375186a5 feat(nodes): consolidate image routers 2023-05-24 11:30:47 -04:00
11bd932cba feat(nodes): revert invocation_complete url hack 2023-05-24 11:30:47 -04:00
b77ccfaf32 chore(ui): regen api client 2023-05-24 11:30:47 -04:00
96653eebb6 build(ui): do not export schemas on api client generation 2023-05-24 11:30:47 -04:00
60d25f105f fix(nodes): restore metadata traverser 2023-05-24 11:30:47 -04:00
734b653a5f fix(nodes): add base images router 2023-05-24 11:30:47 -04:00
52c9e6ec91 feat(nodes): organise/tidy 2023-05-24 11:30:47 -04:00
c0f132e41a hack(nodes): hack to get image urls in the invocation complete event 2023-05-24 11:30:47 -04:00
cc1160a43a feat(nodes): streamline urlservice 2023-05-24 11:30:47 -04:00
adde8450bc fix(nodes): remove bad import 2023-05-24 11:30:47 -04:00
5bf9891553 feat(nodes): it works 2023-05-24 11:30:47 -04:00
22c34c343a feat(nodes): fix types for InvocationServices 2023-05-24 11:30:47 -04:00
f7804f6126 feat(nodes): add logger to images service 2023-05-24 11:30:47 -04:00
d14b02e93f feat(logger): fix logger type issues 2023-05-24 11:30:47 -04:00
1b75d899ae feat(nodes): wip image storage implementation 2023-05-24 11:30:47 -04:00
d4aa79acd7 fix(nodes): use save instead of set
`set` is a python builtin
2023-05-24 11:30:47 -04:00
33d199c007 feat(nodes): image records router 2023-05-24 11:30:47 -04:00
9c89d3452c feat(nodes): add high-level images service
feat(nodes): add ResultsServiceABC & SqliteResultsService

**Doesn't actually work bc of circular imports. Can't even test it.**

- add a base class for ResultsService and SQLite implementation
- use `graph_execution_manager` `on_changed` callback to keep `results` table in sync

fix(nodes): fix results service bugs

chore(ui): regen api

fix(ui): fix type guards

feat(nodes): add `result_type` to results table, fix types

fix(nodes): do not shadow `list` builtin

feat(nodes): add results router

It doesn't work due to circular imports still

fix(nodes): Result class should use outputs classes, not fields

feat(ui): crude results router

fix(ui): send to canvas in currentimagebuttons not working

feat(nodes): add core metadata builder

feat(nodes): add design doc

feat(nodes): wip latents db stuff

feat(nodes): images_db_service and resources router

feat(nodes): wip images db & router

feat(nodes): update image related names

feat(nodes): update urlservice

feat(nodes): add high-level images service
2023-05-24 11:30:47 -04:00
fb0b63c580 fix(nodes): fix seam painting
The problem was the same seed was getting used for the seam painting pass, causing the fried look.

Same issue as if you do img2img on a txt2img with the same seed/prompt.

Thanks to @hipsterusername for teaming up to debug this. We got pretty deep into the weeds.
2023-05-25 00:58:03 +10:00
bb2c6e5925 Merge branch 'main' into release/make-web-dist-startable 2023-05-24 10:55:51 -04:00
928caff2a6 fix: attempt to fix actions (#3454)
i think this conditional needs to be removed.
2023-05-25 02:37:39 +12:00
670c79f2c7 fix: attempt to fix actions
i think this conditional needs to be removed.
2023-05-25 00:31:48 +10:00
d6efb98953 build: fix test-invoke-pip.yml
- Restore conditional which ensures tests are only run on `main`
- Fix `yaml` syntax error
2023-05-24 21:48:12 +10:00
19da795274 fix(ui): send to canvas in currentimagebuttons not working 2023-05-24 21:46:58 +10:00
454ba9b893 add crossOrigin = anonymous attribute to konva image 2023-05-24 10:32:41 +10:00
8e419a4f97 Revert weak references as can be done without it 2023-05-23 04:29:40 +03:00
2533209326 Rewrite cache to weak references 2023-05-23 03:48:22 +03:00
d2dc1ed26f make InvokeAI package installable
This commit makes InvokeAI 3.0 to be installable via PyPi.org and the
installer script.

Main changes.

1. Move static web pages into `invokeai/frontend/web` and modify the
API to look for them there. This allows pip to copy the files into the
distribution directory so that user no longer has to be in repo root
to launch.

2. Update invoke.sh and invoke.bat to launch the new web application
properly. This also changes the wording for launching the CLI from
"generate images" to "explore the InvokeAI node system," since I would
not recommend using the CLI to generate images routinely.

3. Fix a bug in the checkpoint converter script that was identified
during testing.

4. Better error reporting when checkpoint converter fails.

5. Rebuild front end.
2023-05-22 17:51:47 -04:00
d4fb16825e move static into invokeai.frontend.web directory for dist install 2023-05-22 16:48:17 -04:00
165c1adcf8 Merge branch 'main' into lstein/new-model-manager 2023-05-22 21:51:07 +03:00
650d69ef5b added optional middleware prop and new actions needed (#3437)
* added optional middleware prop and new actions needed

* accidental import

* make middleware an array

---------

Co-authored-by: Mary Hipp <maryhipp@Marys-MacBook-Air.local>
2023-05-22 08:16:11 -04:00
ff0e79fa9a add id for invoke button 2023-05-19 21:44:31 +10:00
127b54f812 add some IDs 2023-05-19 21:44:31 +10:00
bdf33f13b3 fix bad merge in compel 2023-05-18 18:08:45 -04:00
27241cdde1 port more globals changes over 2023-05-18 17:17:45 -04:00
259d6ec90d fixup cachedir call 2023-05-18 14:52:16 -04:00
a77c4c87b2 fixed logic error in resolution of model path 2023-05-18 14:35:34 -04:00
d96175d127 resolve some undefined symbols in model_cache 2023-05-18 14:31:47 -04:00
b1a99d772c added method to convert vaes 2023-05-18 13:31:11 -04:00
fd82763412 Model manager draft 2023-05-18 03:56:52 +03:00
e971a7f35c when migrating models.yaml, rename original models.yaml.orig 2023-05-16 22:37:53 -04:00
6ab84741a0 fix(nodes): make ModelsList an enum-keyed dict
The `ModelsList` OpenAPI schema is generated as being keyed by plain strings. This means that API consumers do not know the shape of the dict. It _should_ be keyed by the `SDModelType` enum.

Unfortunately, `fastapi` does not actually handle this correctly yet; it still generates the schema with plain string keys.

Adding this anyways though in hopes that it will be resolved upstream and we can get the correct schema. Until then, I'll implement the (simple but annoying) logic on the frontend.

https://github.com/pydantic/pydantic/issues/4393
2023-05-16 15:02:58 +10:00
cd16857f38 fix None in model_type 2023-05-16 00:13:44 -04:00
1442f1cb8d change model filter to None in second place 2023-05-16 00:03:57 -04:00
eea0d6f7bc default to no filter in list_models() 2023-05-15 23:52:29 -04:00
4fe94a9315 list_models() now returns a dict of {type,{name: info}} 2023-05-15 23:44:08 -04:00
c8f765cc06 improve debugging messages 2023-05-14 18:29:55 -04:00
b9e9087dbe do not manage GPU for pipelines if sequential_offloading is True 2023-05-14 18:09:38 -04:00
63e465eb5c tweaks to get_model() behavior
1. If an external VAE is specified in config file, then
   get_model(submodel=vae) will return the external VAE, not the one
   burnt into the parent diffusers pipeline.

2. The mechanism in (1) is generalized such that you can now have
   "unet:", "text_encoder:" and similar stanzas in the config file.
   Valid formats of these subsections:

       unet:
          repo_id: foo/bar

       unet:
          path: /path/to/local/folder

       unet:
          repo_id: foo/bar
	  subfolder: unet

    In the near future, these will also be used to attach external
    parts to the pipeline, generalizing VAE behavior.

3. Accommodate callers (i.e. the WebUI) that are passing the
   model key ("diffusers/stable-diffusion-1.5") to get_model()
   instead of the tuple of model_name and model_type.

4. Fixed bug in VAE model attaching code.

5. Rebuilt web front end.
2023-05-14 16:50:59 -04:00
426f4eaf7e adjusted regression tests to work with new SDModelTypes 2023-05-13 22:29:33 -04:00
baf5451fa0 Merge branch 'main' into lstein/new-model-manager 2023-05-13 22:01:34 -04:00
b31a6ff605 fix reversed args in _model_key() call 2023-05-13 21:11:06 -04:00
1f602e6143 Fix - apply precision to text_encoder 2023-05-14 03:46:13 +03:00
039fa73269 Change SDModelType enum to string, fixes(model unload negative locks count, scheduler load error, saftensors convert, wrong logic in del_model, wrong parse metadata in web) 2023-05-14 03:06:26 +03:00
2204e47596 allow submodels to be fetched independent of parent pipeline 2023-05-13 16:54:47 -04:00
d8b1f29066 proxy SDModelInfo so that it can be used directly as context 2023-05-13 16:29:18 -04:00
b23c9f1da5 get Tuple type hint syntax right 2023-05-13 14:59:21 -04:00
5e8e3cf464 correct typos in model_manager_service 2023-05-13 14:55:59 -04:00
72967bf118 convert add_model(), del_model(), list_models() etc to use bifurcated names 2023-05-13 14:44:44 -04:00
bc96727cbe Rewrite latent nodes to new model manager 2023-05-13 16:08:03 +03:00
3b2a054f7a Add model loader node; unet, clip, vae fields; change compel node to clip field 2023-05-13 04:37:20 +03:00
131145eab1 A big refactor of model manager(according to IMHO) 2023-05-12 23:13:34 +03:00
4492044d29 Redo compel node to separate model loading 2023-05-12 23:09:33 +03:00
5431dd5f50 Fix event args 2023-05-12 23:08:03 +03:00
79fecba274 Fix model manager initialization in web ui 2023-05-12 23:05:08 +03:00
2ef79b8bf3 fix bug in persistent model scheme 2023-05-12 00:14:56 -04:00
11ecf438f5 latents.py converted to use model manager service; events emitted 2023-05-11 23:33:24 -04:00
df5b968954 model manager now running as a service 2023-05-11 21:24:29 -04:00
8ad8c5c67a resolve conflicts with main 2023-05-11 00:19:20 -04:00
590942edd7 Merge branch 'main' into lstein/new-model-manager 2023-05-11 00:16:03 -04:00
4627910c5d added a wrapper model_manager_service and model events 2023-05-11 00:09:19 -04:00
fa6a580452 merge with main 2023-05-10 00:03:32 -04:00
99c692f397 check that model name matches format 2023-05-09 23:46:59 -04:00
3d85e769ce clean up ckpt handling
- remove legacy ckpt loading code from model_cache
- added placeholders for lora and textual inversion model loading
2023-05-09 22:44:58 -04:00
9cb962cad7 ckpt model conversion now done in ModelCache 2023-05-08 23:39:44 -04:00
a108155544 added StALKeR779's great model size calculating routine 2023-05-08 21:47:03 -04:00
c15b49c805 implement StALKeR7779 requested API for fetching submodels 2023-05-07 23:18:17 -04:00
fd63e36822 optimize subfolder so that it returns submodel if parent is in RAM 2023-05-07 21:39:11 -04:00
4649920074 adjust t2i to work with new model structure 2023-05-07 19:06:49 -04:00
667171ed90 cap model cache size using bytes, not # models 2023-05-07 18:07:28 -04:00
647ffb2a0f defined abstract baseclass for model manager service 2023-05-06 22:41:19 -04:00
05a27bda5e generalize model loading support, include loras/embeds 2023-05-06 15:58:44 -04:00
a8cfa3565c Merge branch 'lstein/new-model-manager' of github.com:invoke-ai/InvokeAI into lstein/new-model-manager 2023-05-06 08:14:15 -04:00
e0214a32bc mostly ported to new manager API; needs testing 2023-05-06 00:44:12 -04:00
af8c7c7d29 model manager rewritten to use model_cache; API changed! 2023-05-05 19:32:28 -04:00
a4e36bc02a when model is forcibly moved into RAM update loaded_models set 2023-05-04 23:28:03 -04:00
2e9bec15e7 Merge branch 'main' into lstein/new-model-manager 2023-05-04 23:19:38 -04:00
68bc0112fa implement lazy GPU offloading and ref counting 2023-05-04 23:15:32 -04:00
a273bdbdc1 Merge branch 'main' into lstein/new-model-manager 2023-05-03 18:09:29 -04:00
8a0ec0fa0f Merge branch 'main' into lstein/new-model-manager 2023-05-03 13:30:50 -04:00
e1fed52c66 work on model cache and its regression test finished 2023-05-03 12:38:18 -04:00
bb959448c1 implement hashing for local & remote models 2023-05-02 16:52:27 -04:00
2e2abf6ea6 caching of subparts working 2023-05-01 22:57:30 -04:00
956ad6bcf5 add redesigned model cache for diffusers & transformers 2023-04-28 00:41:52 -04:00
723 changed files with 29293 additions and 24065 deletions

14
.github/CODEOWNERS vendored
View File

@ -2,7 +2,7 @@
/.github/workflows/ @lstein @blessedcoolant
# documentation
/docs/ @lstein @tildebyte @blessedcoolant
/docs/ @lstein @blessedcoolant @hipsterusername
/mkdocs.yml @lstein @blessedcoolant
# nodes
@ -18,17 +18,17 @@
/invokeai/version @lstein @blessedcoolant
# web ui
/invokeai/frontend @blessedcoolant @psychedelicious @lstein
/invokeai/backend @blessedcoolant @psychedelicious @lstein
/invokeai/frontend @blessedcoolant @psychedelicious @lstein @maryhipp
/invokeai/backend @blessedcoolant @psychedelicious @lstein @maryhipp
# generation, model management, postprocessing
/invokeai/backend @damian0815 @lstein @blessedcoolant @jpphoto @gregghelt2
/invokeai/backend @damian0815 @lstein @blessedcoolant @jpphoto @gregghelt2 @StAlKeR7779
# front ends
/invokeai/frontend/CLI @lstein
/invokeai/frontend/install @lstein @ebr
/invokeai/frontend/merge @lstein @blessedcoolant @hipsterusername
/invokeai/frontend/training @lstein @blessedcoolant @hipsterusername
/invokeai/frontend/web @psychedelicious @blessedcoolant
/invokeai/frontend/merge @lstein @blessedcoolant
/invokeai/frontend/training @lstein @blessedcoolant
/invokeai/frontend/web @psychedelicious @blessedcoolant @maryhipp

View File

@ -80,7 +80,7 @@ jobs:
uses: actions/checkout@v3
- name: set test prompt to main branch validation
run:echo "TEST_PROMPTS=tests/validate_pr_prompt.txt" >> ${{ matrix.github-env }}
run: echo "TEST_PROMPTS=tests/validate_pr_prompt.txt" >> ${{ matrix.github-env }}
- name: setup python
uses: actions/setup-python@v4
@ -125,6 +125,7 @@ jobs:
--no-nsfw_checker
--precision=float32
--always_use_cpu
--use_memory_db
--outdir ${{ env.INVOKEAI_OUTDIR }}/${{ matrix.python-version }}/${{ matrix.pytorch }}
--from_file ${{ env.TEST_PROMPTS }}

View File

@ -43,6 +43,23 @@ _Note: InvokeAI is rapidly evolving. Please use the
[Issues](https://github.com/invoke-ai/InvokeAI/issues) tab to report bugs and make feature
requests. Be sure to use the provided templates. They will help us diagnose issues faster._
## FOR DEVELOPERS - MIGRATING TO THE 3.0.0 MODELS FORMAT
The models directory and models.yaml have changed. To migrate to the
new layout, please follow this recipe:
1. Run `python scripts/migrate_models_to_3.0.py <path_to_root_directory>
2. This will create a new models directory named `models-3.0` and a
new config directory named `models.yaml-3.0`, both in the current
working directory. If you prefer to name them something else, pass
the `--dest-directory` and/or `--dest-yaml` arguments.
3. Check that the new models directory and yaml file look ok.
4. Replace the existing directory and file, keeping backup copies just in
case.
<div align="center">
![canvas preview](https://github.com/invoke-ai/InvokeAI/raw/main/docs/assets/canvas_preview.png)

View File

@ -1,164 +0,0 @@
@echo off
@rem This script will install git (if not found on the PATH variable)
@rem using micromamba (an 8mb static-linked single-file binary, conda replacement).
@rem For users who already have git, this step will be skipped.
@rem Next, it'll download the project's source code.
@rem Then it will download a self-contained, standalone Python and unpack it.
@rem Finally, it'll create the Python virtual environment and preload the models.
@rem This enables a user to install this project without manually installing git or Python
@rem change to the script's directory
PUSHD "%~dp0"
set "no_cache_dir=--no-cache-dir"
if "%1" == "use-cache" (
set "no_cache_dir="
)
echo ***** Installing InvokeAI.. *****
@rem Config
set INSTALL_ENV_DIR=%cd%\installer_files\env
@rem https://mamba.readthedocs.io/en/latest/installation.html
set MICROMAMBA_DOWNLOAD_URL=https://github.com/cmdr2/stable-diffusion-ui/releases/download/v1.1/micromamba.exe
set RELEASE_URL=https://github.com/invoke-ai/InvokeAI
set RELEASE_SOURCEBALL=/archive/refs/heads/main.tar.gz
set PYTHON_BUILD_STANDALONE_URL=https://github.com/indygreg/python-build-standalone/releases/download
set PYTHON_BUILD_STANDALONE=20221002/cpython-3.10.7+20221002-x86_64-pc-windows-msvc-shared-install_only.tar.gz
set PACKAGES_TO_INSTALL=
call git --version >.tmp1 2>.tmp2
if "%ERRORLEVEL%" NEQ "0" set PACKAGES_TO_INSTALL=%PACKAGES_TO_INSTALL% git
@rem Cleanup
del /q .tmp1 .tmp2
@rem (if necessary) install git into a contained environment
if "%PACKAGES_TO_INSTALL%" NEQ "" (
@rem download micromamba
echo ***** Downloading micromamba from %MICROMAMBA_DOWNLOAD_URL% to micromamba.exe *****
call curl -L "%MICROMAMBA_DOWNLOAD_URL%" > micromamba.exe
@rem test the mamba binary
echo ***** Micromamba version: *****
call micromamba.exe --version
@rem create the installer env
if not exist "%INSTALL_ENV_DIR%" (
call micromamba.exe create -y --prefix "%INSTALL_ENV_DIR%"
)
echo ***** Packages to install:%PACKAGES_TO_INSTALL% *****
call micromamba.exe install -y --prefix "%INSTALL_ENV_DIR%" -c conda-forge %PACKAGES_TO_INSTALL%
if not exist "%INSTALL_ENV_DIR%" (
echo ----- There was a problem while installing "%PACKAGES_TO_INSTALL%" using micromamba. Cannot continue. -----
pause
exit /b
)
)
del /q micromamba.exe
@rem For 'git' only
set PATH=%INSTALL_ENV_DIR%\Library\bin;%PATH%
@rem Download/unpack/clean up InvokeAI release sourceball
set err_msg=----- InvokeAI source download failed -----
echo Trying to download "%RELEASE_URL%%RELEASE_SOURCEBALL%"
curl -L %RELEASE_URL%%RELEASE_SOURCEBALL% --output InvokeAI.tgz
if %errorlevel% neq 0 goto err_exit
set err_msg=----- InvokeAI source unpack failed -----
tar -zxf InvokeAI.tgz
if %errorlevel% neq 0 goto err_exit
del /q InvokeAI.tgz
set err_msg=----- InvokeAI source copy failed -----
cd InvokeAI-*
xcopy . .. /e /h
if %errorlevel% neq 0 goto err_exit
cd ..
@rem cleanup
for /f %%i in ('dir /b InvokeAI-*') do rd /s /q %%i
rd /s /q .dev_scripts .github docker-build tests
del /q requirements.in requirements-mkdocs.txt shell.nix
echo ***** Unpacked InvokeAI source *****
@rem Download/unpack/clean up python-build-standalone
set err_msg=----- Python download failed -----
curl -L %PYTHON_BUILD_STANDALONE_URL%/%PYTHON_BUILD_STANDALONE% --output python.tgz
if %errorlevel% neq 0 goto err_exit
set err_msg=----- Python unpack failed -----
tar -zxf python.tgz
if %errorlevel% neq 0 goto err_exit
del /q python.tgz
echo ***** Unpacked python-build-standalone *****
@rem create venv
set err_msg=----- problem creating venv -----
.\python\python -E -s -m venv .venv
if %errorlevel% neq 0 goto err_exit
call .venv\Scripts\activate.bat
echo ***** Created Python virtual environment *****
@rem Print venv's Python version
set err_msg=----- problem calling venv's python -----
echo We're running under
.venv\Scripts\python --version
if %errorlevel% neq 0 goto err_exit
set err_msg=----- pip update failed -----
.venv\Scripts\python -m pip install %no_cache_dir% --no-warn-script-location --upgrade pip wheel
if %errorlevel% neq 0 goto err_exit
echo ***** Updated pip and wheel *****
set err_msg=----- requirements file copy failed -----
copy binary_installer\py3.10-windows-x86_64-cuda-reqs.txt requirements.txt
if %errorlevel% neq 0 goto err_exit
set err_msg=----- main pip install failed -----
.venv\Scripts\python -m pip install %no_cache_dir% --no-warn-script-location -r requirements.txt
if %errorlevel% neq 0 goto err_exit
echo ***** Installed Python dependencies *****
set err_msg=----- InvokeAI setup failed -----
.venv\Scripts\python -m pip install %no_cache_dir% --no-warn-script-location -e .
if %errorlevel% neq 0 goto err_exit
copy binary_installer\invoke.bat.in .\invoke.bat
echo ***** Installed invoke launcher script ******
@rem more cleanup
rd /s /q binary_installer installer_files
@rem preload the models
call .venv\Scripts\python ldm\invoke\config\invokeai_configure.py
set err_msg=----- model download clone failed -----
if %errorlevel% neq 0 goto err_exit
deactivate
echo ***** Finished downloading models *****
echo All done! Execute the file invoke.bat in this directory to start InvokeAI
pause
exit
:err_exit
echo %err_msg%
pause
exit

View File

@ -1,235 +0,0 @@
#!/usr/bin/env bash
# ensure we're in the correct folder in case user's CWD is somewhere else
scriptdir=$(dirname "$0")
cd "$scriptdir"
set -euo pipefail
IFS=$'\n\t'
function _err_exit {
if test "$1" -ne 0
then
echo -e "Error code $1; Error caught was '$2'"
read -p "Press any key to exit..."
exit
fi
}
# This script will install git (if not found on the PATH variable)
# using micromamba (an 8mb static-linked single-file binary, conda replacement).
# For users who already have git, this step will be skipped.
# Next, it'll download the project's source code.
# Then it will download a self-contained, standalone Python and unpack it.
# Finally, it'll create the Python virtual environment and preload the models.
# This enables a user to install this project without manually installing git or Python
echo -e "\n***** Installing InvokeAI into $(pwd)... *****\n"
export no_cache_dir="--no-cache-dir"
if [ $# -ge 1 ]; then
if [ "$1" = "use-cache" ]; then
export no_cache_dir=""
fi
fi
OS_NAME=$(uname -s)
case "${OS_NAME}" in
Linux*) OS_NAME="linux";;
Darwin*) OS_NAME="darwin";;
*) echo -e "\n----- Unknown OS: $OS_NAME! This script runs only on Linux or macOS -----\n" && exit
esac
OS_ARCH=$(uname -m)
case "${OS_ARCH}" in
x86_64*) ;;
arm64*) ;;
*) echo -e "\n----- Unknown system architecture: $OS_ARCH! This script runs only on x86_64 or arm64 -----\n" && exit
esac
# https://mamba.readthedocs.io/en/latest/installation.html
MAMBA_OS_NAME=$OS_NAME
MAMBA_ARCH=$OS_ARCH
if [ "$OS_NAME" == "darwin" ]; then
MAMBA_OS_NAME="osx"
fi
if [ "$OS_ARCH" == "linux" ]; then
MAMBA_ARCH="aarch64"
fi
if [ "$OS_ARCH" == "x86_64" ]; then
MAMBA_ARCH="64"
fi
PY_ARCH=$OS_ARCH
if [ "$OS_ARCH" == "arm64" ]; then
PY_ARCH="aarch64"
fi
# Compute device ('cd' segment of reqs files) detect goes here
# This needs a ton of work
# Suggestions:
# - lspci
# - check $PATH for nvidia-smi, gtt CUDA/GPU version from output
# - Surely there's a similar utility for AMD?
CD="cuda"
if [ "$OS_NAME" == "darwin" ] && [ "$OS_ARCH" == "arm64" ]; then
CD="mps"
fi
# config
INSTALL_ENV_DIR="$(pwd)/installer_files/env"
MICROMAMBA_DOWNLOAD_URL="https://micro.mamba.pm/api/micromamba/${MAMBA_OS_NAME}-${MAMBA_ARCH}/latest"
RELEASE_URL=https://github.com/invoke-ai/InvokeAI
RELEASE_SOURCEBALL=/archive/refs/heads/main.tar.gz
PYTHON_BUILD_STANDALONE_URL=https://github.com/indygreg/python-build-standalone/releases/download
if [ "$OS_NAME" == "darwin" ]; then
PYTHON_BUILD_STANDALONE=20221002/cpython-3.10.7+20221002-${PY_ARCH}-apple-darwin-install_only.tar.gz
elif [ "$OS_NAME" == "linux" ]; then
PYTHON_BUILD_STANDALONE=20221002/cpython-3.10.7+20221002-${PY_ARCH}-unknown-linux-gnu-install_only.tar.gz
fi
echo "INSTALLING $RELEASE_SOURCEBALL FROM $RELEASE_URL"
PACKAGES_TO_INSTALL=""
if ! hash "git" &>/dev/null; then PACKAGES_TO_INSTALL="$PACKAGES_TO_INSTALL git"; fi
# (if necessary) install git and conda into a contained environment
if [ "$PACKAGES_TO_INSTALL" != "" ]; then
# download micromamba
echo -e "\n***** Downloading micromamba from $MICROMAMBA_DOWNLOAD_URL to micromamba *****\n"
curl -L "$MICROMAMBA_DOWNLOAD_URL" | tar -xvjO bin/micromamba > micromamba
chmod u+x ./micromamba
# test the mamba binary
echo -e "\n***** Micromamba version: *****\n"
./micromamba --version
# create the installer env
if [ ! -e "$INSTALL_ENV_DIR" ]; then
./micromamba create -y --prefix "$INSTALL_ENV_DIR"
fi
echo -e "\n***** Packages to install:$PACKAGES_TO_INSTALL *****\n"
./micromamba install -y --prefix "$INSTALL_ENV_DIR" -c conda-forge "$PACKAGES_TO_INSTALL"
if [ ! -e "$INSTALL_ENV_DIR" ]; then
echo -e "\n----- There was a problem while initializing micromamba. Cannot continue. -----\n"
exit
fi
fi
rm -f micromamba.exe
export PATH="$INSTALL_ENV_DIR/bin:$PATH"
# Download/unpack/clean up InvokeAI release sourceball
_err_msg="\n----- InvokeAI source download failed -----\n"
curl -L $RELEASE_URL/$RELEASE_SOURCEBALL --output InvokeAI.tgz
_err_exit $? _err_msg
_err_msg="\n----- InvokeAI source unpack failed -----\n"
tar -zxf InvokeAI.tgz
_err_exit $? _err_msg
rm -f InvokeAI.tgz
_err_msg="\n----- InvokeAI source copy failed -----\n"
cd InvokeAI-*
cp -r . ..
_err_exit $? _err_msg
cd ..
# cleanup
rm -rf InvokeAI-*/
rm -rf .dev_scripts/ .github/ docker-build/ tests/ requirements.in requirements-mkdocs.txt shell.nix
echo -e "\n***** Unpacked InvokeAI source *****\n"
# Download/unpack/clean up python-build-standalone
_err_msg="\n----- Python download failed -----\n"
curl -L $PYTHON_BUILD_STANDALONE_URL/$PYTHON_BUILD_STANDALONE --output python.tgz
_err_exit $? _err_msg
_err_msg="\n----- Python unpack failed -----\n"
tar -zxf python.tgz
_err_exit $? _err_msg
rm -f python.tgz
echo -e "\n***** Unpacked python-build-standalone *****\n"
# create venv
_err_msg="\n----- problem creating venv -----\n"
if [ "$OS_NAME" == "darwin" ]; then
# patch sysconfig so that extensions can build properly
# adapted from https://github.com/cashapp/hermit-packages/commit/fcba384663892f4d9cfb35e8639ff7a28166ee43
PYTHON_INSTALL_DIR="$(pwd)/python"
SYSCONFIG="$(echo python/lib/python*/_sysconfigdata_*.py)"
TMPFILE="$(mktemp)"
chmod +w "${SYSCONFIG}"
cp "${SYSCONFIG}" "${TMPFILE}"
sed "s,'/install,'${PYTHON_INSTALL_DIR},g" "${TMPFILE}" > "${SYSCONFIG}"
rm -f "${TMPFILE}"
fi
./python/bin/python3 -E -s -m venv .venv
_err_exit $? _err_msg
source .venv/bin/activate
echo -e "\n***** Created Python virtual environment *****\n"
# Print venv's Python version
_err_msg="\n----- problem calling venv's python -----\n"
echo -e "We're running under"
.venv/bin/python3 --version
_err_exit $? _err_msg
_err_msg="\n----- pip update failed -----\n"
.venv/bin/python3 -m pip install $no_cache_dir --no-warn-script-location --upgrade pip
_err_exit $? _err_msg
echo -e "\n***** Updated pip *****\n"
_err_msg="\n----- requirements file copy failed -----\n"
cp binary_installer/py3.10-${OS_NAME}-"${OS_ARCH}"-${CD}-reqs.txt requirements.txt
_err_exit $? _err_msg
_err_msg="\n----- main pip install failed -----\n"
.venv/bin/python3 -m pip install $no_cache_dir --no-warn-script-location -r requirements.txt
_err_exit $? _err_msg
echo -e "\n***** Installed Python dependencies *****\n"
_err_msg="\n----- InvokeAI setup failed -----\n"
.venv/bin/python3 -m pip install $no_cache_dir --no-warn-script-location -e .
_err_exit $? _err_msg
echo -e "\n***** Installed InvokeAI *****\n"
cp binary_installer/invoke.sh.in ./invoke.sh
chmod a+rx ./invoke.sh
echo -e "\n***** Installed invoke launcher script ******\n"
# more cleanup
rm -rf binary_installer/ installer_files/
# preload the models
.venv/bin/python3 scripts/configure_invokeai.py
_err_msg="\n----- model download clone failed -----\n"
_err_exit $? _err_msg
deactivate
echo -e "\n***** Finished downloading models *****\n"
echo "All done! Run the command"
echo " $scriptdir/invoke.sh"
echo "to start InvokeAI."
read -p "Press any key to exit..."
exit

View File

@ -1,36 +0,0 @@
@echo off
PUSHD "%~dp0"
call .venv\Scripts\activate.bat
echo Do you want to generate images using the
echo 1. command-line
echo 2. browser-based UI
echo OR
echo 3. open the developer console
set /p choice="Please enter 1, 2 or 3: "
if /i "%choice%" == "1" (
echo Starting the InvokeAI command-line.
.venv\Scripts\python scripts\invoke.py %*
) else if /i "%choice%" == "2" (
echo Starting the InvokeAI browser-based UI.
.venv\Scripts\python scripts\invoke.py --web %*
) else if /i "%choice%" == "3" (
echo Developer Console
echo Python command is:
where python
echo Python version is:
python --version
echo *************************
echo You are now in the system shell, with the local InvokeAI Python virtual environment activated,
echo so that you can troubleshoot this InvokeAI installation as necessary.
echo *************************
echo *** Type `exit` to quit this shell and deactivate the Python virtual environment ***
call cmd /k
) else (
echo Invalid selection
pause
exit /b
)
deactivate

View File

@ -1,46 +0,0 @@
#!/usr/bin/env sh
set -eu
. .venv/bin/activate
# set required env var for torch on mac MPS
if [ "$(uname -s)" == "Darwin" ]; then
export PYTORCH_ENABLE_MPS_FALLBACK=1
fi
echo "Do you want to generate images using the"
echo "1. command-line"
echo "2. browser-based UI"
echo "OR"
echo "3. open the developer console"
echo "Please enter 1, 2, or 3:"
read choice
case $choice in
1)
printf "\nStarting the InvokeAI command-line..\n";
.venv/bin/python scripts/invoke.py $*;
;;
2)
printf "\nStarting the InvokeAI browser-based UI..\n";
.venv/bin/python scripts/invoke.py --web $*;
;;
3)
printf "\nDeveloper Console:\n";
printf "Python command is:\n\t";
which python;
printf "Python version is:\n\t";
python --version;
echo "*************************"
echo "You are now in your user shell ($SHELL) with the local InvokeAI Python virtual environment activated,";
echo "so that you can troubleshoot this InvokeAI installation as necessary.";
printf "*************************\n"
echo "*** Type \`exit\` to quit this shell and deactivate the Python virtual environment *** ";
/usr/bin/env "$SHELL";
;;
*)
echo "Invalid selection";
exit
;;
esac

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -1,17 +0,0 @@
InvokeAI
Project homepage: https://github.com/invoke-ai/InvokeAI
Installation on Windows:
NOTE: You might need to enable Windows Long Paths. If you're not sure,
then you almost certainly need to. Simply double-click the 'WinLongPathsEnabled.reg'
file. Note that you will need to have admin privileges in order to
do this.
Please double-click the 'install.bat' file (while keeping it inside the invokeAI folder).
Installation on Linux and Mac:
Please open the terminal, and run './install.sh' (while keeping it inside the invokeAI folder).
After installation, please run the 'invoke.bat' file (on Windows) or 'invoke.sh'
file (on Linux/Mac) to start InvokeAI.

View File

@ -1,33 +0,0 @@
--prefer-binary
--extra-index-url https://download.pytorch.org/whl/torch_stable.html
--extra-index-url https://download.pytorch.org/whl/cu116
--trusted-host https://download.pytorch.org
accelerate~=0.15
albumentations
diffusers[torch]~=0.11
einops
eventlet
flask_cors
flask_socketio
flaskwebgui==1.0.3
getpass_asterisk
imageio-ffmpeg
pyreadline3
realesrgan
send2trash
streamlit
taming-transformers-rom1504
test-tube
torch-fidelity
torch==1.12.1 ; platform_system == 'Darwin'
torch==1.12.0+cu116 ; platform_system == 'Linux' or platform_system == 'Windows'
torchvision==0.13.1 ; platform_system == 'Darwin'
torchvision==0.13.0+cu116 ; platform_system == 'Linux' or platform_system == 'Windows'
transformers
picklescan
https://github.com/openai/CLIP/archive/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1.zip
https://github.com/invoke-ai/clipseg/archive/1f754751c85d7d4255fa681f4491ff5711c1c288.zip
https://github.com/invoke-ai/GFPGAN/archive/3f5d2397361199bc4a91c08bb7d80f04d7805615.zip ; platform_system=='Windows'
https://github.com/invoke-ai/GFPGAN/archive/c796277a1cf77954e5fc0b288d7062d162894248.zip ; platform_system=='Linux' or platform_system=='Darwin'
https://github.com/Birch-san/k-diffusion/archive/363386981fee88620709cf8f6f2eea167bd6cd74.zip
https://github.com/invoke-ai/PyPatchMatch/archive/129863937a8ab37f6bbcec327c994c0f932abdbc.zip

View File

@ -19,31 +19,56 @@ An invocation looks like this:
```py
class UpscaleInvocation(BaseInvocation):
"""Upscales an image."""
type: Literal['upscale'] = 'upscale'
# fmt: off
type: Literal["upscale"] = "upscale"
# Inputs
image: Union[ImageField,None] = Field(description="The input image")
strength: float = Field(default=0.75, gt=0, le=1, description="The strength")
level: Literal[2,4] = Field(default=2, description = "The upscale level")
image: Union[ImageField, None] = Field(description="The input image", default=None)
strength: float = Field(default=0.75, gt=0, le=1, description="The strength")
level: Literal[2, 4] = Field(default=2, description="The upscale level")
# fmt: on
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"tags": ["upscaling", "image"],
},
}
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get(self.image.image_type, self.image.image_name)
results = context.services.generate.upscale_and_reconstruct(
image_list = [[image, 0]],
upscale = (self.level, self.strength),
strength = 0.0, # GFPGAN strength
save_original = False,
image_callback = None,
image = context.services.images.get_pil_image(
self.image.image_origin, self.image.image_name
)
results = context.services.restoration.upscale_and_reconstruct(
image_list=[[image, 0]],
upscale=(self.level, self.strength),
strength=0.0, # GFPGAN strength
save_original=False,
image_callback=None,
)
# Results are image and seed, unwrap for now
# TODO: can this return multiple results?
image_type = ImageType.RESULT
image_name = context.services.images.create_name(context.graph_execution_state_id, self.id)
context.services.images.save(image_type, image_name, results[0][0])
return ImageOutput(
image = ImageField(image_type = image_type, image_name = image_name)
image_dto = context.services.images.create(
image=results[0][0],
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
)
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
image_origin=image_dto.image_origin,
),
width=image_dto.width,
height=image_dto.height,
)
```
Each portion is important to implement correctly.
@ -95,25 +120,67 @@ Finally, note that for all linking, the `type` of the linked fields must match.
If the `name` also matches, then the field can be **automatically linked** to a
previous invocation by name and matching.
### Config
```py
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"tags": ["upscaling", "image"],
},
}
```
This is an optional configuration for the invocation. It inherits from
pydantic's model `Config` class, and it used primarily to customize the
autogenerated OpenAPI schema.
The UI relies on the OpenAPI schema in two ways:
- An API client & Typescript types are generated from it. This happens at build
time.
- The node editor parses the schema into a template used by the UI to create the
node editor UI. This parsing happens at runtime.
In this example, a `ui` key has been added to the `schema_extra` dict to provide
some tags for the UI, to facilitate filtering nodes.
See the Schema Generation section below for more information.
### Invoke Function
```py
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get(self.image.image_type, self.image.image_name)
results = context.services.generate.upscale_and_reconstruct(
image_list = [[image, 0]],
upscale = (self.level, self.strength),
strength = 0.0, # GFPGAN strength
save_original = False,
image_callback = None,
image = context.services.images.get_pil_image(
self.image.image_origin, self.image.image_name
)
results = context.services.restoration.upscale_and_reconstruct(
image_list=[[image, 0]],
upscale=(self.level, self.strength),
strength=0.0, # GFPGAN strength
save_original=False,
image_callback=None,
)
# Results are image and seed, unwrap for now
image_type = ImageType.RESULT
image_name = context.services.images.create_name(context.graph_execution_state_id, self.id)
context.services.images.save(image_type, image_name, results[0][0])
# TODO: can this return multiple results?
image_dto = context.services.images.create(
image=results[0][0],
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
)
return ImageOutput(
image = ImageField(image_type = image_type, image_name = image_name)
image=ImageField(
image_name=image_dto.image_name,
image_origin=image_dto.image_origin,
),
width=image_dto.width,
height=image_dto.height,
)
```
@ -135,9 +202,16 @@ scenarios. If you need functionality, please provide it as a service in the
```py
class ImageOutput(BaseInvocationOutput):
"""Base class for invocations that output an image"""
type: Literal['image'] = 'image'
image: ImageField = Field(default=None, description="The output image")
# fmt: off
type: Literal["image_output"] = "image_output"
image: ImageField = Field(default=None, description="The output image")
width: int = Field(description="The width of the image in pixels")
height: int = Field(description="The height of the image in pixels")
# fmt: on
class Config:
schema_extra = {"required": ["type", "image", "width", "height"]}
```
Output classes look like an invocation class without the invoke method. Prefer
@ -168,35 +242,36 @@ Here's that `ImageOutput` class, without the needed schema customisation:
class ImageOutput(BaseInvocationOutput):
"""Base class for invocations that output an image"""
type: Literal["image"] = "image"
# fmt: off
type: Literal["image_output"] = "image_output"
image: ImageField = Field(default=None, description="The output image")
width: int = Field(description="The width of the image in pixels")
height: int = Field(description="The height of the image in pixels")
# fmt: on
```
The generated OpenAPI schema, and all clients/types generated from it, will have
the `type` and `image` properties marked as optional, even though we know they
will always have a value by the time we can interact with them via the API.
Here's the same class, but with the schema customisation added:
The OpenAPI schema that results from this `ImageOutput` will have the `type`,
`image`, `width` and `height` properties marked as optional, even though we know
they will always have a value.
```python
class ImageOutput(BaseInvocationOutput):
"""Base class for invocations that output an image"""
type: Literal["image"] = "image"
# fmt: off
type: Literal["image_output"] = "image_output"
image: ImageField = Field(default=None, description="The output image")
width: int = Field(description="The width of the image in pixels")
height: int = Field(description="The height of the image in pixels")
# fmt: on
# Add schema customization
class Config:
schema_extra = {
'required': [
'type',
'image',
]
}
schema_extra = {"required": ["type", "image", "width", "height"]}
```
The resultant schema (and any API client or types generated from it) will now
have see `type` as string literal `"image"` and `image` as an `ImageField`
object.
With the customization in place, the schema will now show these properties as
required, obviating the need for extensive null checks in client code.
See this `pydantic` issue for discussion on this solution:
<https://github.com/pydantic/pydantic/discussions/4577>

171
docs/features/LOGGING.md Normal file
View File

@ -0,0 +1,171 @@
---
title: Controlling Logging
---
# :material-image-off: Controlling Logging
## Controlling How InvokeAI Logs Status Messages
InvokeAI logs status messages using a configurable logging system. You
can log to the terminal window, to a designated file on the local
machine, to the syslog facility on a Linux or Mac, or to a properly
configured web server. You can configure several logs at the same
time, and control the level of message logged and the logging format
(to a limited extent).
Three command-line options control logging:
### `--log_handlers <handler1> <handler2> ...`
This option activates one or more log handlers. Options are "console",
"file", "syslog" and "http". To specify more than one, separate them
by spaces:
```bash
invokeai-web --log_handlers console syslog=/dev/log file=C:\Users\fred\invokeai.log
```
The format of these options is described below.
### `--log_format {plain|color|legacy|syslog}`
This controls the format of log messages written to the console. Only
the "console" log handler is currently affected by this setting.
* "plain" provides formatted messages like this:
```bash
[2023-05-24 23:18:2[2023-05-24 23:18:50,352]::[InvokeAI]::DEBUG --> this is a debug message
[2023-05-24 23:18:50,352]::[InvokeAI]::INFO --> this is an informational messages
[2023-05-24 23:18:50,352]::[InvokeAI]::WARNING --> this is a warning
[2023-05-24 23:18:50,352]::[InvokeAI]::ERROR --> this is an error
[2023-05-24 23:18:50,352]::[InvokeAI]::CRITICAL --> this is a critical error
```
* "color" produces similar output, but the text will be color coded to
indicate the severity of the message.
* "legacy" produces output similar to InvokeAI versions 2.3 and earlier:
```bash
### this is a critical error
*** this is an error
** this is a warning
>> this is an informational messages
| this is a debug message
```
* "syslog" produces messages suitable for syslog entries:
```bash
InvokeAI [2691178] <CRITICAL> this is a critical error
InvokeAI [2691178] <ERROR> this is an error
InvokeAI [2691178] <WARNING> this is a warning
InvokeAI [2691178] <INFO> this is an informational messages
InvokeAI [2691178] <DEBUG> this is a debug message
```
(note that the date, time and hostname will be added by the syslog
system)
### `--log_level {debug|info|warning|error|critical}`
Providing this command-line option will cause only messages at the
specified level or above to be emitted.
## Console logging
When "console" is provided to `--log_handlers`, messages will be
written to the command line window in which InvokeAI was launched. By
default, the color formatter will be used unless overridden by
`--log_format`.
## File logging
When "file" is provided to `--log_handlers`, entries will be written
to the file indicated in the path argument. By default, the "plain"
format will be used:
```bash
invokeai-web --log_handlers file=/var/log/invokeai.log
```
## Syslog logging
When "syslog" is requested, entries will be sent to the syslog
system. There are a variety of ways to control where the log message
is sent:
* Send to the local machine using the `/dev/log` socket:
```
invokeai-web --log_handlers syslog=/dev/log
```
* Send to the local machine using a UDP message:
```
invokeai-web --log_handlers syslog=localhost
```
* Send to the local machine using a UDP message on a nonstandard
port:
```
invokeai-web --log_handlers syslog=localhost:512
```
* Send to a remote machine named "loghost" on the local LAN using
facility LOG_USER and UDP packets:
```
invokeai-web --log_handlers syslog=loghost,facility=LOG_USER,socktype=SOCK_DGRAM
```
This can be abbreviated `syslog=loghost`, as LOG_USER and SOCK_DGRAM
are defaults.
* Send to a remote machine named "loghost" using the facility LOCAL0
and using a TCP socket:
```
invokeai-web --log_handlers syslog=loghost,facility=LOG_LOCAL0,socktype=SOCK_STREAM
```
If no arguments are specified (just a bare "syslog"), then the logging
system will look for a UNIX socket named `/dev/log`, and if not found
try to send a UDP message to `localhost`. The Macintosh OS used to
support logging to a socket named `/var/run/syslog`, but this feature
has since been disabled.
## Web logging
If you have access to a web server that is configured to log messages
when a particular URL is requested, you can log using the "http"
method:
```
invokeai-web --log_handlers http=http://my.server/path/to/logger,method=POST
```
The optional [,method=] part can be used to specify whether the URL
accepts GET (default) or POST messages.
Currently password authentication and SSL are not supported.
## Using the configuration file
You can set and forget logging options by adding a "Logging" section
to `invokeai.yaml`:
```
InvokeAI:
[... other settings...]
Logging:
log_handlers:
- console
- syslog=/dev/log
log_level: info
log_format: color
```

View File

@ -57,6 +57,9 @@ Personalize models by adding your own style or subjects.
## * [The NSFW Checker](NSFW.md)
Prevent InvokeAI from displaying unwanted racy images.
## * [Controlling Logging](LOGGING.md)
Control how InvokeAI logs status messages.
## * [Miscellaneous](OTHER.md)
Run InvokeAI on Google Colab, generate images with repeating patterns,
batch process a file of prompts, increase the "creativity" of image

View File

@ -67,7 +67,7 @@ title: Home
implementation of Stable Diffusion, the open source text-to-image and
image-to-image generator. It provides a streamlined process with various new
features and options to aid the image generation process. It runs on Windows,
Mac and Linux machines, and runs on GPU cards with as little as 4 GB or RAM.
Mac and Linux machines, and runs on GPU cards with as little as 4 GB of RAM.
**Quick links**: [<a href="https://discord.gg/ZmtBAhwWhy">Discord Server</a>]
[<a href="https://github.com/invoke-ai/InvokeAI/">Code and Downloads</a>] [<a

View File

@ -216,7 +216,7 @@ manager, please follow these steps:
9. Run the command-line- or the web- interface:
From within INVOKEAI_ROOT, activate the environment
(with `source .venv/bin/activate` or `.venv\scripts\activate), and then run
(with `source .venv/bin/activate` or `.venv\scripts\activate`), and then run
the script `invokeai`. If the virtual environment you selected is NOT inside
INVOKEAI_ROOT, then you must specify the path to the root directory by adding
`--root_dir \path\to\invokeai` to the commands below:

View File

@ -7,42 +7,42 @@ call .venv\Scripts\activate.bat
set INVOKEAI_ROOT=.
:start
echo Do you want to generate images using the
echo 1. command-line interface
echo 2. browser-based UI
echo 3. run textual inversion training
echo 4. merge models (diffusers type only)
echo 5. download and install models
echo 6. change InvokeAI startup options
echo 7. re-run the configure script to fix a broken install
echo 8. open the developer console
echo 9. update InvokeAI
echo 10. command-line help
echo Q - quit
set /P restore="Please enter 1-10, Q: [2] "
if not defined restore set restore=2
IF /I "%restore%" == "1" (
echo Desired action:
echo 1. Generate images with the browser-based interface
echo 2. Explore InvokeAI nodes using a command-line interface
echo 3. Run textual inversion training
echo 4. Merge models (diffusers type only)
echo 5. Download and install models
echo 6. Change InvokeAI startup options
echo 7. Re-run the configure script to fix a broken install
echo 8. Open the developer console
echo 9. Update InvokeAI
echo 10. Command-line help
echo Q - Quit
set /P choice="Please enter 1-10, Q: [2] "
if not defined choice set choice=2
IF /I "%choice%" == "1" (
echo Starting the InvokeAI browser-based UI..
python .venv\Scripts\invokeai-web.exe %*
) ELSE IF /I "%choice%" == "2" (
echo Starting the InvokeAI command-line..
python .venv\Scripts\invokeai.exe %*
) ELSE IF /I "%restore%" == "2" (
echo Starting the InvokeAI browser-based UI..
python .venv\Scripts\invokeai.exe --web %*
) ELSE IF /I "%restore%" == "3" (
) ELSE IF /I "%choice%" == "3" (
echo Starting textual inversion training..
python .venv\Scripts\invokeai-ti.exe --gui
) ELSE IF /I "%restore%" == "4" (
) ELSE IF /I "%choice%" == "4" (
echo Starting model merging script..
python .venv\Scripts\invokeai-merge.exe --gui
) ELSE IF /I "%restore%" == "5" (
) ELSE IF /I "%choice%" == "5" (
echo Running invokeai-model-install...
python .venv\Scripts\invokeai-model-install.exe
) ELSE IF /I "%restore%" == "6" (
) ELSE IF /I "%choice%" == "6" (
echo Running invokeai-configure...
python .venv\Scripts\invokeai-configure.exe --skip-sd-weight --skip-support-models
) ELSE IF /I "%restore%" == "7" (
) ELSE IF /I "%choice%" == "7" (
echo Running invokeai-configure...
python .venv\Scripts\invokeai-configure.exe --yes --default_only
) ELSE IF /I "%restore%" == "8" (
) ELSE IF /I "%choice%" == "8" (
echo Developer Console
echo Python command is:
where python
@ -54,15 +54,15 @@ IF /I "%restore%" == "1" (
echo *************************
echo *** Type `exit` to quit this shell and deactivate the Python virtual environment ***
call cmd /k
) ELSE IF /I "%restore%" == "9" (
) ELSE IF /I "%choice%" == "9" (
echo Running invokeai-update...
python .venv\Scripts\invokeai-update.exe %*
) ELSE IF /I "%restore%" == "10" (
) ELSE IF /I "%choice%" == "10" (
echo Displaying command line help...
python .venv\Scripts\invokeai.exe --help %*
pause
exit /b
) ELSE IF /I "%restore%" == "q" (
) ELSE IF /I "%choice%" == "q" (
echo Goodbye!
goto ending
) ELSE (

View File

@ -1,5 +1,10 @@
#!/bin/bash
# MIT License
# Coauthored by Lincoln Stein, Eugene Brodsky and Joshua Kimsey
# Copyright 2023, The InvokeAI Development Team
####
# This launch script assumes that:
# 1. it is located in the runtime directory,
@ -11,85 +16,168 @@
set -eu
# ensure we're in the correct folder in case user's CWD is somewhere else
# Ensure we're in the correct folder in case user's CWD is somewhere else
scriptdir=$(dirname "$0")
cd "$scriptdir"
. .venv/bin/activate
export INVOKEAI_ROOT="$scriptdir"
PARAMS=$@
# set required env var for torch on mac MPS
# Check to see if dialog is installed (it seems to be fairly standard, but good to check regardless) and if the user has passed the --no-tui argument to disable the dialog TUI
tui=true
if command -v dialog &>/dev/null; then
# This must use $@ to properly loop through the arguments passed by the user
for arg in "$@"; do
if [ "$arg" == "--no-tui" ]; then
tui=false
# Remove the --no-tui argument to avoid errors later on when passing arguments to InvokeAI
PARAMS=$(echo "$PARAMS" | sed 's/--no-tui//')
break
fi
done
else
tui=false
fi
# Set required env var for torch on mac MPS
if [ "$(uname -s)" == "Darwin" ]; then
export PYTORCH_ENABLE_MPS_FALLBACK=1
fi
if [ "$0" != "bash" ]; then
while true
do
echo "Do you want to generate images using the"
echo "1. command-line interface"
echo "2. browser-based UI"
echo "3. run textual inversion training"
echo "4. merge models (diffusers type only)"
echo "5. download and install models"
echo "6. change InvokeAI startup options"
echo "7. re-run the configure script to fix a broken install"
echo "8. open the developer console"
echo "9. update InvokeAI"
echo "10. command-line help"
echo "Q - Quit"
echo ""
read -p "Please enter 1-10, Q: [2] " yn
choice=${yn:='2'}
case $choice in
1)
echo "Starting the InvokeAI command-line..."
invokeai $@
;;
2)
echo "Starting the InvokeAI browser-based UI..."
invokeai --web $@
;;
3)
echo "Starting Textual Inversion:"
invokeai-ti --gui $@
;;
4)
echo "Merging Models:"
invokeai-merge --gui $@
;;
5)
invokeai-model-install --root ${INVOKEAI_ROOT}
;;
6)
invokeai-configure --root ${INVOKEAI_ROOT} --skip-sd-weights --skip-support-models
;;
7)
invokeai-configure --root ${INVOKEAI_ROOT} --yes --default_only
;;
8)
echo "Developer Console:"
file_name=$(basename "${BASH_SOURCE[0]}")
bash --init-file "$file_name"
;;
9)
echo "Update:"
invokeai-update
;;
10)
invokeai --help
;;
[qQ])
exit 0
;;
*)
echo "Invalid selection"
exit;;
# Primary function for the case statement to determine user input
do_choice() {
case $1 in
1)
clear
printf "Generate images with a browser-based interface\n"
invokeai-web $PARAMS
;;
2)
clear
printf "Explore InvokeAI nodes using a command-line interface\n"
invokeai $PARAMS
;;
3)
clear
printf "Textual inversion training\n"
invokeai-ti --gui $PARAMS
;;
4)
clear
printf "Merge models (diffusers type only)\n"
invokeai-merge --gui $PARAMS
;;
5)
clear
printf "Download and install models\n"
invokeai-model-install --root ${INVOKEAI_ROOT}
;;
6)
clear
printf "Change InvokeAI startup options\n"
invokeai-configure --root ${INVOKEAI_ROOT} --skip-sd-weights --skip-support-models
;;
7)
clear
printf "Re-run the configure script to fix a broken install\n"
invokeai-configure --root ${INVOKEAI_ROOT} --yes --default_only
;;
8)
clear
printf "Open the developer console\n"
file_name=$(basename "${BASH_SOURCE[0]}")
bash --init-file "$file_name"
;;
9)
clear
printf "Update InvokeAI\n"
invokeai-update
;;
10)
clear
printf "Command-line help\n"
invokeai --help
;;
"HELP 1")
clear
printf "Command-line help\n"
invokeai --help
;;
*)
clear
printf "Exiting...\n"
exit
;;
esac
done
clear
}
# Dialog-based TUI for launcing Invoke functions
do_dialog() {
options=(
1 "Generate images with a browser-based interface"
2 "Generate images using a command-line interface"
3 "Textual inversion training"
4 "Merge models (diffusers type only)"
5 "Download and install models"
6 "Change InvokeAI startup options"
7 "Re-run the configure script to fix a broken install"
8 "Open the developer console"
9 "Update InvokeAI")
choice=$(dialog --clear \
--backtitle "\Zb\Zu\Z3InvokeAI" \
--colors \
--title "What would you like to do?" \
--ok-label "Run" \
--cancel-label "Exit" \
--help-button \
--help-label "CLI Help" \
--menu "Select an option:" \
0 0 0 \
"${options[@]}" \
2>&1 >/dev/tty) || clear
do_choice "$choice"
clear
}
# Command-line interface for launching Invoke functions
do_line_input() {
clear
printf " ** For a more attractive experience, please install the 'dialog' utility using your package manager. **\n\n"
printf "What would you like to do?\n"
printf "1: Generate images using the browser-based interface\n"
printf "2: Explore InvokeAI nodes using the command-line interface\n"
printf "3: Run textual inversion training\n"
printf "4: Merge models (diffusers type only)\n"
printf "5: Download and install models\n"
printf "6: Change InvokeAI startup options\n"
printf "7: Re-run the configure script to fix a broken install\n"
printf "8: Open the developer console\n"
printf "9: Update InvokeAI\n"
printf "10: Command-line help\n"
printf "Q: Quit\n\n"
read -p "Please enter 1-10, Q: [1] " yn
choice=${yn:='1'}
do_choice $choice
clear
}
# Main IF statement for launching Invoke with either the TUI or CLI, and for checking if the user is in the developer console
if [ "$0" != "bash" ]; then
while true; do
if $tui; then
# .dialogrc must be located in the same directory as the invoke.sh script
export DIALOGRC="./.dialogrc"
do_dialog
else
do_line_input
fi
done
else # in developer console
python --version
echo "Press ^D to exit"
printf "Press ^D to exit\n"
export PS1="(InvokeAI) \u@\h \w> "
fi

View File

@ -1,22 +1,34 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from logging import Logger
import os
import invokeai.backend.util.logging as logger
from typing import types
from invokeai.app.services.board_image_record_storage import (
SqliteBoardImageRecordStorage,
)
from invokeai.app.services.board_images import (
BoardImagesService,
BoardImagesServiceDependencies,
)
from invokeai.app.services.board_record_storage import SqliteBoardRecordStorage
from invokeai.app.services.boards import BoardService, BoardServiceDependencies
from invokeai.app.services.image_record_storage import SqliteImageRecordStorage
from invokeai.app.services.images import ImageService, ImageServiceDependencies
from invokeai.app.services.metadata import CoreMetadataService
from invokeai.app.services.resource_name import SimpleNameService
from invokeai.app.services.urls import LocalUrlService
from invokeai.backend.util.logging import InvokeAILogger
from ..services.default_graphs import create_system_graphs
from ..services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
from ..services.model_manager_initializer import get_model_manager
from ..services.restoration_services import RestorationServices
from ..services.graph import GraphExecutionState, LibraryGraph
from ..services.image_storage import DiskImageStorage
from ..services.image_file_storage import DiskImageFileStorage
from ..services.invocation_queue import MemoryInvocationQueue
from ..services.invocation_services import InvocationServices
from ..services.invoker import Invoker
from ..services.processor import DefaultInvocationProcessor
from ..services.sqlite import SqliteItemStorage
from ..services.metadata import PngMetadataService
from ..services.model_manager_service import ModelManagerService
from .events import FastAPIEventService
@ -36,44 +48,89 @@ def check_internet() -> bool:
return False
logger = InvokeAILogger.getLogger()
class ApiDependencies:
"""Contains and initializes all dependencies for the API"""
invoker: Invoker = None
def initialize(config, event_handler_id: int, logger: types.ModuleType=logger):
@staticmethod
def initialize(config, event_handler_id: int, logger: Logger = logger):
logger.info(f"Internet connectivity is {config.internet_available}")
events = FastAPIEventService(event_handler_id)
output_folder = os.path.abspath(
os.path.join(os.path.dirname(__file__), "../../../../outputs")
)
latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f'{output_folder}/latents'))
metadata = PngMetadataService()
images = DiskImageStorage(f'{output_folder}/images', metadata_service=metadata)
output_folder = config.output_path
# TODO: build a file/path manager?
db_location = os.path.join(output_folder, "invokeai.db")
db_location = config.db_path
db_location.parent.mkdir(parents=True, exist_ok=True)
graph_execution_manager = SqliteItemStorage[GraphExecutionState](
filename=db_location, table_name="graph_executions"
)
urls = LocalUrlService()
metadata = CoreMetadataService()
image_record_storage = SqliteImageRecordStorage(db_location)
image_file_storage = DiskImageFileStorage(f"{output_folder}/images")
names = SimpleNameService()
latents = ForwardCacheLatentsStorage(
DiskLatentsStorage(f"{output_folder}/latents")
)
board_record_storage = SqliteBoardRecordStorage(db_location)
board_image_record_storage = SqliteBoardImageRecordStorage(db_location)
boards = BoardService(
services=BoardServiceDependencies(
board_image_record_storage=board_image_record_storage,
board_record_storage=board_record_storage,
image_record_storage=image_record_storage,
url=urls,
logger=logger,
)
)
board_images = BoardImagesService(
services=BoardImagesServiceDependencies(
board_image_record_storage=board_image_record_storage,
board_record_storage=board_record_storage,
image_record_storage=image_record_storage,
url=urls,
logger=logger,
)
)
images = ImageService(
services=ImageServiceDependencies(
board_image_record_storage=board_image_record_storage,
image_record_storage=image_record_storage,
image_file_storage=image_file_storage,
metadata=metadata,
url=urls,
logger=logger,
names=names,
graph_execution_manager=graph_execution_manager,
)
)
services = InvocationServices(
model_manager=get_model_manager(config,logger),
model_manager=ModelManagerService(config,logger),
events=events,
latents=latents,
images=images,
metadata=metadata,
boards=boards,
board_images=board_images,
queue=MemoryInvocationQueue(),
graph_library=SqliteItemStorage[LibraryGraph](
filename=db_location, table_name="graphs"
),
graph_execution_manager=SqliteItemStorage[GraphExecutionState](
filename=db_location, table_name="graph_executions"
),
graph_execution_manager=graph_execution_manager,
processor=DefaultInvocationProcessor(),
restoration=RestorationServices(config,logger),
restoration=RestorationServices(config, logger),
configuration=config,
logger=logger,
)

View File

@ -1,40 +0,0 @@
from typing import Optional
from pydantic import BaseModel, Field
from invokeai.app.models.image import ImageType
from invokeai.app.services.metadata import InvokeAIMetadata
class ImageResponseMetadata(BaseModel):
"""An image's metadata. Used only in HTTP responses."""
created: int = Field(description="The creation timestamp of the image")
width: int = Field(description="The width of the image in pixels")
height: int = Field(description="The height of the image in pixels")
invokeai: Optional[InvokeAIMetadata] = Field(
description="The image's InvokeAI-specific metadata"
)
class ImageResponse(BaseModel):
"""The response type for images"""
image_type: ImageType = Field(description="The type of the image")
image_name: str = Field(description="The name of the image")
image_url: str = Field(description="The url of the image")
thumbnail_url: str = Field(description="The url of the image's thumbnail")
metadata: ImageResponseMetadata = Field(description="The image's metadata")
class ProgressImage(BaseModel):
"""The progress image sent intermittently during processing"""
width: int = Field(description="The effective width of the image in pixels")
height: int = Field(description="The effective height of the image in pixels")
dataURL: str = Field(description="The image data as a b64 data URL")
class SavedImage(BaseModel):
image_name: str = Field(description="The name of the saved image")
thumbnail_name: str = Field(description="The name of the saved thumbnail")
created: int = Field(description="The created timestamp of the saved image")

View File

@ -0,0 +1,69 @@
from fastapi import Body, HTTPException, Path, Query
from fastapi.routing import APIRouter
from invokeai.app.services.board_record_storage import BoardRecord, BoardChanges
from invokeai.app.services.image_record_storage import OffsetPaginatedResults
from invokeai.app.services.models.board_record import BoardDTO
from invokeai.app.services.models.image_record import ImageDTO
from ..dependencies import ApiDependencies
board_images_router = APIRouter(prefix="/v1/board_images", tags=["boards"])
@board_images_router.post(
"/",
operation_id="create_board_image",
responses={
201: {"description": "The image was added to a board successfully"},
},
status_code=201,
)
async def create_board_image(
board_id: str = Body(description="The id of the board to add to"),
image_name: str = Body(description="The name of the image to add"),
):
"""Creates a board_image"""
try:
result = ApiDependencies.invoker.services.board_images.add_image_to_board(board_id=board_id, image_name=image_name)
return result
except Exception as e:
raise HTTPException(status_code=500, detail="Failed to add to board")
@board_images_router.delete(
"/",
operation_id="remove_board_image",
responses={
201: {"description": "The image was removed from the board successfully"},
},
status_code=201,
)
async def remove_board_image(
board_id: str = Body(description="The id of the board"),
image_name: str = Body(description="The name of the image to remove"),
):
"""Deletes a board_image"""
try:
result = ApiDependencies.invoker.services.board_images.remove_image_from_board(board_id=board_id, image_name=image_name)
return result
except Exception as e:
raise HTTPException(status_code=500, detail="Failed to update board")
@board_images_router.get(
"/{board_id}",
operation_id="list_board_images",
response_model=OffsetPaginatedResults[ImageDTO],
)
async def list_board_images(
board_id: str = Path(description="The id of the board"),
offset: int = Query(default=0, description="The page offset"),
limit: int = Query(default=10, description="The number of boards per page"),
) -> OffsetPaginatedResults[ImageDTO]:
"""Gets a list of images for a board"""
results = ApiDependencies.invoker.services.board_images.get_images_for_board(
board_id,
)
return results

View File

@ -0,0 +1,108 @@
from typing import Optional, Union
from fastapi import Body, HTTPException, Path, Query
from fastapi.routing import APIRouter
from invokeai.app.services.board_record_storage import BoardChanges
from invokeai.app.services.image_record_storage import OffsetPaginatedResults
from invokeai.app.services.models.board_record import BoardDTO
from ..dependencies import ApiDependencies
boards_router = APIRouter(prefix="/v1/boards", tags=["boards"])
@boards_router.post(
"/",
operation_id="create_board",
responses={
201: {"description": "The board was created successfully"},
},
status_code=201,
response_model=BoardDTO,
)
async def create_board(
board_name: str = Query(description="The name of the board to create"),
) -> BoardDTO:
"""Creates a board"""
try:
result = ApiDependencies.invoker.services.boards.create(board_name=board_name)
return result
except Exception as e:
raise HTTPException(status_code=500, detail="Failed to create board")
@boards_router.get("/{board_id}", operation_id="get_board", response_model=BoardDTO)
async def get_board(
board_id: str = Path(description="The id of board to get"),
) -> BoardDTO:
"""Gets a board"""
try:
result = ApiDependencies.invoker.services.boards.get_dto(board_id=board_id)
return result
except Exception as e:
raise HTTPException(status_code=404, detail="Board not found")
@boards_router.patch(
"/{board_id}",
operation_id="update_board",
responses={
201: {
"description": "The board was updated successfully",
},
},
status_code=201,
response_model=BoardDTO,
)
async def update_board(
board_id: str = Path(description="The id of board to update"),
changes: BoardChanges = Body(description="The changes to apply to the board"),
) -> BoardDTO:
"""Updates a board"""
try:
result = ApiDependencies.invoker.services.boards.update(
board_id=board_id, changes=changes
)
return result
except Exception as e:
raise HTTPException(status_code=500, detail="Failed to update board")
@boards_router.delete("/{board_id}", operation_id="delete_board")
async def delete_board(
board_id: str = Path(description="The id of board to delete"),
) -> None:
"""Deletes a board"""
try:
ApiDependencies.invoker.services.boards.delete(board_id=board_id)
except Exception as e:
# TODO: Does this need any exception handling at all?
pass
@boards_router.get(
"/",
operation_id="list_boards",
response_model=Union[OffsetPaginatedResults[BoardDTO], list[BoardDTO]],
)
async def list_boards(
all: Optional[bool] = Query(default=None, description="Whether to list all boards"),
offset: Optional[int] = Query(default=None, description="The page offset"),
limit: Optional[int] = Query(
default=None, description="The number of boards per page"
),
) -> Union[OffsetPaginatedResults[BoardDTO], list[BoardDTO]]:
"""Gets a list of boards"""
if all:
return ApiDependencies.invoker.services.boards.get_all()
elif offset is not None and limit is not None:
return ApiDependencies.invoker.services.boards.get_many(
offset,
limit,
)
else:
raise HTTPException(
status_code=400,
detail="Invalid request: Must provide either 'all' or both 'offset' and 'limit'",
)

View File

@ -1,148 +1,241 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
import io
from datetime import datetime, timezone
import json
import os
from typing import Any
import uuid
from fastapi import Body, HTTPException, Path, Query, Request, UploadFile
from fastapi.responses import FileResponse, Response
from typing import Optional
from fastapi import Body, HTTPException, Path, Query, Request, Response, UploadFile
from fastapi.routing import APIRouter
from fastapi.responses import FileResponse
from PIL import Image
from invokeai.app.api.models.images import (
ImageResponse,
ImageResponseMetadata,
from invokeai.app.models.image import (
ImageCategory,
ResourceOrigin,
)
from invokeai.app.services.image_record_storage import OffsetPaginatedResults
from invokeai.app.services.models.image_record import (
ImageDTO,
ImageRecordChanges,
ImageUrlsDTO,
)
from invokeai.app.services.item_storage import PaginatedResults
from ...services.image_storage import ImageType
from ..dependencies import ApiDependencies
images_router = APIRouter(prefix="/v1/images", tags=["images"])
@images_router.get("/{image_type}/{image_name}", operation_id="get_image")
async def get_image(
image_type: ImageType = Path(description="The type of image to get"),
image_name: str = Path(description="The name of the image to get"),
) -> FileResponse:
"""Gets an image"""
path = ApiDependencies.invoker.services.images.get_path(
image_type=image_type, image_name=image_name
)
if ApiDependencies.invoker.services.images.validate_path(path):
return FileResponse(path)
else:
raise HTTPException(status_code=404)
@images_router.delete("/{image_type}/{image_name}", operation_id="delete_image")
async def delete_image(
image_type: ImageType = Path(description="The type of image to delete"),
image_name: str = Path(description="The name of the image to delete"),
) -> None:
"""Deletes an image and its thumbnail"""
ApiDependencies.invoker.services.images.delete(
image_type=image_type, image_name=image_name
)
@images_router.get(
"/{thumbnail_type}/thumbnails/{thumbnail_name}", operation_id="get_thumbnail"
)
async def get_thumbnail(
thumbnail_type: ImageType = Path(description="The type of thumbnail to get"),
thumbnail_name: str = Path(description="The name of the thumbnail to get"),
) -> FileResponse | Response:
"""Gets a thumbnail"""
path = ApiDependencies.invoker.services.images.get_path(
image_type=thumbnail_type, image_name=thumbnail_name, is_thumbnail=True
)
if ApiDependencies.invoker.services.images.validate_path(path):
return FileResponse(path)
else:
raise HTTPException(status_code=404)
@images_router.post(
"/uploads/",
"/",
operation_id="upload_image",
responses={
201: {
"description": "The image was uploaded successfully",
"model": ImageResponse,
},
201: {"description": "The image was uploaded successfully"},
415: {"description": "Image upload failed"},
},
status_code=201,
response_model=ImageDTO,
)
async def upload_image(
file: UploadFile, image_type: ImageType, request: Request, response: Response
) -> ImageResponse:
file: UploadFile,
request: Request,
response: Response,
image_category: ImageCategory = Query(description="The category of the image"),
is_intermediate: bool = Query(description="Whether this is an intermediate image"),
session_id: Optional[str] = Query(
default=None, description="The session ID associated with this upload, if any"
),
) -> ImageDTO:
"""Uploads an image"""
if not file.content_type.startswith("image"):
raise HTTPException(status_code=415, detail="Not an image")
contents = await file.read()
try:
img = Image.open(io.BytesIO(contents))
pil_image = Image.open(io.BytesIO(contents))
except:
# Error opening the image
raise HTTPException(status_code=415, detail="Failed to read image")
filename = f"{uuid.uuid4()}_{str(int(datetime.now(timezone.utc).timestamp()))}.png"
try:
image_dto = ApiDependencies.invoker.services.images.create(
image=pil_image,
image_origin=ResourceOrigin.EXTERNAL,
image_category=image_category,
session_id=session_id,
is_intermediate=is_intermediate,
)
saved_image = ApiDependencies.invoker.services.images.save(
image_type, filename, img
)
response.status_code = 201
response.headers["Location"] = image_dto.image_url
invokeai_metadata = ApiDependencies.invoker.services.metadata.get_metadata(img)
return image_dto
except Exception as e:
raise HTTPException(status_code=500, detail="Failed to create image")
image_url = ApiDependencies.invoker.services.images.get_uri(
image_type, saved_image.image_name
)
thumbnail_url = ApiDependencies.invoker.services.images.get_uri(
image_type, saved_image.image_name, True
)
@images_router.delete("/{image_name}", operation_id="delete_image")
async def delete_image(
image_name: str = Path(description="The name of the image to delete"),
) -> None:
"""Deletes an image"""
res = ImageResponse(
image_type=image_type,
image_name=saved_image.image_name,
image_url=image_url,
thumbnail_url=thumbnail_url,
metadata=ImageResponseMetadata(
created=saved_image.created,
width=img.width,
height=img.height,
invokeai=invokeai_metadata,
),
)
try:
ApiDependencies.invoker.services.images.delete(image_name)
except Exception as e:
# TODO: Does this need any exception handling at all?
pass
response.status_code = 201
response.headers["Location"] = image_url
return res
@images_router.patch(
"/{image_name}",
operation_id="update_image",
response_model=ImageDTO,
)
async def update_image(
image_name: str = Path(description="The name of the image to update"),
image_changes: ImageRecordChanges = Body(
description="The changes to apply to the image"
),
) -> ImageDTO:
"""Updates an image"""
try:
return ApiDependencies.invoker.services.images.update(image_name, image_changes)
except Exception as e:
raise HTTPException(status_code=400, detail="Failed to update image")
@images_router.get(
"/{image_name}/metadata",
operation_id="get_image_metadata",
response_model=ImageDTO,
)
async def get_image_metadata(
image_name: str = Path(description="The name of image to get"),
) -> ImageDTO:
"""Gets an image's metadata"""
try:
return ApiDependencies.invoker.services.images.get_dto(image_name)
except Exception as e:
raise HTTPException(status_code=404)
@images_router.get(
"/{image_name}",
operation_id="get_image_full",
response_class=Response,
responses={
200: {
"description": "Return the full-resolution image",
"content": {"image/png": {}},
},
404: {"description": "Image not found"},
},
)
async def get_image_full(
image_name: str = Path(description="The name of full-resolution image file to get"),
) -> FileResponse:
"""Gets a full-resolution image file"""
try:
path = ApiDependencies.invoker.services.images.get_path(image_name)
if not ApiDependencies.invoker.services.images.validate_path(path):
raise HTTPException(status_code=404)
return FileResponse(
path,
media_type="image/png",
filename=image_name,
content_disposition_type="inline",
)
except Exception as e:
raise HTTPException(status_code=404)
@images_router.get(
"/{image_name}/thumbnail",
operation_id="get_image_thumbnail",
response_class=Response,
responses={
200: {
"description": "Return the image thumbnail",
"content": {"image/webp": {}},
},
404: {"description": "Image not found"},
},
)
async def get_image_thumbnail(
image_name: str = Path(description="The name of thumbnail image file to get"),
) -> FileResponse:
"""Gets a thumbnail image file"""
try:
path = ApiDependencies.invoker.services.images.get_path(
image_name, thumbnail=True
)
if not ApiDependencies.invoker.services.images.validate_path(path):
raise HTTPException(status_code=404)
return FileResponse(
path, media_type="image/webp", content_disposition_type="inline"
)
except Exception as e:
raise HTTPException(status_code=404)
@images_router.get(
"/{image_name}/urls",
operation_id="get_image_urls",
response_model=ImageUrlsDTO,
)
async def get_image_urls(
image_name: str = Path(description="The name of the image whose URL to get"),
) -> ImageUrlsDTO:
"""Gets an image and thumbnail URL"""
try:
image_url = ApiDependencies.invoker.services.images.get_url(image_name)
thumbnail_url = ApiDependencies.invoker.services.images.get_url(
image_name, thumbnail=True
)
return ImageUrlsDTO(
image_name=image_name,
image_url=image_url,
thumbnail_url=thumbnail_url,
)
except Exception as e:
raise HTTPException(status_code=404)
@images_router.get(
"/",
operation_id="list_images",
responses={200: {"model": PaginatedResults[ImageResponse]}},
operation_id="list_images_with_metadata",
response_model=OffsetPaginatedResults[ImageDTO],
)
async def list_images(
image_type: ImageType = Query(
default=ImageType.RESULT, description="The type of images to get"
async def list_images_with_metadata(
image_origin: Optional[ResourceOrigin] = Query(
default=None, description="The origin of images to list"
),
page: int = Query(default=0, description="The page of images to get"),
per_page: int = Query(default=10, description="The number of images per page"),
) -> PaginatedResults[ImageResponse]:
categories: Optional[list[ImageCategory]] = Query(
default=None, description="The categories of image to include"
),
is_intermediate: Optional[bool] = Query(
default=None, description="Whether to list intermediate images"
),
board_id: Optional[str] = Query(
default=None, description="The board id to filter by"
),
offset: int = Query(default=0, description="The page offset"),
limit: int = Query(default=10, description="The number of images per page"),
) -> OffsetPaginatedResults[ImageDTO]:
"""Gets a list of images"""
result = ApiDependencies.invoker.services.images.list(image_type, page, per_page)
return result
image_dtos = ApiDependencies.invoker.services.images.get_many(
offset,
limit,
image_origin,
categories,
is_intermediate,
board_id,
)
return image_dtos

View File

@ -1,13 +1,14 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) and Kent Keirsey (https://github.com/hipsterusername)
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) and 2023 Kent Keirsey (https://github.com/hipsterusername)
import shutil
import os
from typing import Annotated, Any, List, Literal, Optional, Union
from typing import Annotated, Literal, Optional, Union, Dict
from fastapi import Query
from fastapi.routing import APIRouter, HTTPException
from pydantic import BaseModel, Field, parse_obj_as
from pathlib import Path
from ..dependencies import ApiDependencies
from invokeai.backend import BaseModelType, ModelType
from invokeai.backend.model_management.models import OPENAPI_MODEL_CONFIGS
MODEL_CONFIGS = Union[tuple(OPENAPI_MODEL_CONFIGS)]
models_router = APIRouter(prefix="/v1/models", tags=["models"])
@ -19,6 +20,15 @@ class VaeRepo(BaseModel):
class ModelInfo(BaseModel):
description: Optional[str] = Field(description="A description of the model")
model_name: str = Field(description="The name of the model")
model_type: str = Field(description="The type of the model")
class DiffusersModelInfo(ModelInfo):
format: Literal['folder'] = 'folder'
vae: Optional[VaeRepo] = Field(description="The VAE repo to use for this model")
repo_id: Optional[str] = Field(description="The repo ID to use for this model")
path: Optional[str] = Field(description="The path to the model")
class CkptModelInfo(ModelInfo):
format: Literal['ckpt'] = 'ckpt'
@ -29,12 +39,8 @@ class CkptModelInfo(ModelInfo):
width: Optional[int] = Field(description="The width of the model")
height: Optional[int] = Field(description="The height of the model")
class DiffusersModelInfo(ModelInfo):
format: Literal['diffusers'] = 'diffusers'
vae: Optional[VaeRepo] = Field(description="The VAE repo to use for this model")
repo_id: Optional[str] = Field(description="The repo ID to use for this model")
path: Optional[str] = Field(description="The path to the model")
class SafetensorsModelInfo(CkptModelInfo):
format: Literal['safetensors'] = 'safetensors'
class CreateModelRequest(BaseModel):
name: str = Field(description="The name of the model")
@ -47,14 +53,16 @@ class CreateModelResponse(BaseModel):
class ConversionRequest(BaseModel):
name: str = Field(description="The name of the new model")
info: CkptModelInfo = Field(description="The converted model info")
save_location: str = Field(description="The path to save the converted model weights")
class ConvertedModelResponse(BaseModel):
name: str = Field(description="The name of the new model")
info: DiffusersModelInfo = Field(description="The converted model info")
class ModelsList(BaseModel):
models: dict[str, Annotated[Union[(CkptModelInfo,DiffusersModelInfo)], Field(discriminator="format")]]
models: list[MODEL_CONFIGS]
@models_router.get(
@ -62,9 +70,16 @@ class ModelsList(BaseModel):
operation_id="list_models",
responses={200: {"model": ModelsList }},
)
async def list_models() -> ModelsList:
async def list_models(
base_model: Optional[BaseModelType] = Query(
default=None, description="Base model"
),
model_type: Optional[ModelType] = Query(
default=None, description="The type of model to get"
),
) -> ModelsList:
"""Gets a list of models"""
models_raw = ApiDependencies.invoker.services.model_manager.list_models()
models_raw = ApiDependencies.invoker.services.model_manager.list_models(base_model, model_type)
models = parse_obj_as(ModelsList, { "models": models_raw })
return models
@ -119,99 +134,10 @@ async def delete_model(model_name: str) -> None:
raise HTTPException(status_code=204, detail=f"Model '{model_name}' deleted successfully")
else:
logger.error(f"Model not found")
logger.error("Model not found")
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found")
# TODO: Refactor these support functions below to live somewhere more appropriate
def get_model_info(model_name: str):
model_info = ApiDependencies.invoker.services.model_manager.model_info(
model_name=model_name
)
if not model_info:
raise HTTPException(status_code=404, detail=f"Unable to retrieve model info for '{model_name}'")
return model_info
def ckpt_validate(model_info: dict, model_name: str):
if "weights" not in model_info:
raise HTTPException(status_code=404, detail=f"Model '{model_name}' is not a valid checkpoint model")
def get_paths(model: ConversionRequest, root: Path) -> tuple:
model_info = get_model_info(model.name)
ckpt_path = Path(model_info.weights)
config_path = Path(model_info.config)
if not ckpt_path.is_absolute():
ckpt_path = Path(root, ckpt_path)
if config_path and not config_path.is_absolute():
config_path = Path(root, config_path)
return ckpt_path, config_path
def get_diffusers_path(convert_request: ConversionRequest, model_name: str) -> Path:
if convert_request.save_location == "root":
diffusers_path = Path(global_converted_ckpts_dir(), f"{model_name}_diffusers")
elif convert_request.save_location == "custom" and convert_request.save_location is not None:
diffusers_path = Path(convert_request.save_location, f"{model_name}_diffusers")
else:
raise ValueError("Invalid save_location value")
if diffusers_path.exists():
shutil.rmtree(diffusers_path)
return diffusers_path
@models_router.post(
"/{model_to_convert}",
operation_id="convert_model",
responses={
200: {
"model_response": "Model converted successfully.",
}
},
)
async def convert_model(convert_request: ConversionRequest) -> ConvertedModelResponse:
"""Convert Model"""
opt=Args()
args = opt.parse_args()
# Set the root directory for static files and relative paths
args.root_dir = os.path.expanduser(args.root_dir or "..")
if not os.path.isabs(args.outdir):
args.outdir = os.path.join(args.root_dir, args.outdir)
# normalize the config directory relative to root
if not os.path.isabs(opt.conf):
opt.conf = os.path.normpath(os.path.join(Globals.root, opt.conf))
model_info = get_model_info(convert_request.name)
ckpt_validate(model_info, convert_request.name)
ckpt_path, original_config_file = get_paths(convert_request, Globals.root)
diffusers_path = get_diffusers_path(convert_request, convert_request.name)
ApiDependencies.invoker.services.model_manager.convert_and_import(
ckpt_path,
diffusers_path,
model_name=convert_request.name,
model_description=model_info.description,
vae=None,
original_config_file=original_config_file,
commit_to_conf=opt.conf,
)
model_info = get_model_info(convert_request.name)
convert_response = ConvertedModelResponse(name=f"{convert_request.name}_diffusers", info=model_info)
print(f">> Model Converted: {convert_request.name}")
return convert_response
# @socketio.on("convertToDiffusers")
# def convert_to_diffusers(model_to_convert: dict):
# try:

View File

@ -3,7 +3,7 @@ import asyncio
from inspect import signature
import uvicorn
import invokeai.backend.util.logging as logger
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html
@ -11,13 +11,22 @@ from fastapi.openapi.utils import get_openapi
from fastapi.staticfiles import StaticFiles
from fastapi_events.handlers.local import local_handler
from fastapi_events.middleware import EventHandlerASGIMiddleware
from pathlib import Path
from pydantic.schema import schema
#This should come early so that modules can log their initialization properly
from .services.config import InvokeAIAppConfig
from ..backend.util.logging import InvokeAILogger
app_config = InvokeAIAppConfig.get_config()
app_config.parse_args()
logger = InvokeAILogger.getLogger(config=app_config)
import invokeai.frontend.web as web_dir
from .api.dependencies import ApiDependencies
from .api.routers import images, sessions, models
from .api.routers import sessions, models, images, boards, board_images
from .api.sockets import SocketIO
from .invocations.baseinvocation import BaseInvocation
from .services.config import InvokeAIAppConfig
# Create the app
# TODO: create this all in a method so configuration/etc. can be passed in?
@ -35,10 +44,6 @@ app.add_middleware(
socket_io = SocketIO(app)
# initialize config
# this is a module global
app_config = InvokeAIAppConfig()
# Add startup event to load dependencies
@app.on_event("startup")
async def startup_event():
@ -69,10 +74,13 @@ async def shutdown_event():
app.include_router(sessions.session_router, prefix="/api")
app.include_router(images.images_router, prefix="/api")
app.include_router(models.models_router, prefix="/api")
app.include_router(images.images_router, prefix="/api")
app.include_router(boards.boards_router, prefix="/api")
app.include_router(board_images.board_images_router, prefix="/api")
# Build a custom OpenAPI to include all outputs
# TODO: can outputs be included on metadata of invocation schemas somehow?
@ -112,6 +120,22 @@ def custom_openapi():
invoker_schema["output"] = outputs_ref
from invokeai.backend.model_management.models import get_model_config_enums
for model_config_format_enum in set(get_model_config_enums()):
name = model_config_format_enum.__qualname__
if name in openapi_schema["components"]["schemas"]:
# print(f"Config with name {name} already defined")
continue
# "BaseModelType":{"title":"BaseModelType","description":"An enumeration.","enum":["sd-1","sd-2"],"type":"string"}
openapi_schema["components"]["schemas"][name] = dict(
title=name,
description="An enumeration.",
type="string",
enum=list(v.value for v in model_config_format_enum),
)
app.openapi_schema = openapi_schema
return app.openapi_schema
@ -119,7 +143,7 @@ def custom_openapi():
app.openapi = custom_openapi
# Override API doc favicons
app.mount("/static", StaticFiles(directory="static/dream_web"), name="static")
app.mount("/static", StaticFiles(directory=Path(web_dir.__path__[0], 'static/dream_web')), name="static")
@app.get("/docs", include_in_schema=False)
def overridden_swagger():
@ -138,8 +162,13 @@ def overridden_redoc():
redoc_favicon_url="/static/favicon.ico",
)
# Must mount *after* the other routes else it borks em
app.mount("/", StaticFiles(directory="invokeai/frontend/web/dist", html=True), name="ui")
app.mount("/",
StaticFiles(directory=Path(web_dir.__path__[0],"dist"),
html=True
), name="ui"
)
def invoke_api():
# Start our own event loop for eventing usage

View File

@ -6,35 +6,44 @@ import re
import shlex
import sys
import time
from typing import (
Union,
get_type_hints,
)
from typing import Union, get_type_hints
from pydantic import BaseModel, ValidationError
from pydantic.fields import Field
# This should come early so that the logger can pick up its configuration options
from .services.config import InvokeAIAppConfig
from invokeai.backend.util.logging import InvokeAILogger
config = InvokeAIAppConfig.get_config()
config.parse_args()
logger = InvokeAILogger().getLogger(config=config)
import invokeai.backend.util.logging as logger
from invokeai.app.services.metadata import PngMetadataService
from .services.default_graphs import create_system_graphs
from invokeai.app.services.image_record_storage import SqliteImageRecordStorage
from invokeai.app.services.images import ImageService
from invokeai.app.services.metadata import CoreMetadataService
from invokeai.app.services.resource_name import SimpleNameService
from invokeai.app.services.urls import LocalUrlService
from .services.default_graphs import (default_text_to_image_graph_id,
create_system_graphs)
from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
from .cli.commands import BaseCommand, CliContext, ExitCli, add_graph_parsers, add_parsers, SortedHelpFormatter
from .cli.commands import (BaseCommand, CliContext, ExitCli,
SortedHelpFormatter, add_graph_parsers, add_parsers)
from .cli.completer import set_autocompleter
from .invocations.baseinvocation import BaseInvocation
from .services.events import EventServiceBase
from .services.model_manager_initializer import get_model_manager
from .services.restoration_services import RestorationServices
from .services.graph import Edge, EdgeConnection, GraphExecutionState, GraphInvocation, LibraryGraph, are_connection_types_compatible
from .services.default_graphs import default_text_to_image_graph_id
from .services.image_storage import DiskImageStorage
from .services.graph import (Edge, EdgeConnection, GraphExecutionState,
GraphInvocation, LibraryGraph,
are_connection_types_compatible)
from .services.image_file_storage import DiskImageFileStorage
from .services.invocation_queue import MemoryInvocationQueue
from .services.invocation_services import InvocationServices
from .services.invoker import Invoker
from .services.model_manager_service import ModelManagerService
from .services.processor import DefaultInvocationProcessor
from .services.restoration_services import RestorationServices
from .services.sqlite import SqliteItemStorage
from .services.config import get_invokeai_config
class CliCommand(BaseModel):
command: Union[BaseCommand.get_commands() + BaseInvocation.get_invocations()] = Field(discriminator="type") # type: ignore
@ -43,7 +52,6 @@ class CliCommand(BaseModel):
class InvalidArgs(Exception):
pass
def add_invocation_args(command_parser):
# Add linking capability
command_parser.add_argument(
@ -187,11 +195,7 @@ def invoke_all(context: CliContext):
raise SessionError()
def invoke_cli():
# this gets the basic configuration
config = get_invokeai_config()
# get the optional list of invocations to execute on the command line
parser = config.get_parser()
parser.add_argument('commands',nargs='*')
@ -202,36 +206,60 @@ def invoke_cli():
if infile := config.from_file:
sys.stdin = open(infile,"r")
model_manager = get_model_manager(config,logger=logger)
model_manager = ModelManagerService(config,logger)
events = EventServiceBase()
output_folder = config.output_path
metadata = PngMetadataService()
# TODO: build a file/path manager?
db_location = os.path.join(output_folder, "invokeai.db")
if config.use_memory_db:
db_location = ":memory:"
else:
db_location = config.db_path
db_location.parent.mkdir(parents=True,exist_ok=True)
logger.info(f'InvokeAI database location is "{db_location}"')
graph_execution_manager = SqliteItemStorage[GraphExecutionState](
filename=db_location, table_name="graph_executions"
)
urls = LocalUrlService()
metadata = CoreMetadataService()
image_record_storage = SqliteImageRecordStorage(db_location)
image_file_storage = DiskImageFileStorage(f"{output_folder}/images")
names = SimpleNameService()
images = ImageService(
image_record_storage=image_record_storage,
image_file_storage=image_file_storage,
metadata=metadata,
url=urls,
logger=logger,
names=names,
graph_execution_manager=graph_execution_manager,
)
services = InvocationServices(
model_manager=model_manager,
events=events,
latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f'{output_folder}/latents')),
images=DiskImageStorage(f'{output_folder}/images', metadata_service=metadata),
metadata=metadata,
images=images,
queue=MemoryInvocationQueue(),
graph_library=SqliteItemStorage[LibraryGraph](
filename=db_location, table_name="graphs"
),
graph_execution_manager=SqliteItemStorage[GraphExecutionState](
filename=db_location, table_name="graph_executions"
),
graph_execution_manager=graph_execution_manager,
processor=DefaultInvocationProcessor(),
restoration=RestorationServices(config,logger=logger),
logger=logger,
configuration=config,
)
system_graphs = create_system_graphs(services.graph_library)
system_graph_names = set([g.name for g in system_graphs])
set_autocompleter(services)
invoker = Invoker(services)
session: GraphExecutionState = invoker.create_execution_state()

View File

@ -1,12 +1,15 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from __future__ import annotations
from abc import ABC, abstractmethod
from inspect import signature
from typing import get_args, get_type_hints, Dict, List, Literal, TypedDict
from typing import get_args, get_type_hints, Dict, List, Literal, TypedDict, TYPE_CHECKING
from pydantic import BaseModel, Field
from ..services.invocation_services import InvocationServices
if TYPE_CHECKING:
from ..services.invocation_services import InvocationServices
class InvocationContext:
@ -75,6 +78,7 @@ class BaseInvocation(ABC, BaseModel):
#fmt: off
id: str = Field(description="The id of this node. Must be unique among all nodes.")
is_intermediate: bool = Field(default=False, description="Whether or not this node is an intermediate node.")
#fmt: on
@ -92,6 +96,7 @@ class UIConfig(TypedDict, total=False):
"image",
"latents",
"model",
"control",
],
]
tags: List[str]

View File

@ -1,9 +1,9 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
from typing import Literal, Optional
from typing import Literal
import numpy as np
from pydantic import Field
from pydantic import Field, validator
from invokeai.app.util.misc import SEED_MAX, get_random_seed
@ -22,9 +22,17 @@ class IntCollectionOutput(BaseInvocationOutput):
# Outputs
collection: list[int] = Field(default=[], description="The int collection")
class FloatCollectionOutput(BaseInvocationOutput):
"""A collection of floats"""
type: Literal["float_collection"] = "float_collection"
# Outputs
collection: list[float] = Field(default=[], description="The float collection")
class RangeInvocation(BaseInvocation):
"""Creates a range"""
"""Creates a range of numbers from start to stop with step"""
type: Literal["range"] = "range"
@ -33,12 +41,34 @@ class RangeInvocation(BaseInvocation):
stop: int = Field(default=10, description="The stop of the range")
step: int = Field(default=1, description="The step of the range")
@validator("stop")
def stop_gt_start(cls, v, values):
if "start" in values and v <= values["start"]:
raise ValueError("stop must be greater than start")
return v
def invoke(self, context: InvocationContext) -> IntCollectionOutput:
return IntCollectionOutput(
collection=list(range(self.start, self.stop, self.step))
)
class RangeOfSizeInvocation(BaseInvocation):
"""Creates a range from start to start + size with step"""
type: Literal["range_of_size"] = "range_of_size"
# Inputs
start: int = Field(default=0, description="The start of the range")
size: int = Field(default=1, description="The number of values")
step: int = Field(default=1, description="The step of the range")
def invoke(self, context: InvocationContext) -> IntCollectionOutput:
return IntCollectionOutput(
collection=list(range(self.start, self.start + self.size, self.step))
)
class RandomRangeInvocation(BaseInvocation):
"""Creates a collection of random numbers"""

View File

@ -1,19 +1,22 @@
from typing import Literal, Optional, Union
from pydantic import BaseModel, Field
from contextlib import ExitStack
import re
from invokeai.app.invocations.util.choose_model import choose_model
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
from .model import ClipField
from ...backend.util.devices import choose_torch_device, torch_dtype
from ...backend.util.devices import torch_dtype
from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent
from ...backend.stable_diffusion.textual_inversion_manager import TextualInversionManager
from ...backend.model_management import BaseModelType, ModelType, SubModelType
from ...backend.model_management.lora import ModelPatcher
from compel import Compel
from compel.prompt_parser import (
Blend,
CrossAttentionControlSubstitute,
FlattenedPrompt,
Fragment,
Fragment, Conjunction,
)
@ -39,7 +42,7 @@ class CompelInvocation(BaseInvocation):
type: Literal["compel"] = "compel"
prompt: str = Field(default="", description="Prompt")
model: str = Field(default="", description="Model to use")
clip: ClipField = Field(None, description="Clip to use")
# Schema customisation
class Config(InvocationConfig):
@ -55,87 +58,93 @@ class CompelInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> CompelOutput:
# TODO: load without model
model = choose_model(context.services.model_manager, self.model)
pipeline = model["model"]
tokenizer = pipeline.tokenizer
text_encoder = pipeline.text_encoder
# TODO: global? input?
#use_full_precision = precision == "float32" or precision == "autocast"
#use_full_precision = False
# TODO: redo TI when separate model loding implemented
#textual_inversion_manager = TextualInversionManager(
# tokenizer=tokenizer,
# text_encoder=text_encoder,
# full_precision=use_full_precision,
#)
def load_huggingface_concepts(concepts: list[str]):
pipeline.textual_inversion_manager.load_huggingface_concepts(concepts)
# apply the concepts library to the prompt
prompt_str = pipeline.textual_inversion_manager.hf_concepts_library.replace_concepts_with_triggers(
self.prompt,
lambda concepts: load_huggingface_concepts(concepts),
pipeline.textual_inversion_manager.get_all_trigger_strings(),
tokenizer_info = context.services.model_manager.get_model(
**self.clip.tokenizer.dict(),
)
# lazy-load any deferred textual inversions.
# this might take a couple of seconds the first time a textual inversion is used.
pipeline.textual_inversion_manager.create_deferred_token_ids_for_any_trigger_terms(
prompt_str
text_encoder_info = context.services.model_manager.get_model(
**self.clip.text_encoder.dict(),
)
with tokenizer_info as orig_tokenizer,\
text_encoder_info as text_encoder,\
ExitStack() as stack:
compel = Compel(
tokenizer=tokenizer,
text_encoder=text_encoder,
textual_inversion_manager=pipeline.textual_inversion_manager,
dtype_for_device_getter=torch_dtype,
truncate_long_prompts=True, # TODO:
)
loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.clip.loras]
# TODO: support legacy blend?
ti_list = []
for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", self.prompt):
name = trigger[1:-1]
try:
ti_list.append(
stack.enter_context(
context.services.model_manager.get_model(
model_name=name,
base_model=self.clip.text_encoder.base_model,
model_type=ModelType.TextualInversion,
)
)
)
except Exception:
#print(e)
#import traceback
#print(traceback.format_exc())
print(f"Warn: trigger: \"{trigger}\" not found")
conjunction = Compel.parse_prompt_string(prompt_str)
prompt: Union[FlattenedPrompt, Blend] = conjunction.prompts[0]
with ModelPatcher.apply_lora_text_encoder(text_encoder, loras),\
ModelPatcher.apply_ti(orig_tokenizer, text_encoder, ti_list) as (tokenizer, ti_manager):
if context.services.configuration.log_tokenization:
log_tokenization_for_prompt_object(prompt, tokenizer)
compel = Compel(
tokenizer=tokenizer,
text_encoder=text_encoder,
textual_inversion_manager=ti_manager,
dtype_for_device_getter=torch_dtype,
truncate_long_prompts=True, # TODO:
)
conjunction = Compel.parse_prompt_string(self.prompt)
prompt: Union[FlattenedPrompt, Blend] = conjunction.prompts[0]
c, options = compel.build_conditioning_tensor_for_prompt_object(prompt)
if context.services.configuration.log_tokenization:
log_tokenization_for_prompt_object(prompt, tokenizer)
# TODO: long prompt support
#if not self.truncate_long_prompts:
# [c, uc] = compel.pad_conditioning_tensors_to_same_length([c, uc])
c, options = compel.build_conditioning_tensor_for_prompt_object(prompt)
# TODO: long prompt support
#if not self.truncate_long_prompts:
# [c, uc] = compel.pad_conditioning_tensors_to_same_length([c, uc])
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(
tokens_count_including_eos_bos=get_max_token_count(tokenizer, conjunction),
cross_attention_control_args=options.get("cross_attention_control", None),
)
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(
tokens_count_including_eos_bos=get_max_token_count(tokenizer, prompt),
cross_attention_control_args=options.get("cross_attention_control", None),
)
# TODO: hacky but works ;D maybe rename latents somehow?
context.services.latents.save(conditioning_name, (c, ec))
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
# TODO: hacky but works ;D maybe rename latents somehow?
context.services.latents.set(conditioning_name, (c, ec))
return CompelOutput(
conditioning=ConditioningField(
conditioning_name=conditioning_name,
),
)
return CompelOutput(
conditioning=ConditioningField(
conditioning_name=conditioning_name,
),
)
def get_max_token_count(
tokenizer, prompt: Union[FlattenedPrompt, Blend], truncate_if_too_long=False
tokenizer, prompt: Union[FlattenedPrompt, Blend, Conjunction], truncate_if_too_long=False
) -> int:
if type(prompt) is Blend:
blend: Blend = prompt
return max(
[
get_max_token_count(tokenizer, c, truncate_if_too_long)
for c in blend.prompts
get_max_token_count(tokenizer, p, truncate_if_too_long)
for p in blend.prompts
]
)
elif type(prompt) is Conjunction:
conjunction: Conjunction = prompt
return sum(
[
get_max_token_count(tokenizer, p, truncate_if_too_long)
for p in conjunction.prompts
]
)
else:
@ -170,6 +179,22 @@ def get_tokens_for_prompt_object(
return tokens
def log_tokenization_for_conjunction(
c: Conjunction, tokenizer, display_label_prefix=None
):
display_label_prefix = display_label_prefix or ""
for i, p in enumerate(c.prompts):
if len(c.prompts)>1:
this_display_label_prefix = f"{display_label_prefix}(conjunction part {i + 1}, weight={c.weights[i]})"
else:
this_display_label_prefix = display_label_prefix
log_tokenization_for_prompt_object(
p,
tokenizer,
display_label_prefix=this_display_label_prefix
)
def log_tokenization_for_prompt_object(
p: Union[Blend, FlattenedPrompt], tokenizer, display_label_prefix=None
):

View File

@ -0,0 +1,454 @@
# InvokeAI nodes for ControlNet image preprocessors
# initial implementation by Gregg Helt, 2023
# heavily leverages controlnet_aux package: https://github.com/patrickvonplaten/controlnet_aux
from builtins import float
import numpy as np
from typing import Literal, Optional, Union, List
from PIL import Image, ImageFilter, ImageOps
from pydantic import BaseModel, Field, validator
from ..models.image import ImageField, ImageCategory, ResourceOrigin
from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
InvocationContext,
InvocationConfig,
)
from controlnet_aux import (
CannyDetector,
HEDdetector,
LineartDetector,
LineartAnimeDetector,
MidasDetector,
MLSDdetector,
NormalBaeDetector,
OpenposeDetector,
PidiNetDetector,
ContentShuffleDetector,
ZoeDetector,
MediapipeFaceDetector,
)
from .image import ImageOutput, PILInvocationConfig
CONTROLNET_DEFAULT_MODELS = [
###########################################
# lllyasviel sd v1.5, ControlNet v1.0 models
##############################################
"lllyasviel/sd-controlnet-canny",
"lllyasviel/sd-controlnet-depth",
"lllyasviel/sd-controlnet-hed",
"lllyasviel/sd-controlnet-seg",
"lllyasviel/sd-controlnet-openpose",
"lllyasviel/sd-controlnet-scribble",
"lllyasviel/sd-controlnet-normal",
"lllyasviel/sd-controlnet-mlsd",
#############################################
# lllyasviel sd v1.5, ControlNet v1.1 models
#############################################
"lllyasviel/control_v11p_sd15_canny",
"lllyasviel/control_v11p_sd15_openpose",
"lllyasviel/control_v11p_sd15_seg",
# "lllyasviel/control_v11p_sd15_depth", # broken
"lllyasviel/control_v11f1p_sd15_depth",
"lllyasviel/control_v11p_sd15_normalbae",
"lllyasviel/control_v11p_sd15_scribble",
"lllyasviel/control_v11p_sd15_mlsd",
"lllyasviel/control_v11p_sd15_softedge",
"lllyasviel/control_v11p_sd15s2_lineart_anime",
"lllyasviel/control_v11p_sd15_lineart",
"lllyasviel/control_v11p_sd15_inpaint",
# "lllyasviel/control_v11u_sd15_tile",
# problem (temporary?) with huffingface "lllyasviel/control_v11u_sd15_tile",
# so for now replace "lllyasviel/control_v11f1e_sd15_tile",
"lllyasviel/control_v11e_sd15_shuffle",
"lllyasviel/control_v11e_sd15_ip2p",
"lllyasviel/control_v11f1e_sd15_tile",
#################################################
# thibaud sd v2.1 models (ControlNet v1.0? or v1.1?
##################################################
"thibaud/controlnet-sd21-openpose-diffusers",
"thibaud/controlnet-sd21-canny-diffusers",
"thibaud/controlnet-sd21-depth-diffusers",
"thibaud/controlnet-sd21-scribble-diffusers",
"thibaud/controlnet-sd21-hed-diffusers",
"thibaud/controlnet-sd21-zoedepth-diffusers",
"thibaud/controlnet-sd21-color-diffusers",
"thibaud/controlnet-sd21-openposev2-diffusers",
"thibaud/controlnet-sd21-lineart-diffusers",
"thibaud/controlnet-sd21-normalbae-diffusers",
"thibaud/controlnet-sd21-ade20k-diffusers",
##############################################
# ControlNetMediaPipeface, ControlNet v1.1
##############################################
# ["CrucibleAI/ControlNetMediaPipeFace", "diffusion_sd15"], # SD 1.5
# diffusion_sd15 needs to be passed to from_pretrained() as subfolder arg
# hacked t2l to split to model & subfolder if format is "model,subfolder"
"CrucibleAI/ControlNetMediaPipeFace,diffusion_sd15", # SD 1.5
"CrucibleAI/ControlNetMediaPipeFace", # SD 2.1?
]
CONTROLNET_NAME_VALUES = Literal[tuple(CONTROLNET_DEFAULT_MODELS)]
class ControlField(BaseModel):
image: ImageField = Field(default=None, description="The control image")
control_model: Optional[str] = Field(default=None, description="The ControlNet model to use")
# control_weight: Optional[float] = Field(default=1, description="weight given to controlnet")
control_weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet")
begin_step_percent: float = Field(default=0, ge=0, le=1,
description="When the ControlNet is first applied (% of total steps)")
end_step_percent: float = Field(default=1, ge=0, le=1,
description="When the ControlNet is last applied (% of total steps)")
@validator("control_weight")
def abs_le_one(cls, v):
"""validate that all abs(values) are <=1"""
if isinstance(v, list):
for i in v:
if abs(i) > 1:
raise ValueError('all abs(control_weight) must be <= 1')
else:
if abs(v) > 1:
raise ValueError('abs(control_weight) must be <= 1')
return v
class Config:
schema_extra = {
"required": ["image", "control_model", "control_weight", "begin_step_percent", "end_step_percent"],
"ui": {
"type_hints": {
"control_weight": "float",
# "control_weight": "number",
}
}
}
class ControlOutput(BaseInvocationOutput):
"""node output for ControlNet info"""
# fmt: off
type: Literal["control_output"] = "control_output"
control: ControlField = Field(default=None, description="The control info")
# fmt: on
class ControlNetInvocation(BaseInvocation):
"""Collects ControlNet info to pass to other nodes"""
# fmt: off
type: Literal["controlnet"] = "controlnet"
# Inputs
image: ImageField = Field(default=None, description="The control image")
control_model: CONTROLNET_NAME_VALUES = Field(default="lllyasviel/sd-controlnet-canny",
description="control model used")
control_weight: Union[float, List[float]] = Field(default=1.0, description="The weight given to the ControlNet")
# TODO: add support in backend core for begin_step_percent, end_step_percent, guess_mode
begin_step_percent: float = Field(default=0, ge=0, le=1,
description="When the ControlNet is first applied (% of total steps)")
end_step_percent: float = Field(default=1, ge=0, le=1,
description="When the ControlNet is last applied (% of total steps)")
# fmt: on
class Config(InvocationConfig):
schema_extra = {
"ui": {
"tags": ["latents"],
"type_hints": {
"model": "model",
"control": "control",
# "cfg_scale": "float",
"cfg_scale": "number",
"control_weight": "float",
}
},
}
def invoke(self, context: InvocationContext) -> ControlOutput:
return ControlOutput(
control=ControlField(
image=self.image,
control_model=self.control_model,
control_weight=self.control_weight,
begin_step_percent=self.begin_step_percent,
end_step_percent=self.end_step_percent,
),
)
# TODO: move image processors to separate file (image_analysis.py
class ImageProcessorInvocation(BaseInvocation, PILInvocationConfig):
"""Base class for invocations that preprocess images for ControlNet"""
# fmt: off
type: Literal["image_processor"] = "image_processor"
# Inputs
image: ImageField = Field(default=None, description="The image to process")
# fmt: on
def run_processor(self, image):
# superclass just passes through image without processing
return image
def invoke(self, context: InvocationContext) -> ImageOutput:
raw_image = context.services.images.get_pil_image(self.image.image_name)
# image type should be PIL.PngImagePlugin.PngImageFile ?
processed_image = self.run_processor(raw_image)
# FIXME: what happened to image metadata?
# metadata = context.services.metadata.build_metadata(
# session_id=context.graph_execution_state_id, node=self
# )
# currently can't see processed image in node UI without a showImage node,
# so for now setting image_type to RESULT instead of INTERMEDIATE so will get saved in gallery
image_dto = context.services.images.create(
image=processed_image,
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.CONTROL,
session_id=context.graph_execution_state_id,
node_id=self.id,
is_intermediate=self.is_intermediate
)
"""Builds an ImageOutput and its ImageField"""
processed_image_field = ImageField(image_name=image_dto.image_name)
return ImageOutput(
image=processed_image_field,
# width=processed_image.width,
width = image_dto.width,
# height=processed_image.height,
height = image_dto.height,
# mode=processed_image.mode,
)
class CannyImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
"""Canny edge detection for ControlNet"""
# fmt: off
type: Literal["canny_image_processor"] = "canny_image_processor"
# Input
low_threshold: int = Field(default=100, ge=0, le=255, description="The low threshold of the Canny pixel gradient (0-255)")
high_threshold: int = Field(default=200, ge=0, le=255, description="The high threshold of the Canny pixel gradient (0-255)")
# fmt: on
def run_processor(self, image):
canny_processor = CannyDetector()
processed_image = canny_processor(image, self.low_threshold, self.high_threshold)
return processed_image
class HedImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
"""Applies HED edge detection to image"""
# fmt: off
type: Literal["hed_image_processor"] = "hed_image_processor"
# Inputs
detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection")
image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image")
# safe not supported in controlnet_aux v0.0.3
# safe: bool = Field(default=False, description="whether to use safe mode")
scribble: bool = Field(default=False, description="Whether to use scribble mode")
# fmt: on
def run_processor(self, image):
hed_processor = HEDdetector.from_pretrained("lllyasviel/Annotators")
processed_image = hed_processor(image,
detect_resolution=self.detect_resolution,
image_resolution=self.image_resolution,
# safe not supported in controlnet_aux v0.0.3
# safe=self.safe,
scribble=self.scribble,
)
return processed_image
class LineartImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
"""Applies line art processing to image"""
# fmt: off
type: Literal["lineart_image_processor"] = "lineart_image_processor"
# Inputs
detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection")
image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image")
coarse: bool = Field(default=False, description="Whether to use coarse mode")
# fmt: on
def run_processor(self, image):
lineart_processor = LineartDetector.from_pretrained("lllyasviel/Annotators")
processed_image = lineart_processor(image,
detect_resolution=self.detect_resolution,
image_resolution=self.image_resolution,
coarse=self.coarse)
return processed_image
class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
"""Applies line art anime processing to image"""
# fmt: off
type: Literal["lineart_anime_image_processor"] = "lineart_anime_image_processor"
# Inputs
detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection")
image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image")
# fmt: on
def run_processor(self, image):
processor = LineartAnimeDetector.from_pretrained("lllyasviel/Annotators")
processed_image = processor(image,
detect_resolution=self.detect_resolution,
image_resolution=self.image_resolution,
)
return processed_image
class OpenposeImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
"""Applies Openpose processing to image"""
# fmt: off
type: Literal["openpose_image_processor"] = "openpose_image_processor"
# Inputs
hand_and_face: bool = Field(default=False, description="Whether to use hands and face mode")
detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection")
image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image")
# fmt: on
def run_processor(self, image):
openpose_processor = OpenposeDetector.from_pretrained("lllyasviel/Annotators")
processed_image = openpose_processor(image,
detect_resolution=self.detect_resolution,
image_resolution=self.image_resolution,
hand_and_face=self.hand_and_face,
)
return processed_image
class MidasDepthImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
"""Applies Midas depth processing to image"""
# fmt: off
type: Literal["midas_depth_image_processor"] = "midas_depth_image_processor"
# Inputs
a_mult: float = Field(default=2.0, ge=0, description="Midas parameter `a_mult` (a = a_mult * PI)")
bg_th: float = Field(default=0.1, ge=0, description="Midas parameter `bg_th`")
# depth_and_normal not supported in controlnet_aux v0.0.3
# depth_and_normal: bool = Field(default=False, description="whether to use depth and normal mode")
# fmt: on
def run_processor(self, image):
midas_processor = MidasDetector.from_pretrained("lllyasviel/Annotators")
processed_image = midas_processor(image,
a=np.pi * self.a_mult,
bg_th=self.bg_th,
# dept_and_normal not supported in controlnet_aux v0.0.3
# depth_and_normal=self.depth_and_normal,
)
return processed_image
class NormalbaeImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
"""Applies NormalBae processing to image"""
# fmt: off
type: Literal["normalbae_image_processor"] = "normalbae_image_processor"
# Inputs
detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection")
image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image")
# fmt: on
def run_processor(self, image):
normalbae_processor = NormalBaeDetector.from_pretrained("lllyasviel/Annotators")
processed_image = normalbae_processor(image,
detect_resolution=self.detect_resolution,
image_resolution=self.image_resolution)
return processed_image
class MlsdImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
"""Applies MLSD processing to image"""
# fmt: off
type: Literal["mlsd_image_processor"] = "mlsd_image_processor"
# Inputs
detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection")
image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image")
thr_v: float = Field(default=0.1, ge=0, description="MLSD parameter `thr_v`")
thr_d: float = Field(default=0.1, ge=0, description="MLSD parameter `thr_d`")
# fmt: on
def run_processor(self, image):
mlsd_processor = MLSDdetector.from_pretrained("lllyasviel/Annotators")
processed_image = mlsd_processor(image,
detect_resolution=self.detect_resolution,
image_resolution=self.image_resolution,
thr_v=self.thr_v,
thr_d=self.thr_d)
return processed_image
class PidiImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
"""Applies PIDI processing to image"""
# fmt: off
type: Literal["pidi_image_processor"] = "pidi_image_processor"
# Inputs
detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection")
image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image")
safe: bool = Field(default=False, description="Whether to use safe mode")
scribble: bool = Field(default=False, description="Whether to use scribble mode")
# fmt: on
def run_processor(self, image):
pidi_processor = PidiNetDetector.from_pretrained("lllyasviel/Annotators")
processed_image = pidi_processor(image,
detect_resolution=self.detect_resolution,
image_resolution=self.image_resolution,
safe=self.safe,
scribble=self.scribble)
return processed_image
class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
"""Applies content shuffle processing to image"""
# fmt: off
type: Literal["content_shuffle_image_processor"] = "content_shuffle_image_processor"
# Inputs
detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection")
image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image")
h: Union[int, None] = Field(default=512, ge=0, description="Content shuffle `h` parameter")
w: Union[int, None] = Field(default=512, ge=0, description="Content shuffle `w` parameter")
f: Union[int, None] = Field(default=256, ge=0, description="Content shuffle `f` parameter")
# fmt: on
def run_processor(self, image):
content_shuffle_processor = ContentShuffleDetector()
processed_image = content_shuffle_processor(image,
detect_resolution=self.detect_resolution,
image_resolution=self.image_resolution,
h=self.h,
w=self.w,
f=self.f
)
return processed_image
# should work with controlnet_aux >= 0.0.4 and timm <= 0.6.13
class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
"""Applies Zoe depth processing to image"""
# fmt: off
type: Literal["zoe_depth_image_processor"] = "zoe_depth_image_processor"
# fmt: on
def run_processor(self, image):
zoe_depth_processor = ZoeDetector.from_pretrained("lllyasviel/Annotators")
processed_image = zoe_depth_processor(image)
return processed_image
class MediapipeFaceProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
"""Applies mediapipe face processing to image"""
# fmt: off
type: Literal["mediapipe_face_processor"] = "mediapipe_face_processor"
# Inputs
max_faces: int = Field(default=1, ge=1, description="Maximum number of faces to detect")
min_confidence: float = Field(default=0.5, ge=0, le=1, description="Minimum confidence for face detection")
# fmt: on
def run_processor(self, image):
mediapipe_face_processor = MediapipeFaceDetector()
processed_image = mediapipe_face_processor(image, max_faces=self.max_faces, min_confidence=self.min_confidence)
return processed_image

View File

@ -7,9 +7,9 @@ import numpy
from PIL import Image, ImageOps
from pydantic import BaseModel, Field
from invokeai.app.models.image import ImageField, ImageType
from invokeai.app.models.image import ImageCategory, ImageField, ResourceOrigin
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
from .image import ImageOutput, build_image_output
from .image import ImageOutput
class CvInvocationConfig(BaseModel):
@ -26,24 +26,23 @@ class CvInvocationConfig(BaseModel):
class CvInpaintInvocation(BaseInvocation, CvInvocationConfig):
"""Simple inpaint using opencv."""
#fmt: off
# fmt: off
type: Literal["cv_inpaint"] = "cv_inpaint"
# Inputs
image: ImageField = Field(default=None, description="The image to inpaint")
mask: ImageField = Field(default=None, description="The mask to use when inpainting")
#fmt: on
# fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get(
self.image.image_type, self.image.image_name
)
mask = context.services.images.get(self.mask.image_type, self.mask.image_name)
image = context.services.images.get_pil_image(self.image.image_name)
mask = context.services.images.get_pil_image(self.mask.image_name)
# Convert to cv image/mask
# TODO: consider making these utility functions
cv_image = cv.cvtColor(numpy.array(image.convert("RGB")), cv.COLOR_RGB2BGR)
cv_mask = numpy.array(ImageOps.invert(mask))
cv_mask = numpy.array(ImageOps.invert(mask.convert("L")))
# Inpaint
cv_inpainted = cv.inpaint(cv_image, cv_mask, 3, cv.INPAINT_TELEA)
@ -52,18 +51,17 @@ class CvInpaintInvocation(BaseInvocation, CvInvocationConfig):
# TODO: consider making a utility function
image_inpainted = Image.fromarray(cv.cvtColor(cv_inpainted, cv.COLOR_BGR2RGB))
image_type = ImageType.INTERMEDIATE
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
image_dto = context.services.images.create(
image=image_inpainted,
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
)
metadata = context.services.metadata.build_metadata(
session_id=context.graph_execution_state_id, node=self
return ImageOutput(
image=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
)
context.services.images.save(image_type, image_name, image_inpainted, metadata)
return build_image_output(
image_type=image_type,
image_name=image_name,
image=image_inpainted,
)

View File

@ -3,120 +3,77 @@
from functools import partial
from typing import Literal, Optional, Union, get_args
import numpy as np
from torch import Tensor
import torch
from diffusers import ControlNetModel
from pydantic import BaseModel, Field
from invokeai.app.models.image import ColorField, ImageField, ImageType
from invokeai.app.invocations.util.choose_model import choose_model
from invokeai.app.models.image import (ColorField, ImageCategory, ImageField,
ResourceOrigin)
from invokeai.app.util.misc import SEED_MAX, get_random_seed
from invokeai.backend.generator.inpaint import infill_methods
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
from .image import ImageOutput, build_image_output
from ...backend.generator import Txt2Img, Img2Img, Inpaint, InvokeAIGenerator
from ...backend.generator import Inpaint, InvokeAIGenerator
from ...backend.stable_diffusion import PipelineIntermediateState
from ..util.step_callback import stable_diffusion_step_callback
from .baseinvocation import BaseInvocation, InvocationConfig, InvocationContext
from .image import ImageOutput
import re
from ...backend.model_management.lora import ModelPatcher
from ...backend.stable_diffusion.diffusers_pipeline import StableDiffusionGeneratorPipeline
from .model import UNetField, VaeField
from .compel import ConditioningField
from contextlib import contextmanager, ExitStack, ContextDecorator
SAMPLER_NAME_VALUES = Literal[tuple(InvokeAIGenerator.schedulers())]
INFILL_METHODS = Literal[tuple(infill_methods())]
DEFAULT_INFILL_METHOD = 'patchmatch' if 'patchmatch' in get_args(INFILL_METHODS) else 'tile'
class SDImageInvocation(BaseModel):
"""Helper class to provide all Stable Diffusion raster image invocations with additional config"""
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"tags": ["stable-diffusion", "image"],
"type_hints": {
"model": "model",
},
},
}
DEFAULT_INFILL_METHOD = (
"patchmatch" if "patchmatch" in get_args(INFILL_METHODS) else "tile"
)
# Text to image
class TextToImageInvocation(BaseInvocation, SDImageInvocation):
"""Generates an image using text2img."""
from .latent import get_scheduler
type: Literal["txt2img"] = "txt2img"
class OldModelContext(ContextDecorator):
model: StableDiffusionGeneratorPipeline
# Inputs
# TODO: consider making prompt optional to enable providing prompt through a link
# fmt: off
prompt: Optional[str] = Field(description="The prompt to generate an image from")
def __init__(self, model):
self.model = model
def __enter__(self):
return self.model
def __exit__(self, *exc):
return False
class OldModelInfo:
name: str
hash: str
context: OldModelContext
def __init__(self, name: str, hash: str, model: StableDiffusionGeneratorPipeline):
self.name = name
self.hash = hash
self.context = OldModelContext(
model=model,
)
class InpaintInvocation(BaseInvocation):
"""Generates an image using inpaint."""
type: Literal["inpaint"] = "inpaint"
positive_conditioning: Optional[ConditioningField] = Field(description="Positive conditioning for generation")
negative_conditioning: Optional[ConditioningField] = Field(description="Negative conditioning for generation")
seed: int = Field(ge=0, le=SEED_MAX, description="The seed to use (omit for random)", default_factory=get_random_seed)
steps: int = Field(default=30, gt=0, description="The number of steps to use to generate the image")
width: int = Field(default=512, multiple_of=8, gt=0, description="The width of the resulting image", )
height: int = Field(default=512, multiple_of=8, gt=0, description="The height of the resulting image", )
cfg_scale: float = Field(default=7.5, ge=1, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
scheduler: SAMPLER_NAME_VALUES = Field(default="lms", description="The scheduler to use" )
model: str = Field(default="", description="The model to use (currently ignored)")
# fmt: on
# TODO: pass this an emitter method or something? or a session for dispatching?
def dispatch_progress(
self,
context: InvocationContext,
source_node_id: str,
intermediate_state: PipelineIntermediateState,
) -> None:
stable_diffusion_step_callback(
context=context,
intermediate_state=intermediate_state,
node=self.dict(),
source_node_id=source_node_id,
)
def invoke(self, context: InvocationContext) -> ImageOutput:
# Handle invalid model parameter
model = choose_model(context.services.model_manager, self.model)
# Get the source node id (we are invoking the prepared node)
graph_execution_state = context.services.graph_execution_manager.get(
context.graph_execution_state_id
)
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
outputs = Txt2Img(model).generate(
prompt=self.prompt,
step_callback=partial(self.dispatch_progress, context, source_node_id),
**self.dict(
exclude={"prompt"}
), # Shorthand for passing all of the parameters above manually
)
# Outputs is an infinite iterator that will return a new InvokeAIGeneratorOutput object
# each time it is called. We only need the first one.
generate_output = next(outputs)
# Results are image and seed, unwrap for now and ignore the seed
# TODO: pre-seed?
# TODO: can this return multiple results? Should it?
image_type = ImageType.RESULT
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
metadata = context.services.metadata.build_metadata(
session_id=context.graph_execution_state_id, node=self
)
context.services.images.save(
image_type, image_name, generate_output.image, metadata
)
return build_image_output(
image_type=image_type,
image_name=image_name,
image=generate_output.image,
)
class ImageToImageInvocation(TextToImageInvocation):
"""Generates an image using img2img."""
type: Literal["img2img"] = "img2img"
scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" )
unet: UNetField = Field(default=None, description="UNet model")
vae: VaeField = Field(default=None, description="Vae model")
# Inputs
image: Union[ImageField, None] = Field(description="The input image")
@ -128,92 +85,41 @@ class ImageToImageInvocation(TextToImageInvocation):
description="Whether or not the result should be fit to the aspect ratio of the input image",
)
def dispatch_progress(
self,
context: InvocationContext,
source_node_id: str,
intermediate_state: PipelineIntermediateState,
) -> None:
stable_diffusion_step_callback(
context=context,
intermediate_state=intermediate_state,
node=self.dict(),
source_node_id=source_node_id,
)
def invoke(self, context: InvocationContext) -> ImageOutput:
image = (
None
if self.image is None
else context.services.images.get(
self.image.image_type, self.image.image_name
)
)
if self.fit:
image = image.resize((self.width, self.height))
# Handle invalid model parameter
model = choose_model(context.services.model_manager, self.model)
# Get the source node id (we are invoking the prepared node)
graph_execution_state = context.services.graph_execution_manager.get(
context.graph_execution_state_id
)
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
outputs = Img2Img(model).generate(
prompt=self.prompt,
init_image=image,
step_callback=partial(self.dispatch_progress, context, source_node_id),
**self.dict(
exclude={"prompt", "image", "mask"}
), # Shorthand for passing all of the parameters above manually
)
# Outputs is an infinite iterator that will return a new InvokeAIGeneratorOutput object
# each time it is called. We only need the first one.
generator_output = next(outputs)
result_image = generator_output.image
# Results are image and seed, unwrap for now and ignore the seed
# TODO: pre-seed?
# TODO: can this return multiple results? Should it?
image_type = ImageType.RESULT
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
metadata = context.services.metadata.build_metadata(
session_id=context.graph_execution_state_id, node=self
)
context.services.images.save(image_type, image_name, result_image, metadata)
return build_image_output(
image_type=image_type,
image_name=image_name,
image=result_image,
)
class InpaintInvocation(ImageToImageInvocation):
"""Generates an image using inpaint."""
type: Literal["inpaint"] = "inpaint"
# Inputs
mask: Union[ImageField, None] = Field(description="The mask")
seam_size: int = Field(default=96, ge=1, description="The seam inpaint size (px)")
seam_blur: int = Field(default=16, ge=0, description="The seam inpaint blur radius (px)")
seam_blur: int = Field(
default=16, ge=0, description="The seam inpaint blur radius (px)"
)
seam_strength: float = Field(
default=0.75, gt=0, le=1, description="The seam inpaint strength"
)
seam_steps: int = Field(default=30, ge=1, description="The number of steps to use for seam inpaint")
tile_size: int = Field(default=32, ge=1, description="The tile infill method size (px)")
infill_method: INFILL_METHODS = Field(default=DEFAULT_INFILL_METHOD, description="The method used to infill empty regions (px)")
inpaint_width: Optional[int] = Field(default=None, multiple_of=8, gt=0, description="The width of the inpaint region (px)")
inpaint_height: Optional[int] = Field(default=None, multiple_of=8, gt=0, description="The height of the inpaint region (px)")
inpaint_fill: Optional[ColorField] = Field(default=ColorField(r=127, g=127, b=127, a=255), description="The solid infill method color")
seam_steps: int = Field(
default=30, ge=1, description="The number of steps to use for seam inpaint"
)
tile_size: int = Field(
default=32, ge=1, description="The tile infill method size (px)"
)
infill_method: INFILL_METHODS = Field(
default=DEFAULT_INFILL_METHOD,
description="The method used to infill empty regions (px)",
)
inpaint_width: Optional[int] = Field(
default=None,
multiple_of=8,
gt=0,
description="The width of the inpaint region (px)",
)
inpaint_height: Optional[int] = Field(
default=None,
multiple_of=8,
gt=0,
description="The height of the inpaint region (px)",
)
inpaint_fill: Optional[ColorField] = Field(
default=ColorField(r=127, g=127, b=127, a=255),
description="The solid infill method color",
)
inpaint_replace: float = Field(
default=0.0,
ge=0.0,
@ -221,6 +127,14 @@ class InpaintInvocation(ImageToImageInvocation):
description="The amount by which to replace masked areas with latent noise",
)
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"tags": ["stable-diffusion", "image"],
},
}
def dispatch_progress(
self,
context: InvocationContext,
@ -234,60 +148,101 @@ class InpaintInvocation(ImageToImageInvocation):
source_node_id=source_node_id,
)
def get_conditioning(self, context):
c, extra_conditioning_info = context.services.latents.get(self.positive_conditioning.conditioning_name)
uc, _ = context.services.latents.get(self.negative_conditioning.conditioning_name)
return (uc, c, extra_conditioning_info)
@contextmanager
def load_model_old_way(self, context, scheduler):
unet_info = context.services.model_manager.get_model(**self.unet.unet.dict())
vae_info = context.services.model_manager.get_model(**self.vae.vae.dict())
#unet = unet_info.context.model
#vae = vae_info.context.model
with ExitStack() as stack:
loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.unet.loras]
with vae_info as vae,\
unet_info as unet,\
ModelPatcher.apply_lora_unet(unet, loras):
device = context.services.model_manager.mgr.cache.execution_device
dtype = context.services.model_manager.mgr.cache.precision
pipeline = StableDiffusionGeneratorPipeline(
vae=vae,
text_encoder=None,
tokenizer=None,
unet=unet,
scheduler=scheduler,
safety_checker=None,
feature_extractor=None,
requires_safety_checker=False,
precision="float16" if dtype == torch.float16 else "float32",
execution_device=device,
)
yield OldModelInfo(
name=self.unet.unet.model_name,
hash="<NO-HASH>",
model=pipeline,
)
def invoke(self, context: InvocationContext) -> ImageOutput:
image = (
None
if self.image is None
else context.services.images.get(
self.image.image_type, self.image.image_name
)
else context.services.images.get_pil_image(self.image.image_name)
)
mask = (
None
if self.mask is None
else context.services.images.get(self.mask.image_type, self.mask.image_name)
else context.services.images.get_pil_image(self.mask.image_name)
)
# Handle invalid model parameter
model = choose_model(context.services.model_manager, self.model)
# Get the source node id (we are invoking the prepared node)
graph_execution_state = context.services.graph_execution_manager.get(
context.graph_execution_state_id
)
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
outputs = Inpaint(model).generate(
prompt=self.prompt,
init_image=image,
mask_image=mask,
step_callback=partial(self.dispatch_progress, context, source_node_id),
**self.dict(
exclude={"prompt", "image", "mask"}
), # Shorthand for passing all of the parameters above manually
conditioning = self.get_conditioning(context)
scheduler = get_scheduler(
context=context,
scheduler_info=self.unet.scheduler,
scheduler_name=self.scheduler,
)
with self.load_model_old_way(context, scheduler) as model:
outputs = Inpaint(model).generate(
conditioning=conditioning,
scheduler=scheduler,
init_image=image,
mask_image=mask,
step_callback=partial(self.dispatch_progress, context, source_node_id),
**self.dict(
exclude={"positive_conditioning", "negative_conditioning", "scheduler", "image", "mask"}
), # Shorthand for passing all of the parameters above manually
)
# Outputs is an infinite iterator that will return a new InvokeAIGeneratorOutput object
# each time it is called. We only need the first one.
generator_output = next(outputs)
result_image = generator_output.image
# Results are image and seed, unwrap for now and ignore the seed
# TODO: pre-seed?
# TODO: can this return multiple results? Should it?
image_type = ImageType.RESULT
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
image_dto = context.services.images.create(
image=generator_output.image,
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
session_id=context.graph_execution_state_id,
node_id=self.id,
is_intermediate=self.is_intermediate,
)
metadata = context.services.metadata.build_metadata(
session_id=context.graph_execution_state_id, node=self
)
context.services.images.save(image_type, image_name, result_image, metadata)
return build_image_output(
image_type=image_type,
image_name=image_name,
image=result_image,
return ImageOutput(
image=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
)

View File

@ -1,13 +1,13 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
import io
from typing import Literal, Optional
from typing import Literal, Optional, Union
import numpy
from PIL import Image, ImageFilter, ImageOps
from PIL import Image, ImageFilter, ImageOps, ImageChops
from pydantic import BaseModel, Field
from ..models.image import ImageField, ImageType
from ..models.image import ImageCategory, ImageField, ResourceOrigin
from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
@ -31,7 +31,7 @@ class ImageOutput(BaseInvocationOutput):
"""Base class for invocations that output an image"""
# fmt: off
type: Literal["image"] = "image"
type: Literal["image_output"] = "image_output"
image: ImageField = Field(default=None, description="The output image")
width: int = Field(description="The width of the image in pixels")
height: int = Field(description="The height of the image in pixels")
@ -41,27 +41,14 @@ class ImageOutput(BaseInvocationOutput):
schema_extra = {"required": ["type", "image", "width", "height"]}
def build_image_output(
image_type: ImageType, image_name: str, image: Image.Image
) -> ImageOutput:
"""Builds an ImageOutput and its ImageField"""
image_field = ImageField(
image_name=image_name,
image_type=image_type,
)
return ImageOutput(
image=image_field,
width=image.width,
height=image.height,
)
class MaskOutput(BaseInvocationOutput):
"""Base class for invocations that output a mask"""
# fmt: off
type: Literal["mask"] = "mask"
mask: ImageField = Field(default=None, description="The output mask")
width: int = Field(description="The width of the mask in pixels")
height: int = Field(description="The height of the mask in pixels")
# fmt: on
class Config:
@ -80,16 +67,17 @@ class LoadImageInvocation(BaseInvocation):
type: Literal["load_image"] = "load_image"
# Inputs
image_type: ImageType = Field(description="The type of the image")
image_name: str = Field(description="The name of the image")
image: Union[ImageField, None] = Field(
default=None, description="The image to load"
)
# fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get(self.image_type, self.image_name)
image = context.services.images.get_pil_image(self.image.image_name)
return build_image_output(
image_type=self.image_type,
image_name=self.image_name,
image=image,
return ImageOutput(
image=ImageField(image_name=self.image.image_name),
width=image.width,
height=image.height,
)
@ -99,32 +87,32 @@ class ShowImageInvocation(BaseInvocation):
type: Literal["show_image"] = "show_image"
# Inputs
image: ImageField = Field(default=None, description="The image to show")
image: Union[ImageField, None] = Field(
default=None, description="The image to show"
)
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get(
self.image.image_type, self.image.image_name
)
image = context.services.images.get_pil_image(self.image.image_name)
if image:
image.show()
# TODO: how to handle failure?
return build_image_output(
image_type=self.image.image_type,
image_name=self.image.image_name,
image=image,
return ImageOutput(
image=ImageField(image_name=self.image.image_name),
width=image.width,
height=image.height,
)
class CropImageInvocation(BaseInvocation, PILInvocationConfig):
class ImageCropInvocation(BaseInvocation, PILInvocationConfig):
"""Crops an image to a specified box. The box can be outside of the image."""
# fmt: off
type: Literal["crop"] = "crop"
type: Literal["img_crop"] = "img_crop"
# Inputs
image: ImageField = Field(default=None, description="The image to crop")
image: Union[ImageField, None] = Field(default=None, description="The image to crop")
x: int = Field(default=0, description="The left x coordinate of the crop rectangle")
y: int = Field(default=0, description="The top y coordinate of the crop rectangle")
width: int = Field(default=512, gt=0, description="The width of the crop rectangle")
@ -132,58 +120,51 @@ class CropImageInvocation(BaseInvocation, PILInvocationConfig):
# fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get(
self.image.image_type, self.image.image_name
)
image = context.services.images.get_pil_image(self.image.image_name)
image_crop = Image.new(
mode="RGBA", size=(self.width, self.height), color=(0, 0, 0, 0)
)
image_crop.paste(image, (-self.x, -self.y))
image_type = ImageType.INTERMEDIATE
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
metadata = context.services.metadata.build_metadata(
session_id=context.graph_execution_state_id, node=self
)
context.services.images.save(image_type, image_name, image_crop, metadata)
return build_image_output(
image_type=image_type,
image_name=image_name,
image_dto = context.services.images.create(
image=image_crop,
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
)
return ImageOutput(
image=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
)
class PasteImageInvocation(BaseInvocation, PILInvocationConfig):
class ImagePasteInvocation(BaseInvocation, PILInvocationConfig):
"""Pastes an image into another image."""
# fmt: off
type: Literal["paste"] = "paste"
type: Literal["img_paste"] = "img_paste"
# Inputs
base_image: ImageField = Field(default=None, description="The base image")
image: ImageField = Field(default=None, description="The image to paste")
base_image: Union[ImageField, None] = Field(default=None, description="The base image")
image: Union[ImageField, None] = Field(default=None, description="The image to paste")
mask: Optional[ImageField] = Field(default=None, description="The mask to use when pasting")
x: int = Field(default=0, description="The left x coordinate at which to paste the image")
y: int = Field(default=0, description="The top y coordinate at which to paste the image")
# fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
base_image = context.services.images.get(
self.base_image.image_type, self.base_image.image_name
)
image = context.services.images.get(
self.image.image_type, self.image.image_name
)
base_image = context.services.images.get_pil_image(self.base_image.image_name)
image = context.services.images.get_pil_image(self.image.image_name)
mask = (
None
if self.mask is None
else ImageOps.invert(
context.services.images.get(self.mask.image_type, self.mask.image_name)
context.services.images.get_pil_image(self.mask.image_name)
)
)
# TODO: probably shouldn't invert mask here... should user be required to do it?
@ -199,20 +180,19 @@ class PasteImageInvocation(BaseInvocation, PILInvocationConfig):
new_image.paste(base_image, (abs(min_x), abs(min_y)))
new_image.paste(image, (max(0, self.x), max(0, self.y)), mask=mask)
image_type = ImageType.RESULT
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
metadata = context.services.metadata.build_metadata(
session_id=context.graph_execution_state_id, node=self
)
context.services.images.save(image_type, image_name, new_image, metadata)
return build_image_output(
image_type=image_type,
image_name=image_name,
image_dto = context.services.images.create(
image=new_image,
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
)
return ImageOutput(
image=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
)
@ -223,48 +203,150 @@ class MaskFromAlphaInvocation(BaseInvocation, PILInvocationConfig):
type: Literal["tomask"] = "tomask"
# Inputs
image: ImageField = Field(default=None, description="The image to create the mask from")
image: Union[ImageField, None] = Field(default=None, description="The image to create the mask from")
invert: bool = Field(default=False, description="Whether or not to invert the mask")
# fmt: on
def invoke(self, context: InvocationContext) -> MaskOutput:
image = context.services.images.get(
self.image.image_type, self.image.image_name
)
image = context.services.images.get_pil_image(self.image.image_name)
image_mask = image.split()[-1]
if self.invert:
image_mask = ImageOps.invert(image_mask)
image_type = ImageType.INTERMEDIATE
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
image_dto = context.services.images.create(
image=image_mask,
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.MASK,
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
)
metadata = context.services.metadata.build_metadata(
session_id=context.graph_execution_state_id, node=self
return MaskOutput(
mask=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
)
context.services.images.save(image_type, image_name, image_mask, metadata)
return MaskOutput(mask=ImageField(image_type=image_type, image_name=image_name))
class ImageMultiplyInvocation(BaseInvocation, PILInvocationConfig):
"""Multiplies two images together using `PIL.ImageChops.multiply()`."""
# fmt: off
type: Literal["img_mul"] = "img_mul"
# Inputs
image1: Union[ImageField, None] = Field(default=None, description="The first image to multiply")
image2: Union[ImageField, None] = Field(default=None, description="The second image to multiply")
# fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
image1 = context.services.images.get_pil_image(self.image1.image_name)
image2 = context.services.images.get_pil_image(self.image2.image_name)
multiply_image = ImageChops.multiply(image1, image2)
image_dto = context.services.images.create(
image=multiply_image,
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
)
return ImageOutput(
image=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
)
class BlurInvocation(BaseInvocation, PILInvocationConfig):
IMAGE_CHANNELS = Literal["A", "R", "G", "B"]
class ImageChannelInvocation(BaseInvocation, PILInvocationConfig):
"""Gets a channel from an image."""
# fmt: off
type: Literal["img_chan"] = "img_chan"
# Inputs
image: Union[ImageField, None] = Field(default=None, description="The image to get the channel from")
channel: IMAGE_CHANNELS = Field(default="A", description="The channel to get")
# fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(self.image.image_name)
channel_image = image.getchannel(self.channel)
image_dto = context.services.images.create(
image=channel_image,
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
)
return ImageOutput(
image=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
)
IMAGE_MODES = Literal["L", "RGB", "RGBA", "CMYK", "YCbCr", "LAB", "HSV", "I", "F"]
class ImageConvertInvocation(BaseInvocation, PILInvocationConfig):
"""Converts an image to a different mode."""
# fmt: off
type: Literal["img_conv"] = "img_conv"
# Inputs
image: Union[ImageField, None] = Field(default=None, description="The image to convert")
mode: IMAGE_MODES = Field(default="L", description="The mode to convert to")
# fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(self.image.image_name)
converted_image = image.convert(self.mode)
image_dto = context.services.images.create(
image=converted_image,
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
)
return ImageOutput(
image=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
)
class ImageBlurInvocation(BaseInvocation, PILInvocationConfig):
"""Blurs an image"""
# fmt: off
type: Literal["blur"] = "blur"
type: Literal["img_blur"] = "img_blur"
# Inputs
image: ImageField = Field(default=None, description="The image to blur")
image: Union[ImageField, None] = Field(default=None, description="The image to blur")
radius: float = Field(default=8.0, ge=0, description="The blur radius")
blur_type: Literal["gaussian", "box"] = Field(default="gaussian", description="The type of blur")
# fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get(
self.image.image_type, self.image.image_name
)
image = context.services.images.get_pil_image(self.image.image_name)
blur = (
ImageFilter.GaussianBlur(self.radius)
@ -273,74 +355,171 @@ class BlurInvocation(BaseInvocation, PILInvocationConfig):
)
blur_image = image.filter(blur)
image_type = ImageType.INTERMEDIATE
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
image_dto = context.services.images.create(
image=blur_image,
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
)
metadata = context.services.metadata.build_metadata(
session_id=context.graph_execution_state_id, node=self
)
context.services.images.save(image_type, image_name, blur_image, metadata)
return build_image_output(
image_type=image_type, image_name=image_name, image=blur_image
return ImageOutput(
image=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
)
class LerpInvocation(BaseInvocation, PILInvocationConfig):
PIL_RESAMPLING_MODES = Literal[
"nearest",
"box",
"bilinear",
"hamming",
"bicubic",
"lanczos",
]
PIL_RESAMPLING_MAP = {
"nearest": Image.Resampling.NEAREST,
"box": Image.Resampling.BOX,
"bilinear": Image.Resampling.BILINEAR,
"hamming": Image.Resampling.HAMMING,
"bicubic": Image.Resampling.BICUBIC,
"lanczos": Image.Resampling.LANCZOS,
}
class ImageResizeInvocation(BaseInvocation, PILInvocationConfig):
"""Resizes an image to specific dimensions"""
# fmt: off
type: Literal["img_resize"] = "img_resize"
# Inputs
image: Union[ImageField, None] = Field(default=None, description="The image to resize")
width: int = Field(ge=64, multiple_of=8, description="The width to resize to (px)")
height: int = Field(ge=64, multiple_of=8, description="The height to resize to (px)")
resample_mode: PIL_RESAMPLING_MODES = Field(default="bicubic", description="The resampling mode")
# fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(self.image.image_name)
resample_mode = PIL_RESAMPLING_MAP[self.resample_mode]
resize_image = image.resize(
(self.width, self.height),
resample=resample_mode,
)
image_dto = context.services.images.create(
image=resize_image,
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
)
return ImageOutput(
image=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
)
class ImageScaleInvocation(BaseInvocation, PILInvocationConfig):
"""Scales an image by a factor"""
# fmt: off
type: Literal["img_scale"] = "img_scale"
# Inputs
image: Union[ImageField, None] = Field(default=None, description="The image to scale")
scale_factor: float = Field(gt=0, description="The factor by which to scale the image")
resample_mode: PIL_RESAMPLING_MODES = Field(default="bicubic", description="The resampling mode")
# fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(self.image.image_name)
resample_mode = PIL_RESAMPLING_MAP[self.resample_mode]
width = int(image.width * self.scale_factor)
height = int(image.height * self.scale_factor)
resize_image = image.resize(
(width, height),
resample=resample_mode,
)
image_dto = context.services.images.create(
image=resize_image,
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
)
return ImageOutput(
image=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
)
class ImageLerpInvocation(BaseInvocation, PILInvocationConfig):
"""Linear interpolation of all pixels of an image"""
# fmt: off
type: Literal["lerp"] = "lerp"
type: Literal["img_lerp"] = "img_lerp"
# Inputs
image: ImageField = Field(default=None, description="The image to lerp")
image: Union[ImageField, None] = Field(default=None, description="The image to lerp")
min: int = Field(default=0, ge=0, le=255, description="The minimum output value")
max: int = Field(default=255, ge=0, le=255, description="The maximum output value")
# fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get(
self.image.image_type, self.image.image_name
)
image = context.services.images.get_pil_image(self.image.image_name)
image_arr = numpy.asarray(image, dtype=numpy.float32) / 255
image_arr = image_arr * (self.max - self.min) + self.max
lerp_image = Image.fromarray(numpy.uint8(image_arr))
image_type = ImageType.INTERMEDIATE
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
image_dto = context.services.images.create(
image=lerp_image,
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
)
metadata = context.services.metadata.build_metadata(
session_id=context.graph_execution_state_id, node=self
)
context.services.images.save(image_type, image_name, lerp_image, metadata)
return build_image_output(
image_type=image_type, image_name=image_name, image=lerp_image
return ImageOutput(
image=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
)
class InverseLerpInvocation(BaseInvocation, PILInvocationConfig):
class ImageInverseLerpInvocation(BaseInvocation, PILInvocationConfig):
"""Inverse linear interpolation of all pixels of an image"""
# fmt: off
type: Literal["ilerp"] = "ilerp"
type: Literal["img_ilerp"] = "img_ilerp"
# Inputs
image: ImageField = Field(default=None, description="The image to lerp")
image: Union[ImageField, None] = Field(default=None, description="The image to lerp")
min: int = Field(default=0, ge=0, le=255, description="The minimum input value")
max: int = Field(default=255, ge=0, le=255, description="The maximum input value")
# fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get(
self.image.image_type, self.image.image_name
)
image = context.services.images.get_pil_image(self.image.image_name)
image_arr = numpy.asarray(image, dtype=numpy.float32)
image_arr = (
@ -352,16 +531,17 @@ class InverseLerpInvocation(BaseInvocation, PILInvocationConfig):
ilerp_image = Image.fromarray(numpy.uint8(image_arr))
image_type = ImageType.INTERMEDIATE
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
image_dto = context.services.images.create(
image=ilerp_image,
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
)
metadata = context.services.metadata.build_metadata(
session_id=context.graph_execution_state_id, node=self
)
context.services.images.save(image_type, image_name, ilerp_image, metadata)
return build_image_output(
image_type=image_type, image_name=image_name, image=ilerp_image
return ImageOutput(
image=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
)

View File

@ -1,17 +1,17 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
from typing import Literal, Optional, Union, get_args
from typing import Literal, Union, get_args
import numpy as np
import math
from PIL import Image, ImageOps
from pydantic import Field
from invokeai.app.invocations.image import ImageOutput, build_image_output
from invokeai.app.invocations.image import ImageOutput
from invokeai.app.util.misc import SEED_MAX, get_random_seed
from invokeai.backend.image_util.patchmatch import PatchMatch
from ..models.image import ColorField, ImageField, ImageType
from ..models.image import ColorField, ImageCategory, ImageField, ResourceOrigin
from .baseinvocation import (
BaseInvocation,
InvocationContext,
@ -125,36 +125,35 @@ class InfillColorInvocation(BaseInvocation):
"""Infills transparent areas of an image with a solid color"""
type: Literal["infill_rgba"] = "infill_rgba"
image: Optional[ImageField] = Field(default=None, description="The image to infill")
color: Optional[ColorField] = Field(
image: Union[ImageField, None] = Field(
default=None, description="The image to infill"
)
color: ColorField = Field(
default=ColorField(r=127, g=127, b=127, a=255),
description="The color to use to infill",
)
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get(
self.image.image_type, self.image.image_name
)
image = context.services.images.get_pil_image(self.image.image_name)
solid_bg = Image.new("RGBA", image.size, self.color.tuple())
infilled = Image.alpha_composite(solid_bg, image)
infilled = Image.alpha_composite(solid_bg, image.convert("RGBA"))
infilled.paste(image, (0, 0), image.split()[-1])
image_type = ImageType.RESULT
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
image_dto = context.services.images.create(
image=infilled,
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
)
metadata = context.services.metadata.build_metadata(
session_id=context.graph_execution_state_id, node=self
)
context.services.images.save(image_type, image_name, infilled, metadata)
return build_image_output(
image_type=image_type,
image_name=image_name,
image=image,
return ImageOutput(
image=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
)
@ -163,7 +162,9 @@ class InfillTileInvocation(BaseInvocation):
type: Literal["infill_tile"] = "infill_tile"
image: Optional[ImageField] = Field(default=None, description="The image to infill")
image: Union[ImageField, None] = Field(
default=None, description="The image to infill"
)
tile_size: int = Field(default=32, ge=1, description="The tile size (px)")
seed: int = Field(
ge=0,
@ -173,29 +174,26 @@ class InfillTileInvocation(BaseInvocation):
)
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get(
self.image.image_type, self.image.image_name
)
image = context.services.images.get_pil_image(self.image.image_name)
infilled = tile_fill_missing(
image.copy(), seed=self.seed, tile_size=self.tile_size
)
infilled.paste(image, (0, 0), image.split()[-1])
image_type = ImageType.RESULT
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
image_dto = context.services.images.create(
image=infilled,
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
)
metadata = context.services.metadata.build_metadata(
session_id=context.graph_execution_state_id, node=self
)
context.services.images.save(image_type, image_name, infilled, metadata)
return build_image_output(
image_type=image_type,
image_name=image_name,
image=image,
return ImageOutput(
image=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
)
@ -204,30 +202,29 @@ class InfillPatchMatchInvocation(BaseInvocation):
type: Literal["infill_patchmatch"] = "infill_patchmatch"
image: Optional[ImageField] = Field(default=None, description="The image to infill")
image: Union[ImageField, None] = Field(
default=None, description="The image to infill"
)
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get(
self.image.image_type, self.image.image_name
)
image = context.services.images.get_pil_image(self.image.image_name)
if PatchMatch.patchmatch_available():
infilled = infill_patchmatch(image.copy())
else:
raise ValueError("PatchMatch is not available on this system")
image_type = ImageType.RESULT
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
image_dto = context.services.images.create(
image=infilled,
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
)
metadata = context.services.metadata.build_metadata(
session_id=context.graph_execution_state_id, node=self
)
context.services.images.save(image_type, image_name, infilled, metadata)
return build_image_output(
image_type=image_type,
image_name=image_name,
image=image,
return ImageOutput(
image=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
)

View File

@ -1,34 +1,36 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
import random
from typing import Literal, Optional, Union
from contextlib import ExitStack
from typing import List, Literal, Optional, Union
import einops
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, validator
import torch
from diffusers import ControlNetModel, DPMSolverMultistepScheduler
from diffusers.image_processor import VaeImageProcessor
from diffusers.schedulers import SchedulerMixin as Scheduler
from invokeai.app.invocations.util.choose_model import choose_model
from invokeai.app.util.misc import SEED_MAX, get_random_seed
from invokeai.app.util.step_callback import stable_diffusion_step_callback
from ...backend.model_management.model_manager import ModelManager
from ...backend.util.devices import choose_torch_device, torch_dtype
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings
from ..models.image import ImageCategory, ImageField, ResourceOrigin
from ...backend.image_util.seamless import configure_model_padding
from ...backend.prompting.conditioning import get_uc_and_c_and_ec
from ...backend.stable_diffusion.diffusers_pipeline import ConditioningData, StableDiffusionGeneratorPipeline, image_resized_to_grid_as_tensor
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
import numpy as np
from ..services.image_storage import ImageType
from .baseinvocation import BaseInvocation, InvocationContext
from .image import ImageField, ImageOutput, build_image_output
from .compel import ConditioningField
from ...backend.stable_diffusion import PipelineIntermediateState
from diffusers.schedulers import SchedulerMixin as Scheduler
import diffusers
from diffusers import DiffusionPipeline
from ...backend.stable_diffusion.diffusers_pipeline import (
ConditioningData, ControlNetData, StableDiffusionGeneratorPipeline,
image_resized_to_grid_as_tensor)
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import \
PostprocessingSettings
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
from ...backend.util.devices import choose_torch_device, torch_dtype
from ...backend.model_management.lora import ModelPatcher
from .baseinvocation import (BaseInvocation, BaseInvocationOutput,
InvocationConfig, InvocationContext)
from .compel import ConditioningField
from .controlnet_image_processors import ControlField
from .image import ImageOutput
from .model import ModelInfo, UNetField, VaeField
class LatentsField(BaseModel):
"""A latents field used for passing latents between invocations"""
@ -81,10 +83,17 @@ SAMPLER_NAME_VALUES = Literal[
]
def get_scheduler(scheduler_name:str, model: StableDiffusionGeneratorPipeline)->Scheduler:
def get_scheduler(
context: InvocationContext,
scheduler_info: ModelInfo,
scheduler_name: str,
) -> Scheduler:
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP['ddim'])
scheduler_config = model.scheduler.config
orig_scheduler_info = context.services.model_manager.get_model(**scheduler_info.dict())
with orig_scheduler_info as orig_scheduler:
scheduler_config = orig_scheduler.config
if "_backup" in scheduler_config:
scheduler_config = scheduler_config["_backup"]
scheduler_config = {**scheduler_config, **scheduler_extra_config, "_backup": scheduler_config}
@ -119,7 +128,6 @@ def get_noise(width:int, height:int, device:torch.device, seed:int = 0, latent_c
# x = (1 - self.perlin) * x + self.perlin * perlin_noise
return x
class NoiseInvocation(BaseInvocation):
"""Generates latent noise."""
@ -139,12 +147,17 @@ class NoiseInvocation(BaseInvocation):
},
}
@validator("seed", pre=True)
def modulo_seed(cls, v):
"""Returns the seed modulo SEED_MAX to ensure it is within the valid range."""
return v % SEED_MAX
def invoke(self, context: InvocationContext) -> NoiseOutput:
device = torch.device(choose_torch_device())
noise = get_noise(self.width, self.height, device, self.seed)
name = f'{context.graph_execution_state_id}__{self.id}'
context.services.latents.set(name, noise)
context.services.latents.save(name, noise)
return build_noise_output(latents_name=name, latents=noise)
@ -160,20 +173,36 @@ class TextToLatentsInvocation(BaseInvocation):
negative_conditioning: Optional[ConditioningField] = Field(description="Negative conditioning for generation")
noise: Optional[LatentsField] = Field(description="The noise to use")
steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image")
cfg_scale: float = Field(default=7.5, gt=0, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
scheduler: SAMPLER_NAME_VALUES = Field(default="lms", description="The scheduler to use" )
model: str = Field(default="", description="The model to use (currently ignored)")
seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
cfg_scale: Union[float, List[float]] = Field(default=7.5, ge=1, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" )
unet: UNetField = Field(default=None, description="UNet submodel")
control: Union[ControlField, list[ControlField]] = Field(default=None, description="The control to use")
#seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
#seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
# fmt: on
@validator("cfg_scale")
def ge_one(cls, v):
"""validate that all cfg_scale values are >= 1"""
if isinstance(v, list):
for i in v:
if i < 1:
raise ValueError('cfg_scale must be greater than 1')
else:
if v < 1:
raise ValueError('cfg_scale must be greater than 1')
return v
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"tags": ["latents", "image"],
"tags": ["latents"],
"type_hints": {
"model": "model"
"model": "model",
"control": "control",
# "cfg_scale": "float",
"cfg_scale": "number"
}
},
}
@ -189,49 +218,138 @@ class TextToLatentsInvocation(BaseInvocation):
source_node_id=source_node_id,
)
def get_model(self, model_manager: ModelManager) -> StableDiffusionGeneratorPipeline:
model_info = choose_model(model_manager, self.model)
model_name = model_info['model_name']
model_hash = model_info['hash']
model: StableDiffusionGeneratorPipeline = model_info['model']
model.scheduler = get_scheduler(
model=model,
scheduler_name=self.scheduler
)
if isinstance(model, DiffusionPipeline):
for component in [model.unet, model.vae]:
configure_model_padding(component,
self.seamless,
self.seamless_axes
)
else:
configure_model_padding(model,
self.seamless,
self.seamless_axes
)
return model
def get_conditioning_data(self, context: InvocationContext, model: StableDiffusionGeneratorPipeline) -> ConditioningData:
def get_conditioning_data(self, context: InvocationContext, scheduler) -> ConditioningData:
c, extra_conditioning_info = context.services.latents.get(self.positive_conditioning.conditioning_name)
uc, _ = context.services.latents.get(self.negative_conditioning.conditioning_name)
conditioning_data = ConditioningData(
uc,
c,
self.cfg_scale,
extra_conditioning_info,
unconditioned_embeddings=uc,
text_embeddings=c,
guidance_scale=self.cfg_scale,
extra=extra_conditioning_info,
postprocessing_settings=PostprocessingSettings(
threshold=0.0,#threshold,
warmup=0.2,#warmup,
h_symmetry_time_pct=None,#h_symmetry_time_pct,
v_symmetry_time_pct=None#v_symmetry_time_pct,
),
).add_scheduler_args_if_applicable(model.scheduler, eta=0.0)#ddim_eta)
)
conditioning_data = conditioning_data.add_scheduler_args_if_applicable(
scheduler,
# for ddim scheduler
eta=0.0, #ddim_eta
# for ancestral and sde schedulers
generator=torch.Generator(device=uc.device).manual_seed(0),
)
return conditioning_data
def create_pipeline(self, unet, scheduler) -> StableDiffusionGeneratorPipeline:
# TODO:
#configure_model_padding(
# unet,
# self.seamless,
# self.seamless_axes,
#)
class FakeVae:
class FakeVaeConfig:
def __init__(self):
self.block_out_channels = [0]
def __init__(self):
self.config = FakeVae.FakeVaeConfig()
return StableDiffusionGeneratorPipeline(
vae=FakeVae(), # TODO: oh...
text_encoder=None,
tokenizer=None,
unet=unet,
scheduler=scheduler,
safety_checker=None,
feature_extractor=None,
requires_safety_checker=False,
precision="float16" if unet.dtype == torch.float16 else "float32",
)
def prep_control_data(
self,
context: InvocationContext,
model: StableDiffusionGeneratorPipeline, # really only need model for dtype and device
control_input: List[ControlField],
latents_shape: List[int],
do_classifier_free_guidance: bool = True,
) -> List[ControlNetData]:
# assuming fixed dimensional scaling of 8:1 for image:latents
control_height_resize = latents_shape[2] * 8
control_width_resize = latents_shape[3] * 8
if control_input is None:
# print("control input is None")
control_list = None
elif isinstance(control_input, list) and len(control_input) == 0:
# print("control input is empty list")
control_list = None
elif isinstance(control_input, ControlField):
# print("control input is ControlField")
control_list = [control_input]
elif isinstance(control_input, list) and len(control_input) > 0 and isinstance(control_input[0], ControlField):
# print("control input is list[ControlField]")
control_list = control_input
else:
# print("input control is unrecognized:", type(self.control))
control_list = None
if (control_list is None):
control_data = None
# from above handling, any control that is not None should now be of type list[ControlField]
else:
# FIXME: add checks to skip entry if model or image is None
# and if weight is None, populate with default 1.0?
control_data = []
control_models = []
for control_info in control_list:
# handle control models
if ("," in control_info.control_model):
control_model_split = control_info.control_model.split(",")
control_name = control_model_split[0]
control_subfolder = control_model_split[1]
print("Using HF model subfolders")
print(" control_name: ", control_name)
print(" control_subfolder: ", control_subfolder)
control_model = ControlNetModel.from_pretrained(control_name,
subfolder=control_subfolder,
torch_dtype=model.unet.dtype).to(model.device)
else:
control_model = ControlNetModel.from_pretrained(control_info.control_model,
torch_dtype=model.unet.dtype).to(model.device)
control_models.append(control_model)
control_image_field = control_info.image
input_image = context.services.images.get_pil_image(control_image_field.image_name)
# self.image.image_type, self.image.image_name
# FIXME: still need to test with different widths, heights, devices, dtypes
# and add in batch_size, num_images_per_prompt?
# and do real check for classifier_free_guidance?
# prepare_control_image should return torch.Tensor of shape(batch_size, 3, height, width)
control_image = model.prepare_control_image(
image=input_image,
do_classifier_free_guidance=do_classifier_free_guidance,
width=control_width_resize,
height=control_height_resize,
# batch_size=batch_size * num_images_per_prompt,
# num_images_per_prompt=num_images_per_prompt,
device=control_model.device,
dtype=control_model.dtype,
)
control_item = ControlNetData(model=control_model,
image_tensor=control_image,
weight=control_info.control_weight,
begin_step_percent=control_info.begin_step_percent,
end_step_percent=control_info.end_step_percent)
control_data.append(control_item)
# MultiControlNetModel has been refactored out, just need list[ControlNetData]
return control_data
def invoke(self, context: InvocationContext) -> LatentsOutput:
noise = context.services.latents.get(self.noise.latents_name)
@ -243,27 +361,46 @@ class TextToLatentsInvocation(BaseInvocation):
def step_callback(state: PipelineIntermediateState):
self.dispatch_progress(context, source_node_id, state)
model = self.get_model(context.services.model_manager)
conditioning_data = self.get_conditioning_data(context, model)
unet_info = context.services.model_manager.get_model(**self.unet.unet.dict())
with unet_info as unet,\
ExitStack() as stack:
# TODO: Verify the noise is the right size
scheduler = get_scheduler(
context=context,
scheduler_info=self.unet.scheduler,
scheduler_name=self.scheduler,
)
pipeline = self.create_pipeline(unet, scheduler)
conditioning_data = self.get_conditioning_data(context, scheduler)
result_latents, result_attention_map_saver = model.latents_from_embeddings(
latents=torch.zeros_like(noise, dtype=torch_dtype(model.device)),
noise=noise,
num_inference_steps=self.steps,
conditioning_data=conditioning_data,
callback=step_callback
)
loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.unet.loras]
control_data = self.prep_control_data(
model=pipeline, context=context, control_input=self.control,
latents_shape=noise.shape,
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
do_classifier_free_guidance=True,
)
with ModelPatcher.apply_lora_unet(pipeline.unet, loras):
# TODO: Verify the noise is the right size
result_latents, result_attention_map_saver = pipeline.latents_from_embeddings(
latents=torch.zeros_like(noise, dtype=torch_dtype(unet.device)),
noise=noise,
num_inference_steps=self.steps,
conditioning_data=conditioning_data,
control_data=control_data, # list[ControlNetData]
callback=step_callback,
)
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
torch.cuda.empty_cache()
name = f'{context.graph_execution_state_id}__{self.id}'
context.services.latents.set(name, result_latents)
context.services.latents.save(name, result_latents)
return build_latents_output(latents_name=name, latents=result_latents)
class LatentsToLatentsInvocation(TextToLatentsInvocation):
"""Generates latents using latents as base image."""
@ -271,7 +408,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
# Inputs
latents: Optional[LatentsField] = Field(description="The latents to use as a base image")
strength: float = Field(default=0.5, description="The strength of the latents to use")
strength: float = Field(default=0.7, ge=0, le=1, description="The strength of the latents to use")
# Schema customisation
class Config(InvocationConfig):
@ -279,7 +416,9 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
"ui": {
"tags": ["latents"],
"type_hints": {
"model": "model"
"model": "model",
"control": "control",
"cfg_scale": "number",
}
},
}
@ -295,31 +434,58 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
def step_callback(state: PipelineIntermediateState):
self.dispatch_progress(context, source_node_id, state)
model = self.get_model(context.services.model_manager)
conditioning_data = self.get_conditioning_data(context, model)
# TODO: Verify the noise is the right size
initial_latents = latent if self.strength < 1.0 else torch.zeros_like(
latent, device=model.device, dtype=latent.dtype
unet_info = context.services.model_manager.get_model(
**self.unet.unet.dict(),
)
timesteps, _ = model.get_img2img_timesteps(self.steps, self.strength)
with unet_info as unet,\
ExitStack() as stack:
result_latents, result_attention_map_saver = model.latents_from_embeddings(
latents=initial_latents,
timesteps=timesteps,
noise=noise,
num_inference_steps=self.steps,
conditioning_data=conditioning_data,
callback=step_callback
)
scheduler = get_scheduler(
context=context,
scheduler_info=self.unet.scheduler,
scheduler_name=self.scheduler,
)
pipeline = self.create_pipeline(unet, scheduler)
conditioning_data = self.get_conditioning_data(context, scheduler)
control_data = self.prep_control_data(
model=pipeline, context=context, control_input=self.control,
latents_shape=noise.shape,
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
do_classifier_free_guidance=True,
)
# TODO: Verify the noise is the right size
initial_latents = latent if self.strength < 1.0 else torch.zeros_like(
latent, device=unet.device, dtype=latent.dtype
)
timesteps, _ = pipeline.get_img2img_timesteps(
self.steps,
self.strength,
device=unet.device,
)
loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.unet.loras]
with ModelPatcher.apply_lora_unet(pipeline.unet, loras):
result_latents, result_attention_map_saver = pipeline.latents_from_embeddings(
latents=initial_latents,
timesteps=timesteps,
noise=noise,
num_inference_steps=self.steps,
conditioning_data=conditioning_data,
control_data=control_data, # list[ControlNetData]
callback=step_callback
)
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
torch.cuda.empty_cache()
name = f'{context.graph_execution_state_id}__{self.id}'
context.services.latents.set(name, result_latents)
context.services.latents.save(name, result_latents)
return build_latents_output(latents_name=name, latents=result_latents)
@ -331,16 +497,14 @@ class LatentsToImageInvocation(BaseInvocation):
# Inputs
latents: Optional[LatentsField] = Field(description="The latents to generate an image from")
model: str = Field(default="", description="The model to use")
vae: VaeField = Field(default=None, description="Vae submodel")
tiled: bool = Field(default=False, description="Decode latents by overlaping tiles(less memory consumption)")
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"tags": ["latents", "image"],
"type_hints": {
"model": "model"
}
},
}
@ -348,30 +512,44 @@ class LatentsToImageInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> ImageOutput:
latents = context.services.latents.get(self.latents.latents_name)
# TODO: this only really needs the vae
model_info = choose_model(context.services.model_manager, self.model)
model: StableDiffusionGeneratorPipeline = model_info['model']
vae_info = context.services.model_manager.get_model(
**self.vae.vae.dict(),
)
with torch.inference_mode():
np_image = model.decode_latents(latents)
image = model.numpy_to_pil(np_image)[0]
image_type = ImageType.RESULT
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
metadata = context.services.metadata.build_metadata(
session_id=context.graph_execution_state_id, node=self
)
with vae_info as vae:
if self.tiled or context.services.configuration.tiled_decode:
vae.enable_tiling()
else:
vae.disable_tiling()
# clear memory as vae decode can request a lot
torch.cuda.empty_cache()
context.services.images.save(image_type, image_name, image, metadata)
return build_image_output(
image_type=image_type, image_name=image_name, image=image
)
with torch.inference_mode():
# copied from diffusers pipeline
latents = latents / vae.config.scaling_factor
image = vae.decode(latents, return_dict=False)[0]
image = (image / 2 + 0.5).clamp(0, 1) # denormalize
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
np_image = image.cpu().permute(0, 2, 3, 1).float().numpy()
image = VaeImageProcessor.numpy_to_pil(np_image)[0]
torch.cuda.empty_cache()
image_dto = context.services.images.create(
image=image,
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
)
return ImageOutput(
image=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
)
LATENTS_INTERPOLATION_MODE = Literal[
"nearest", "linear", "bilinear", "bicubic", "trilinear", "area", "nearest-exact"
@ -404,7 +582,8 @@ class ResizeLatentsInvocation(BaseInvocation):
torch.cuda.empty_cache()
name = f"{context.graph_execution_state_id}__{self.id}"
context.services.latents.set(name, resized_latents)
# context.services.latents.set(name, resized_latents)
context.services.latents.save(name, resized_latents)
return build_latents_output(latents_name=name, latents=resized_latents)
@ -434,7 +613,8 @@ class ScaleLatentsInvocation(BaseInvocation):
torch.cuda.empty_cache()
name = f"{context.graph_execution_state_id}__{self.id}"
context.services.latents.set(name, resized_latents)
# context.services.latents.set(name, resized_latents)
context.services.latents.save(name, resized_latents)
return build_latents_output(latents_name=name, latents=resized_latents)
@ -445,38 +625,50 @@ class ImageToLatentsInvocation(BaseInvocation):
# Inputs
image: Union[ImageField, None] = Field(description="The image to encode")
model: str = Field(default="", description="The model to use")
vae: VaeField = Field(default=None, description="Vae submodel")
tiled: bool = Field(default=False, description="Encode latents by overlaping tiles(less memory consumption)")
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"tags": ["latents", "image"],
"type_hints": {"model": "model"},
},
}
@torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput:
image = context.services.images.get(
self.image.image_type, self.image.image_name
# image = context.services.images.get(
# self.image.image_type, self.image.image_name
# )
image = context.services.images.get_pil_image(self.image.image_name)
#vae_info = context.services.model_manager.get_model(**self.vae.vae.dict())
vae_info = context.services.model_manager.get_model(
**self.vae.vae.dict(),
)
# TODO: this only really needs the vae
model_info = choose_model(context.services.model_manager, self.model)
model: StableDiffusionGeneratorPipeline = model_info["model"]
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
if image_tensor.dim() == 3:
image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w")
latents = model.non_noised_latents_from_image(
image_tensor,
device=model._model_group.device_for(model.unet),
dtype=model.unet.dtype,
)
with vae_info as vae:
if self.tiled:
vae.enable_tiling()
else:
vae.disable_tiling()
# non_noised_latents_from_image
image_tensor = image_tensor.to(device=vae.device, dtype=vae.dtype)
with torch.inference_mode():
image_tensor_dist = vae.encode(image_tensor).latent_dist
latents = image_tensor_dist.sample().to(
dtype=vae.dtype
) # FIXME: uses torch.randn. make reproducible!
latents = 0.18215 * latents
name = f"{context.graph_execution_state_id}__{self.id}"
context.services.latents.set(name, latents)
# context.services.latents.set(name, latents)
context.services.latents.save(name, latents)
return build_latents_output(latents_name=name, latents=latents)

View File

@ -34,6 +34,15 @@ class IntOutput(BaseInvocationOutput):
# fmt: on
class FloatOutput(BaseInvocationOutput):
"""A float output"""
# fmt: off
type: Literal["float_output"] = "float_output"
param: float = Field(default=None, description="The output float")
# fmt: on
class AddInvocation(BaseInvocation, MathInvocationConfig):
"""Adds two numbers"""

View File

@ -0,0 +1,217 @@
from typing import Literal, Optional, Union, List
from pydantic import BaseModel, Field
import copy
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
from ...backend.util.devices import choose_torch_device, torch_dtype
from ...backend.model_management import BaseModelType, ModelType, SubModelType
class ModelInfo(BaseModel):
model_name: str = Field(description="Info to load submodel")
base_model: BaseModelType = Field(description="Base model")
model_type: ModelType = Field(description="Info to load submodel")
submodel: Optional[SubModelType] = Field(description="Info to load submodel")
class LoraInfo(ModelInfo):
weight: float = Field(description="Lora's weight which to use when apply to model")
class UNetField(BaseModel):
unet: ModelInfo = Field(description="Info to load unet submodel")
scheduler: ModelInfo = Field(description="Info to load scheduler submodel")
loras: List[LoraInfo] = Field(description="Loras to apply on model loading")
class ClipField(BaseModel):
tokenizer: ModelInfo = Field(description="Info to load tokenizer submodel")
text_encoder: ModelInfo = Field(description="Info to load text_encoder submodel")
loras: List[LoraInfo] = Field(description="Loras to apply on model loading")
class VaeField(BaseModel):
# TODO: better naming?
vae: ModelInfo = Field(description="Info to load vae submodel")
class ModelLoaderOutput(BaseInvocationOutput):
"""Model loader output"""
#fmt: off
type: Literal["model_loader_output"] = "model_loader_output"
unet: UNetField = Field(default=None, description="UNet submodel")
clip: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels")
vae: VaeField = Field(default=None, description="Vae submodel")
#fmt: on
class PipelineModelField(BaseModel):
"""Pipeline model field"""
model_name: str = Field(description="Name of the model")
base_model: BaseModelType = Field(description="Base model")
class PipelineModelLoaderInvocation(BaseInvocation):
"""Loads a pipeline model, outputting its submodels."""
type: Literal["pipeline_model_loader"] = "pipeline_model_loader"
model: PipelineModelField = Field(description="The model to load")
# TODO: precision?
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"tags": ["model", "loader"],
"type_hints": {
"model": "model"
}
},
}
def invoke(self, context: InvocationContext) -> ModelLoaderOutput:
base_model = self.model.base_model
model_name = self.model.model_name
model_type = ModelType.Pipeline
# TODO: not found exceptions
if not context.services.model_manager.model_exists(
model_name=model_name,
base_model=base_model,
model_type=model_type,
):
raise Exception(f"Unknown {base_model} {model_type} model: {model_name}")
"""
if not context.services.model_manager.model_exists(
model_name=self.model_name,
model_type=SDModelType.Diffusers,
submodel=SDModelType.Tokenizer,
):
raise Exception(
f"Failed to find tokenizer submodel in {self.model_name}! Check if model corrupted"
)
if not context.services.model_manager.model_exists(
model_name=self.model_name,
model_type=SDModelType.Diffusers,
submodel=SDModelType.TextEncoder,
):
raise Exception(
f"Failed to find text_encoder submodel in {self.model_name}! Check if model corrupted"
)
if not context.services.model_manager.model_exists(
model_name=self.model_name,
model_type=SDModelType.Diffusers,
submodel=SDModelType.UNet,
):
raise Exception(
f"Failed to find unet submodel from {self.model_name}! Check if model corrupted"
)
"""
return ModelLoaderOutput(
unet=UNetField(
unet=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=model_type,
submodel=SubModelType.UNet,
),
scheduler=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=model_type,
submodel=SubModelType.Scheduler,
),
loras=[],
),
clip=ClipField(
tokenizer=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=model_type,
submodel=SubModelType.Tokenizer,
),
text_encoder=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=model_type,
submodel=SubModelType.TextEncoder,
),
loras=[],
),
vae=VaeField(
vae=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=model_type,
submodel=SubModelType.Vae,
),
)
)
class LoraLoaderOutput(BaseInvocationOutput):
"""Model loader output"""
#fmt: off
type: Literal["lora_loader_output"] = "lora_loader_output"
unet: Optional[UNetField] = Field(default=None, description="UNet submodel")
clip: Optional[ClipField] = Field(default=None, description="Tokenizer and text_encoder submodels")
#fmt: on
class LoraLoaderInvocation(BaseInvocation):
"""Apply selected lora to unet and text_encoder."""
type: Literal["lora_loader"] = "lora_loader"
lora_name: str = Field(description="Lora model name")
weight: float = Field(default=0.75, description="With what weight to apply lora")
unet: Optional[UNetField] = Field(description="UNet model for applying lora")
clip: Optional[ClipField] = Field(description="Clip model for applying lora")
def invoke(self, context: InvocationContext) -> LoraLoaderOutput:
if not context.services.model_manager.model_exists(
model_name=self.lora_name,
model_type=SDModelType.Lora,
):
raise Exception(f"Unkown lora name: {self.lora_name}!")
if self.unet is not None and any(lora.model_name == self.lora_name for lora in self.unet.loras):
raise Exception(f"Lora \"{self.lora_name}\" already applied to unet")
if self.clip is not None and any(lora.model_name == self.lora_name for lora in self.clip.loras):
raise Exception(f"Lora \"{self.lora_name}\" already applied to clip")
output = LoraLoaderOutput()
if self.unet is not None:
output.unet = copy.deepcopy(self.unet)
output.unet.loras.append(
LoraInfo(
model_name=self.lora_name,
model_type=SDModelType.Lora,
submodel=None,
weight=self.weight,
)
)
if self.clip is not None:
output.clip = copy.deepcopy(self.clip)
output.clip.loras.append(
LoraInfo(
model_name=self.lora_name,
model_type=SDModelType.Lora,
submodel=None,
weight=self.weight,
)
)
return output

View File

@ -0,0 +1,237 @@
import io
from typing import Literal, Optional, Any
# from PIL.Image import Image
import PIL.Image
from matplotlib.ticker import MaxNLocator
from matplotlib.figure import Figure
from pydantic import BaseModel, Field
import numpy as np
import matplotlib.pyplot as plt
from easing_functions import (
LinearInOut,
QuadEaseInOut, QuadEaseIn, QuadEaseOut,
CubicEaseInOut, CubicEaseIn, CubicEaseOut,
QuarticEaseInOut, QuarticEaseIn, QuarticEaseOut,
QuinticEaseInOut, QuinticEaseIn, QuinticEaseOut,
SineEaseInOut, SineEaseIn, SineEaseOut,
CircularEaseIn, CircularEaseInOut, CircularEaseOut,
ExponentialEaseInOut, ExponentialEaseIn, ExponentialEaseOut,
ElasticEaseIn, ElasticEaseInOut, ElasticEaseOut,
BackEaseIn, BackEaseInOut, BackEaseOut,
BounceEaseIn, BounceEaseInOut, BounceEaseOut)
from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
InvocationContext,
InvocationConfig,
)
from ...backend.util.logging import InvokeAILogger
from .collections import FloatCollectionOutput
class FloatLinearRangeInvocation(BaseInvocation):
"""Creates a range"""
type: Literal["float_range"] = "float_range"
# Inputs
start: float = Field(default=5, description="The first value of the range")
stop: float = Field(default=10, description="The last value of the range")
steps: int = Field(default=30, description="number of values to interpolate over (including start and stop)")
def invoke(self, context: InvocationContext) -> FloatCollectionOutput:
param_list = list(np.linspace(self.start, self.stop, self.steps))
return FloatCollectionOutput(
collection=param_list
)
EASING_FUNCTIONS_MAP = {
"Linear": LinearInOut,
"QuadIn": QuadEaseIn,
"QuadOut": QuadEaseOut,
"QuadInOut": QuadEaseInOut,
"CubicIn": CubicEaseIn,
"CubicOut": CubicEaseOut,
"CubicInOut": CubicEaseInOut,
"QuarticIn": QuarticEaseIn,
"QuarticOut": QuarticEaseOut,
"QuarticInOut": QuarticEaseInOut,
"QuinticIn": QuinticEaseIn,
"QuinticOut": QuinticEaseOut,
"QuinticInOut": QuinticEaseInOut,
"SineIn": SineEaseIn,
"SineOut": SineEaseOut,
"SineInOut": SineEaseInOut,
"CircularIn": CircularEaseIn,
"CircularOut": CircularEaseOut,
"CircularInOut": CircularEaseInOut,
"ExponentialIn": ExponentialEaseIn,
"ExponentialOut": ExponentialEaseOut,
"ExponentialInOut": ExponentialEaseInOut,
"ElasticIn": ElasticEaseIn,
"ElasticOut": ElasticEaseOut,
"ElasticInOut": ElasticEaseInOut,
"BackIn": BackEaseIn,
"BackOut": BackEaseOut,
"BackInOut": BackEaseInOut,
"BounceIn": BounceEaseIn,
"BounceOut": BounceEaseOut,
"BounceInOut": BounceEaseInOut,
}
EASING_FUNCTION_KEYS: Any = Literal[
tuple(list(EASING_FUNCTIONS_MAP.keys()))
]
# actually I think for now could just use CollectionOutput (which is list[Any]
class StepParamEasingInvocation(BaseInvocation):
"""Experimental per-step parameter easing for denoising steps"""
type: Literal["step_param_easing"] = "step_param_easing"
# Inputs
# fmt: off
easing: EASING_FUNCTION_KEYS = Field(default="Linear", description="The easing function to use")
num_steps: int = Field(default=20, description="number of denoising steps")
start_value: float = Field(default=0.0, description="easing starting value")
end_value: float = Field(default=1.0, description="easing ending value")
start_step_percent: float = Field(default=0.0, description="fraction of steps at which to start easing")
end_step_percent: float = Field(default=1.0, description="fraction of steps after which to end easing")
# if None, then start_value is used prior to easing start
pre_start_value: Optional[float] = Field(default=None, description="value before easing start")
# if None, then end value is used prior to easing end
post_end_value: Optional[float] = Field(default=None, description="value after easing end")
mirror: bool = Field(default=False, description="include mirror of easing function")
# FIXME: add alt_mirror option (alternative to default or mirror), or remove entirely
# alt_mirror: bool = Field(default=False, description="alternative mirroring by dual easing")
show_easing_plot: bool = Field(default=False, description="show easing plot")
# fmt: on
def invoke(self, context: InvocationContext) -> FloatCollectionOutput:
log_diagnostics = False
# convert from start_step_percent to nearest step <= (steps * start_step_percent)
# start_step = int(np.floor(self.num_steps * self.start_step_percent))
start_step = int(np.round(self.num_steps * self.start_step_percent))
# convert from end_step_percent to nearest step >= (steps * end_step_percent)
# end_step = int(np.ceil((self.num_steps - 1) * self.end_step_percent))
end_step = int(np.round((self.num_steps - 1) * self.end_step_percent))
# end_step = int(np.ceil(self.num_steps * self.end_step_percent))
num_easing_steps = end_step - start_step + 1
# num_presteps = max(start_step - 1, 0)
num_presteps = start_step
num_poststeps = self.num_steps - (num_presteps + num_easing_steps)
prelist = list(num_presteps * [self.pre_start_value])
postlist = list(num_poststeps * [self.post_end_value])
if log_diagnostics:
logger = InvokeAILogger.getLogger(name="StepParamEasing")
logger.debug("start_step: " + str(start_step))
logger.debug("end_step: " + str(end_step))
logger.debug("num_easing_steps: " + str(num_easing_steps))
logger.debug("num_presteps: " + str(num_presteps))
logger.debug("num_poststeps: " + str(num_poststeps))
logger.debug("prelist size: " + str(len(prelist)))
logger.debug("postlist size: " + str(len(postlist)))
logger.debug("prelist: " + str(prelist))
logger.debug("postlist: " + str(postlist))
easing_class = EASING_FUNCTIONS_MAP[self.easing]
if log_diagnostics:
logger.debug("easing class: " + str(easing_class))
easing_list = list()
if self.mirror: # "expected" mirroring
# if number of steps is even, squeeze duration down to (number_of_steps)/2
# and create reverse copy of list to append
# if number of steps is odd, squeeze duration down to ceil(number_of_steps/2)
# and create reverse copy of list[1:end-1]
# but if even then number_of_steps/2 === ceil(number_of_steps/2), so can just use ceil always
base_easing_duration = int(np.ceil(num_easing_steps/2.0))
if log_diagnostics: logger.debug("base easing duration: " + str(base_easing_duration))
even_num_steps = (num_easing_steps % 2 == 0) # even number of steps
easing_function = easing_class(start=self.start_value,
end=self.end_value,
duration=base_easing_duration - 1)
base_easing_vals = list()
for step_index in range(base_easing_duration):
easing_val = easing_function.ease(step_index)
base_easing_vals.append(easing_val)
if log_diagnostics:
logger.debug("step_index: " + str(step_index) + ", easing_val: " + str(easing_val))
if even_num_steps:
mirror_easing_vals = list(reversed(base_easing_vals))
else:
mirror_easing_vals = list(reversed(base_easing_vals[0:-1]))
if log_diagnostics:
logger.debug("base easing vals: " + str(base_easing_vals))
logger.debug("mirror easing vals: " + str(mirror_easing_vals))
easing_list = base_easing_vals + mirror_easing_vals
# FIXME: add alt_mirror option (alternative to default or mirror), or remove entirely
# elif self.alt_mirror: # function mirroring (unintuitive behavior (at least to me))
# # half_ease_duration = round(num_easing_steps - 1 / 2)
# half_ease_duration = round((num_easing_steps - 1) / 2)
# easing_function = easing_class(start=self.start_value,
# end=self.end_value,
# duration=half_ease_duration,
# )
#
# mirror_function = easing_class(start=self.end_value,
# end=self.start_value,
# duration=half_ease_duration,
# )
# for step_index in range(num_easing_steps):
# if step_index <= half_ease_duration:
# step_val = easing_function.ease(step_index)
# else:
# step_val = mirror_function.ease(step_index - half_ease_duration)
# easing_list.append(step_val)
# if log_diagnostics: logger.debug(step_index, step_val)
#
else: # no mirroring (default)
easing_function = easing_class(start=self.start_value,
end=self.end_value,
duration=num_easing_steps - 1)
for step_index in range(num_easing_steps):
step_val = easing_function.ease(step_index)
easing_list.append(step_val)
if log_diagnostics:
logger.debug("step_index: " + str(step_index) + ", easing_val: " + str(step_val))
if log_diagnostics:
logger.debug("prelist size: " + str(len(prelist)))
logger.debug("easing_list size: " + str(len(easing_list)))
logger.debug("postlist size: " + str(len(postlist)))
param_list = prelist + easing_list + postlist
if self.show_easing_plot:
plt.figure()
plt.xlabel("Step")
plt.ylabel("Param Value")
plt.title("Per-Step Values Based On Easing: " + self.easing)
plt.bar(range(len(param_list)), param_list)
# plt.plot(param_list)
ax = plt.gca()
ax.xaxis.set_major_locator(MaxNLocator(integer=True))
buf = io.BytesIO()
plt.savefig(buf, format='png')
buf.seek(0)
im = PIL.Image.open(buf)
im.show()
buf.close()
# output array of size steps, each entry list[i] is param value for step i
return FloatCollectionOutput(
collection=param_list
)

View File

@ -3,7 +3,7 @@
from typing import Literal
from pydantic import Field
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
from .math import IntOutput
from .math import IntOutput, FloatOutput
# Pass-through parameter nodes - used by subgraphs
@ -16,3 +16,13 @@ class ParamIntInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> IntOutput:
return IntOutput(a=self.a)
class ParamFloatInvocation(BaseInvocation):
"""A float parameter"""
#fmt: off
type: Literal["param_float"] = "param_float"
param: float = Field(default=0.0, description="The float value")
#fmt: on
def invoke(self, context: InvocationContext) -> FloatOutput:
return FloatOutput(param=self.param)

View File

@ -2,8 +2,8 @@ from typing import Literal
from pydantic.fields import Field
from .baseinvocation import BaseInvocationOutput
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
from dynamicprompts.generators import RandomPromptGenerator, CombinatorialPromptGenerator
class PromptOutput(BaseInvocationOutput):
"""Base class for invocations that output a prompt"""
@ -20,3 +20,38 @@ class PromptOutput(BaseInvocationOutput):
'prompt',
]
}
class PromptCollectionOutput(BaseInvocationOutput):
"""Base class for invocations that output a collection of prompts"""
# fmt: off
type: Literal["prompt_collection_output"] = "prompt_collection_output"
prompt_collection: list[str] = Field(description="The output prompt collection")
count: int = Field(description="The size of the prompt collection")
# fmt: on
class Config:
schema_extra = {"required": ["type", "prompt_collection", "count"]}
class DynamicPromptInvocation(BaseInvocation):
"""Parses a prompt using adieyal/dynamicprompts' random or combinatorial generator"""
type: Literal["dynamic_prompt"] = "dynamic_prompt"
prompt: str = Field(description="The prompt to parse with dynamicprompts")
max_prompts: int = Field(default=1, description="The number of prompts to generate")
combinatorial: bool = Field(
default=False, description="Whether to use the combinatorial generator"
)
def invoke(self, context: InvocationContext) -> PromptCollectionOutput:
if self.combinatorial:
generator = CombinatorialPromptGenerator()
prompts = generator.generate(self.prompt, max_prompts=self.max_prompts)
else:
generator = RandomPromptGenerator()
prompts = generator.generate(self.prompt, num_images=self.max_prompts)
return PromptCollectionOutput(prompt_collection=prompts, count=len(prompts))

View File

@ -2,21 +2,23 @@ from typing import Literal, Union
from pydantic import Field
from invokeai.app.models.image import ImageField, ImageType
from invokeai.app.models.image import ImageCategory, ImageField, ResourceOrigin
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
from .image import ImageOutput, build_image_output
from .image import ImageOutput
class RestoreFaceInvocation(BaseInvocation):
"""Restores faces in an image."""
#fmt: off
# fmt: off
type: Literal["restore_face"] = "restore_face"
# Inputs
image: Union[ImageField, None] = Field(description="The input image")
strength: float = Field(default=0.75, gt=0, le=1, description="The strength of the restoration" )
#fmt: on
# fmt: on
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
@ -26,9 +28,7 @@ class RestoreFaceInvocation(BaseInvocation):
}
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get(
self.image.image_type, self.image.image_name
)
image = context.services.images.get_pil_image(self.image.image_name)
results = context.services.restoration.upscale_and_reconstruct(
image_list=[[image, 0]],
upscale=None,
@ -39,18 +39,17 @@ class RestoreFaceInvocation(BaseInvocation):
# Results are image and seed, unwrap for now
# TODO: can this return multiple results?
image_type = ImageType.RESULT
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
image_dto = context.services.images.create(
image=results[0][0],
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
)
metadata = context.services.metadata.build_metadata(
session_id=context.graph_execution_state_id, node=self
return ImageOutput(
image=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
)
context.services.images.save(image_type, image_name, results[0][0], metadata)
return build_image_output(
image_type=image_type,
image_name=image_name,
image=results[0][0]
)

View File

@ -4,22 +4,22 @@ from typing import Literal, Union
from pydantic import Field
from invokeai.app.models.image import ImageField, ImageType
from invokeai.app.models.image import ImageCategory, ImageField, ResourceOrigin
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
from .image import ImageOutput, build_image_output
from .image import ImageOutput
class UpscaleInvocation(BaseInvocation):
"""Upscales an image."""
#fmt: off
# fmt: off
type: Literal["upscale"] = "upscale"
# Inputs
image: Union[ImageField, None] = Field(description="The input image", default=None)
strength: float = Field(default=0.75, gt=0, le=1, description="The strength")
level: Literal[2, 4] = Field(default=2, description="The upscale level")
#fmt: on
# fmt: on
# Schema customisation
class Config(InvocationConfig):
@ -30,9 +30,7 @@ class UpscaleInvocation(BaseInvocation):
}
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get(
self.image.image_type, self.image.image_name
)
image = context.services.images.get_pil_image(self.image.image_name)
results = context.services.restoration.upscale_and_reconstruct(
image_list=[[image, 0]],
upscale=(self.level, self.strength),
@ -43,18 +41,17 @@ class UpscaleInvocation(BaseInvocation):
# Results are image and seed, unwrap for now
# TODO: can this return multiple results?
image_type = ImageType.RESULT
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
image_dto = context.services.images.create(
image=results[0][0],
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
)
metadata = context.services.metadata.build_metadata(
session_id=context.graph_execution_state_id, node=self
return ImageOutput(
image=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
)
context.services.images.save(image_type, image_name, results[0][0], metadata)
return build_image_output(
image_type=image_type,
image_name=image_name,
image=results[0][0]
)

View File

@ -1,14 +0,0 @@
from invokeai.backend.model_management.model_manager import ModelManager
def choose_model(model_manager: ModelManager, model_name: str):
"""Returns the default model if the `model_name` not a valid model, else returns the selected model."""
logger = model_manager.logger
if model_name and not model_manager.valid_model(model_name):
default_model_name = model_manager.default_model()
logger.warning(f"\'{model_name}\' is not a valid model name. Using default model \'{default_model_name}\' instead.")
model = model_manager.get_model()
else:
model = model_manager.get_model(model_name)
return model

View File

@ -2,31 +2,74 @@ from enum import Enum
from typing import Optional, Tuple
from pydantic import BaseModel, Field
class ImageType(str, Enum):
RESULT = "results"
INTERMEDIATE = "intermediates"
UPLOAD = "uploads"
from invokeai.app.util.metaenum import MetaEnum
def is_image_type(obj):
try:
ImageType(obj)
except ValueError:
return False
return True
class ResourceOrigin(str, Enum, metaclass=MetaEnum):
"""The origin of a resource (eg image).
- INTERNAL: The resource was created by the application.
- EXTERNAL: The resource was not created by the application.
This may be a user-initiated upload, or an internal application upload (eg Canvas init image).
"""
INTERNAL = "internal"
"""The resource was created by the application."""
EXTERNAL = "external"
"""The resource was not created by the application.
This may be a user-initiated upload, or an internal application upload (eg Canvas init image).
"""
class InvalidOriginException(ValueError):
"""Raised when a provided value is not a valid ResourceOrigin.
Subclasses `ValueError`.
"""
def __init__(self, message="Invalid resource origin."):
super().__init__(message)
class ImageCategory(str, Enum, metaclass=MetaEnum):
"""The category of an image.
- GENERAL: The image is an output, init image, or otherwise an image without a specialized purpose.
- MASK: The image is a mask image.
- CONTROL: The image is a ControlNet control image.
- USER: The image is a user-provide image.
- OTHER: The image is some other type of image with a specialized purpose. To be used by external nodes.
"""
GENERAL = "general"
"""GENERAL: The image is an output, init image, or otherwise an image without a specialized purpose."""
MASK = "mask"
"""MASK: The image is a mask image."""
CONTROL = "control"
"""CONTROL: The image is a ControlNet control image."""
USER = "user"
"""USER: The image is a user-provide image."""
OTHER = "other"
"""OTHER: The image is some other type of image with a specialized purpose. To be used by external nodes."""
class InvalidImageCategoryException(ValueError):
"""Raised when a provided value is not a valid ImageCategory.
Subclasses `ValueError`.
"""
def __init__(self, message="Invalid image category."):
super().__init__(message)
class ImageField(BaseModel):
"""An image field used for passing image objects between invocations"""
image_type: ImageType = Field(
default=ImageType.RESULT, description="The type of the image"
)
image_name: Optional[str] = Field(default=None, description="The name of the image")
class Config:
schema_extra = {"required": ["image_type", "image_name"]}
schema_extra = {"required": ["image_name"]}
class ColorField(BaseModel):
@ -37,3 +80,11 @@ class ColorField(BaseModel):
def tuple(self) -> Tuple[int, int, int, int]:
return (self.r, self.g, self.b, self.a)
class ProgressImage(BaseModel):
"""The progress image sent intermittently during processing"""
width: int = Field(description="The effective width of the image in pixels")
height: int = Field(description="The effective height of the image in pixels")
dataURL: str = Field(description="The image data as a b64 data URL")

View File

@ -0,0 +1,93 @@
from typing import Optional, Union, List
from pydantic import BaseModel, Extra, Field, StrictFloat, StrictInt, StrictStr
class ImageMetadata(BaseModel):
"""
Core generation metadata for an image/tensor generated in InvokeAI.
Also includes any metadata from the image's PNG tEXt chunks.
Generated by traversing the execution graph, collecting the parameters of the nearest ancestors
of a given node.
Full metadata may be accessed by querying for the session in the `graph_executions` table.
"""
class Config:
extra = Extra.allow
"""
This lets the ImageMetadata class accept arbitrary additional fields. The CoreMetadataService
won't add any fields that are not already defined, but other a different metadata service
implementation might.
"""
type: Optional[StrictStr] = Field(
default=None,
description="The type of the ancestor node of the image output node.",
)
"""The type of the ancestor node of the image output node."""
positive_conditioning: Optional[StrictStr] = Field(
default=None, description="The positive conditioning."
)
"""The positive conditioning"""
negative_conditioning: Optional[StrictStr] = Field(
default=None, description="The negative conditioning."
)
"""The negative conditioning"""
width: Optional[StrictInt] = Field(
default=None, description="Width of the image/latents in pixels."
)
"""Width of the image/latents in pixels"""
height: Optional[StrictInt] = Field(
default=None, description="Height of the image/latents in pixels."
)
"""Height of the image/latents in pixels"""
seed: Optional[StrictInt] = Field(
default=None, description="The seed used for noise generation."
)
"""The seed used for noise generation"""
# cfg_scale: Optional[StrictFloat] = Field(
# cfg_scale: Union[float, list[float]] = Field(
cfg_scale: Union[StrictFloat, List[StrictFloat]] = Field(
default=None, description="The classifier-free guidance scale."
)
"""The classifier-free guidance scale"""
steps: Optional[StrictInt] = Field(
default=None, description="The number of steps used for inference."
)
"""The number of steps used for inference"""
scheduler: Optional[StrictStr] = Field(
default=None, description="The scheduler used for inference."
)
"""The scheduler used for inference"""
model: Optional[StrictStr] = Field(
default=None, description="The model used for inference."
)
"""The model used for inference"""
strength: Optional[StrictFloat] = Field(
default=None,
description="The strength used for image-to-image/latents-to-latents.",
)
"""The strength used for image-to-image/latents-to-latents."""
latents: Optional[StrictStr] = Field(
default=None, description="The ID of the initial latents."
)
"""The ID of the initial latents"""
vae: Optional[StrictStr] = Field(
default=None, description="The VAE used for decoding."
)
"""The VAE used for decoding"""
unet: Optional[StrictStr] = Field(
default=None, description="The UNet used dor inference."
)
"""The UNet used dor inference"""
clip: Optional[StrictStr] = Field(
default=None, description="The CLIP Encoder used for conditioning."
)
"""The CLIP Encoder used for conditioning"""
extra: Optional[StrictStr] = Field(
default=None,
description="Uploaded image metadata, extracted from the PNG tEXt chunk.",
)
"""Uploaded image metadata, extracted from the PNG tEXt chunk."""

View File

@ -0,0 +1,254 @@
from abc import ABC, abstractmethod
import sqlite3
import threading
from typing import Union, cast
from invokeai.app.services.board_record_storage import BoardRecord
from invokeai.app.services.image_record_storage import OffsetPaginatedResults
from invokeai.app.services.models.image_record import (
ImageRecord,
deserialize_image_record,
)
class BoardImageRecordStorageBase(ABC):
"""Abstract base class for the one-to-many board-image relationship record storage."""
@abstractmethod
def add_image_to_board(
self,
board_id: str,
image_name: str,
) -> None:
"""Adds an image to a board."""
pass
@abstractmethod
def remove_image_from_board(
self,
board_id: str,
image_name: str,
) -> None:
"""Removes an image from a board."""
pass
@abstractmethod
def get_images_for_board(
self,
board_id: str,
) -> OffsetPaginatedResults[ImageRecord]:
"""Gets images for a board."""
pass
@abstractmethod
def get_board_for_image(
self,
image_name: str,
) -> Union[str, None]:
"""Gets an image's board id, if it has one."""
pass
@abstractmethod
def get_image_count_for_board(
self,
board_id: str,
) -> int:
"""Gets the number of images for a board."""
pass
class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
_filename: str
_conn: sqlite3.Connection
_cursor: sqlite3.Cursor
_lock: threading.Lock
def __init__(self, filename: str) -> None:
super().__init__()
self._filename = filename
self._conn = sqlite3.connect(filename, check_same_thread=False)
# Enable row factory to get rows as dictionaries (must be done before making the cursor!)
self._conn.row_factory = sqlite3.Row
self._cursor = self._conn.cursor()
self._lock = threading.Lock()
try:
self._lock.acquire()
# Enable foreign keys
self._conn.execute("PRAGMA foreign_keys = ON;")
self._create_tables()
self._conn.commit()
finally:
self._lock.release()
def _create_tables(self) -> None:
"""Creates the `board_images` junction table."""
# Create the `board_images` junction table.
self._cursor.execute(
"""--sql
CREATE TABLE IF NOT EXISTS board_images (
board_id TEXT NOT NULL,
image_name TEXT NOT NULL,
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- updated via trigger
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Soft delete, currently unused
deleted_at DATETIME,
-- enforce one-to-many relationship between boards and images using PK
-- (we can extend this to many-to-many later)
PRIMARY KEY (image_name),
FOREIGN KEY (board_id) REFERENCES boards (board_id) ON DELETE CASCADE,
FOREIGN KEY (image_name) REFERENCES images (image_name) ON DELETE CASCADE
);
"""
)
# Add index for board id
self._cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_board_images_board_id ON board_images (board_id);
"""
)
# Add index for board id, sorted by created_at
self._cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_board_images_board_id_created_at ON board_images (board_id, created_at);
"""
)
# Add trigger for `updated_at`.
self._cursor.execute(
"""--sql
CREATE TRIGGER IF NOT EXISTS tg_board_images_updated_at
AFTER UPDATE
ON board_images FOR EACH ROW
BEGIN
UPDATE board_images SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
WHERE board_id = old.board_id AND image_name = old.image_name;
END;
"""
)
def add_image_to_board(
self,
board_id: str,
image_name: str,
) -> None:
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
INSERT INTO board_images (board_id, image_name)
VALUES (?, ?)
ON CONFLICT (image_name) DO UPDATE SET board_id = ?;
""",
(board_id, image_name, board_id),
)
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
raise e
finally:
self._lock.release()
def remove_image_from_board(
self,
board_id: str,
image_name: str,
) -> None:
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
DELETE FROM board_images
WHERE board_id = ? AND image_name = ?;
""",
(board_id, image_name),
)
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
raise e
finally:
self._lock.release()
def get_images_for_board(
self,
board_id: str,
offset: int = 0,
limit: int = 10,
) -> OffsetPaginatedResults[ImageRecord]:
# TODO: this isn't paginated yet?
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
SELECT images.*
FROM board_images
INNER JOIN images ON board_images.image_name = images.image_name
WHERE board_images.board_id = ?
ORDER BY board_images.updated_at DESC;
""",
(board_id,),
)
result = cast(list[sqlite3.Row], self._cursor.fetchall())
images = list(map(lambda r: deserialize_image_record(dict(r)), result))
self._cursor.execute(
"""--sql
SELECT COUNT(*) FROM images WHERE 1=1;
"""
)
count = cast(int, self._cursor.fetchone()[0])
except sqlite3.Error as e:
self._conn.rollback()
raise e
finally:
self._lock.release()
return OffsetPaginatedResults(
items=images, offset=offset, limit=limit, total=count
)
def get_board_for_image(
self,
image_name: str,
) -> Union[str, None]:
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
SELECT board_id
FROM board_images
WHERE image_name = ?;
""",
(image_name,),
)
result = self._cursor.fetchone()
if result is None:
return None
return cast(str, result[0])
except sqlite3.Error as e:
self._conn.rollback()
raise e
finally:
self._lock.release()
def get_image_count_for_board(self, board_id: str) -> int:
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
SELECT COUNT(*) FROM board_images WHERE board_id = ?;
""",
(board_id,),
)
count = cast(int, self._cursor.fetchone()[0])
return count
except sqlite3.Error as e:
self._conn.rollback()
raise e
finally:
self._lock.release()

View File

@ -0,0 +1,142 @@
from abc import ABC, abstractmethod
from logging import Logger
from typing import List, Union
from invokeai.app.services.board_image_record_storage import BoardImageRecordStorageBase
from invokeai.app.services.board_record_storage import (
BoardRecord,
BoardRecordStorageBase,
)
from invokeai.app.services.image_record_storage import (
ImageRecordStorageBase,
OffsetPaginatedResults,
)
from invokeai.app.services.models.board_record import BoardDTO
from invokeai.app.services.models.image_record import ImageDTO, image_record_to_dto
from invokeai.app.services.urls import UrlServiceBase
class BoardImagesServiceABC(ABC):
"""High-level service for board-image relationship management."""
@abstractmethod
def add_image_to_board(
self,
board_id: str,
image_name: str,
) -> None:
"""Adds an image to a board."""
pass
@abstractmethod
def remove_image_from_board(
self,
board_id: str,
image_name: str,
) -> None:
"""Removes an image from a board."""
pass
@abstractmethod
def get_images_for_board(
self,
board_id: str,
) -> OffsetPaginatedResults[ImageDTO]:
"""Gets images for a board."""
pass
@abstractmethod
def get_board_for_image(
self,
image_name: str,
) -> Union[str, None]:
"""Gets an image's board id, if it has one."""
pass
class BoardImagesServiceDependencies:
"""Service dependencies for the BoardImagesService."""
board_image_records: BoardImageRecordStorageBase
board_records: BoardRecordStorageBase
image_records: ImageRecordStorageBase
urls: UrlServiceBase
logger: Logger
def __init__(
self,
board_image_record_storage: BoardImageRecordStorageBase,
image_record_storage: ImageRecordStorageBase,
board_record_storage: BoardRecordStorageBase,
url: UrlServiceBase,
logger: Logger,
):
self.board_image_records = board_image_record_storage
self.image_records = image_record_storage
self.board_records = board_record_storage
self.urls = url
self.logger = logger
class BoardImagesService(BoardImagesServiceABC):
_services: BoardImagesServiceDependencies
def __init__(self, services: BoardImagesServiceDependencies):
self._services = services
def add_image_to_board(
self,
board_id: str,
image_name: str,
) -> None:
self._services.board_image_records.add_image_to_board(board_id, image_name)
def remove_image_from_board(
self,
board_id: str,
image_name: str,
) -> None:
self._services.board_image_records.remove_image_from_board(board_id, image_name)
def get_images_for_board(
self,
board_id: str,
) -> OffsetPaginatedResults[ImageDTO]:
image_records = self._services.board_image_records.get_images_for_board(
board_id
)
image_dtos = list(
map(
lambda r: image_record_to_dto(
r,
self._services.urls.get_image_url(r.image_name),
self._services.urls.get_image_url(r.image_name, True),
board_id,
),
image_records.items,
)
)
return OffsetPaginatedResults[ImageDTO](
items=image_dtos,
offset=image_records.offset,
limit=image_records.limit,
total=image_records.total,
)
def get_board_for_image(
self,
image_name: str,
) -> Union[str, None]:
board_id = self._services.board_image_records.get_board_for_image(image_name)
return board_id
def board_record_to_dto(
board_record: BoardRecord, cover_image_name: str | None, image_count: int
) -> BoardDTO:
"""Converts a board record to a board DTO."""
return BoardDTO(
**board_record.dict(exclude={'cover_image_name'}),
cover_image_name=cover_image_name,
image_count=image_count,
)

View File

@ -0,0 +1,329 @@
from abc import ABC, abstractmethod
from typing import Optional, cast
import sqlite3
import threading
from typing import Optional, Union
import uuid
from invokeai.app.services.image_record_storage import OffsetPaginatedResults
from invokeai.app.services.models.board_record import (
BoardRecord,
deserialize_board_record,
)
from pydantic import BaseModel, Field, Extra
class BoardChanges(BaseModel, extra=Extra.forbid):
board_name: Optional[str] = Field(description="The board's new name.")
cover_image_name: Optional[str] = Field(
description="The name of the board's new cover image."
)
class BoardRecordNotFoundException(Exception):
"""Raised when an board record is not found."""
def __init__(self, message="Board record not found"):
super().__init__(message)
class BoardRecordSaveException(Exception):
"""Raised when an board record cannot be saved."""
def __init__(self, message="Board record not saved"):
super().__init__(message)
class BoardRecordDeleteException(Exception):
"""Raised when an board record cannot be deleted."""
def __init__(self, message="Board record not deleted"):
super().__init__(message)
class BoardRecordStorageBase(ABC):
"""Low-level service responsible for interfacing with the board record store."""
@abstractmethod
def delete(self, board_id: str) -> None:
"""Deletes a board record."""
pass
@abstractmethod
def save(
self,
board_name: str,
) -> BoardRecord:
"""Saves a board record."""
pass
@abstractmethod
def get(
self,
board_id: str,
) -> BoardRecord:
"""Gets a board record."""
pass
@abstractmethod
def update(
self,
board_id: str,
changes: BoardChanges,
) -> BoardRecord:
"""Updates a board record."""
pass
@abstractmethod
def get_many(
self,
offset: int = 0,
limit: int = 10,
) -> OffsetPaginatedResults[BoardRecord]:
"""Gets many board records."""
pass
@abstractmethod
def get_all(
self,
) -> list[BoardRecord]:
"""Gets all board records."""
pass
class SqliteBoardRecordStorage(BoardRecordStorageBase):
_filename: str
_conn: sqlite3.Connection
_cursor: sqlite3.Cursor
_lock: threading.Lock
def __init__(self, filename: str) -> None:
super().__init__()
self._filename = filename
self._conn = sqlite3.connect(filename, check_same_thread=False)
# Enable row factory to get rows as dictionaries (must be done before making the cursor!)
self._conn.row_factory = sqlite3.Row
self._cursor = self._conn.cursor()
self._lock = threading.Lock()
try:
self._lock.acquire()
# Enable foreign keys
self._conn.execute("PRAGMA foreign_keys = ON;")
self._create_tables()
self._conn.commit()
finally:
self._lock.release()
def _create_tables(self) -> None:
"""Creates the `boards` table and `board_images` junction table."""
# Create the `boards` table.
self._cursor.execute(
"""--sql
CREATE TABLE IF NOT EXISTS boards (
board_id TEXT NOT NULL PRIMARY KEY,
board_name TEXT NOT NULL,
cover_image_name TEXT,
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Updated via trigger
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Soft delete, currently unused
deleted_at DATETIME,
FOREIGN KEY (cover_image_name) REFERENCES images (image_name) ON DELETE SET NULL
);
"""
)
self._cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_boards_created_at ON boards (created_at);
"""
)
# Add trigger for `updated_at`.
self._cursor.execute(
"""--sql
CREATE TRIGGER IF NOT EXISTS tg_boards_updated_at
AFTER UPDATE
ON boards FOR EACH ROW
BEGIN
UPDATE boards SET updated_at = current_timestamp
WHERE board_id = old.board_id;
END;
"""
)
def delete(self, board_id: str) -> None:
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
DELETE FROM boards
WHERE board_id = ?;
""",
(board_id,),
)
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
raise BoardRecordDeleteException from e
except Exception as e:
self._conn.rollback()
raise BoardRecordDeleteException from e
finally:
self._lock.release()
def save(
self,
board_name: str,
) -> BoardRecord:
try:
board_id = str(uuid.uuid4())
self._lock.acquire()
self._cursor.execute(
"""--sql
INSERT OR IGNORE INTO boards (board_id, board_name)
VALUES (?, ?);
""",
(board_id, board_name),
)
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
raise BoardRecordSaveException from e
finally:
self._lock.release()
return self.get(board_id)
def get(
self,
board_id: str,
) -> BoardRecord:
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
SELECT *
FROM boards
WHERE board_id = ?;
""",
(board_id,),
)
result = cast(Union[sqlite3.Row, None], self._cursor.fetchone())
except sqlite3.Error as e:
self._conn.rollback()
raise BoardRecordNotFoundException from e
finally:
self._lock.release()
if result is None:
raise BoardRecordNotFoundException
return BoardRecord(**dict(result))
def update(
self,
board_id: str,
changes: BoardChanges,
) -> BoardRecord:
try:
self._lock.acquire()
# Change the name of a board
if changes.board_name is not None:
self._cursor.execute(
f"""--sql
UPDATE boards
SET board_name = ?
WHERE board_id = ?;
""",
(changes.board_name, board_id),
)
# Change the cover image of a board
if changes.cover_image_name is not None:
self._cursor.execute(
f"""--sql
UPDATE boards
SET cover_image_name = ?
WHERE board_id = ?;
""",
(changes.cover_image_name, board_id),
)
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
raise BoardRecordSaveException from e
finally:
self._lock.release()
return self.get(board_id)
def get_many(
self,
offset: int = 0,
limit: int = 10,
) -> OffsetPaginatedResults[BoardRecord]:
try:
self._lock.acquire()
# Get all the boards
self._cursor.execute(
"""--sql
SELECT *
FROM boards
ORDER BY created_at DESC
LIMIT ? OFFSET ?;
""",
(limit, offset),
)
result = cast(list[sqlite3.Row], self._cursor.fetchall())
boards = list(map(lambda r: deserialize_board_record(dict(r)), result))
# Get the total number of boards
self._cursor.execute(
"""--sql
SELECT COUNT(*)
FROM boards
WHERE 1=1;
"""
)
count = cast(int, self._cursor.fetchone()[0])
return OffsetPaginatedResults[BoardRecord](
items=boards, offset=offset, limit=limit, total=count
)
except sqlite3.Error as e:
self._conn.rollback()
raise e
finally:
self._lock.release()
def get_all(
self,
) -> list[BoardRecord]:
try:
self._lock.acquire()
# Get all the boards
self._cursor.execute(
"""--sql
SELECT *
FROM boards
ORDER BY created_at DESC
"""
)
result = cast(list[sqlite3.Row], self._cursor.fetchall())
boards = list(map(lambda r: deserialize_board_record(dict(r)), result))
return boards
except sqlite3.Error as e:
self._conn.rollback()
raise e
finally:
self._lock.release()

View File

@ -0,0 +1,185 @@
from abc import ABC, abstractmethod
from logging import Logger
from invokeai.app.services.board_image_record_storage import BoardImageRecordStorageBase
from invokeai.app.services.board_images import board_record_to_dto
from invokeai.app.services.board_record_storage import (
BoardChanges,
BoardRecordStorageBase,
)
from invokeai.app.services.image_record_storage import (
ImageRecordStorageBase,
OffsetPaginatedResults,
)
from invokeai.app.services.models.board_record import BoardDTO
from invokeai.app.services.urls import UrlServiceBase
class BoardServiceABC(ABC):
"""High-level service for board management."""
@abstractmethod
def create(
self,
board_name: str,
) -> BoardDTO:
"""Creates a board."""
pass
@abstractmethod
def get_dto(
self,
board_id: str,
) -> BoardDTO:
"""Gets a board."""
pass
@abstractmethod
def update(
self,
board_id: str,
changes: BoardChanges,
) -> BoardDTO:
"""Updates a board."""
pass
@abstractmethod
def delete(
self,
board_id: str,
) -> None:
"""Deletes a board."""
pass
@abstractmethod
def get_many(
self,
offset: int = 0,
limit: int = 10,
) -> OffsetPaginatedResults[BoardDTO]:
"""Gets many boards."""
pass
@abstractmethod
def get_all(
self,
) -> list[BoardDTO]:
"""Gets all boards."""
pass
class BoardServiceDependencies:
"""Service dependencies for the BoardService."""
board_image_records: BoardImageRecordStorageBase
board_records: BoardRecordStorageBase
image_records: ImageRecordStorageBase
urls: UrlServiceBase
logger: Logger
def __init__(
self,
board_image_record_storage: BoardImageRecordStorageBase,
image_record_storage: ImageRecordStorageBase,
board_record_storage: BoardRecordStorageBase,
url: UrlServiceBase,
logger: Logger,
):
self.board_image_records = board_image_record_storage
self.image_records = image_record_storage
self.board_records = board_record_storage
self.urls = url
self.logger = logger
class BoardService(BoardServiceABC):
_services: BoardServiceDependencies
def __init__(self, services: BoardServiceDependencies):
self._services = services
def create(
self,
board_name: str,
) -> BoardDTO:
board_record = self._services.board_records.save(board_name)
return board_record_to_dto(board_record, None, 0)
def get_dto(self, board_id: str) -> BoardDTO:
board_record = self._services.board_records.get(board_id)
cover_image = self._services.image_records.get_most_recent_image_for_board(
board_record.board_id
)
if cover_image:
cover_image_name = cover_image.image_name
else:
cover_image_name = None
image_count = self._services.board_image_records.get_image_count_for_board(
board_id
)
return board_record_to_dto(board_record, cover_image_name, image_count)
def update(
self,
board_id: str,
changes: BoardChanges,
) -> BoardDTO:
board_record = self._services.board_records.update(board_id, changes)
cover_image = self._services.image_records.get_most_recent_image_for_board(
board_record.board_id
)
if cover_image:
cover_image_name = cover_image.image_name
else:
cover_image_name = None
image_count = self._services.board_image_records.get_image_count_for_board(
board_id
)
return board_record_to_dto(board_record, cover_image_name, image_count)
def delete(self, board_id: str) -> None:
self._services.board_records.delete(board_id)
def get_many(
self, offset: int = 0, limit: int = 10
) -> OffsetPaginatedResults[BoardDTO]:
board_records = self._services.board_records.get_many(offset, limit)
board_dtos = []
for r in board_records.items:
cover_image = self._services.image_records.get_most_recent_image_for_board(
r.board_id
)
if cover_image:
cover_image_name = cover_image.image_name
else:
cover_image_name = None
image_count = self._services.board_image_records.get_image_count_for_board(
r.board_id
)
board_dtos.append(board_record_to_dto(r, cover_image_name, image_count))
return OffsetPaginatedResults[BoardDTO](
items=board_dtos, offset=offset, limit=limit, total=len(board_dtos)
)
def get_all(self) -> list[BoardDTO]:
board_records = self._services.board_records.get_all()
board_dtos = []
for r in board_records:
cover_image = self._services.image_records.get_most_recent_image_for_board(
r.board_id
)
if cover_image:
cover_image_name = cover_image.image_name
else:
cover_image_name = None
image_count = self._services.board_image_records.get_image_count_for_board(
r.board_id
)
board_dtos.append(board_record_to_dto(r, cover_image_name, image_count))
return board_dtos

View File

@ -15,10 +15,7 @@ InvokeAI:
conf_path: configs/models.yaml
legacy_conf_dir: configs/stable-diffusion
outdir: outputs
embedding_dir: embeddings
lora_dir: loras
autoconvert_dir: null
gfpgan_model_dir: models/gfpgan/GFPGANv1.4.pth
Models:
model: stable-diffusion-1.5
embeddings: true
@ -51,18 +48,32 @@ in INVOKEAI_ROOT. You can replace supersede this by providing any
OmegaConf dictionary object initialization time:
omegaconf = OmegaConf.load('/tmp/init.yaml')
conf = InvokeAIAppConfig(conf=omegaconf)
conf = InvokeAIAppConfig()
conf.parse_args(conf=omegaconf)
By default, InvokeAIAppConfig will parse the contents of `sys.argv` at
initialization time. You may pass a list of strings in the optional
InvokeAIAppConfig.parse_args() will parse the contents of `sys.argv`
at initialization time. You may pass a list of strings in the optional
`argv` argument to use instead of the system argv:
conf = InvokeAIAppConfig(arg=['--xformers_enabled'])
conf.parse_args(argv=['--xformers_enabled'])
It is also possible to set a value at initialization time. This value
has highest priority.
It is also possible to set a value at initialization time. However, if
you call parse_args() it may be overwritten.
conf = InvokeAIAppConfig(xformers_enabled=True)
conf.parse_args(argv=['--no-xformers'])
conf.xformers_enabled
# False
To avoid this, use `get_config()` to retrieve the application-wide
configuration object. This will retain any properties set at object
creation time:
conf = InvokeAIAppConfig.get_config(xformers_enabled=True)
conf.parse_args(argv=['--no-xformers'])
conf.xformers_enabled
# True
Any setting can be overwritten by setting an environment variable of
form: "INVOKEAI_<setting>", as in:
@ -76,18 +87,23 @@ Order of precedence (from highest):
4) config file options
5) pydantic defaults
Typical usage:
Typical usage at the top level file:
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.invocations.generate import TextToImageInvocation
# get global configuration and print its nsfw_checker value
conf = InvokeAIAppConfig()
conf = InvokeAIAppConfig.get_config()
conf.parse_args()
print(conf.nsfw_checker)
Typical usage in a backend module:
from invokeai.app.services.config import InvokeAIAppConfig
# get global configuration and print its nsfw_checker value
conf = InvokeAIAppConfig.get_config()
print(conf.nsfw_checker)
# get the text2image invocation and print its step value
text2image = TextToImageInvocation()
print(text2image.steps)
Computed properties:
@ -103,10 +119,11 @@ a Path object:
lora_path - path to the LoRA directory
In most cases, you will want to create a single InvokeAIAppConfig
object for the entire application. The get_invokeai_config() function
object for the entire application. The InvokeAIAppConfig.get_config() function
does this:
config = get_invokeai_config()
config = InvokeAIAppConfig.get_config()
config.parse_args() # read values from the command line/config file
print(config.root)
# Subclassing
@ -140,24 +157,23 @@ two configs are kept in separate sections of the config file:
legacy_conf_dir: configs/stable-diffusion
outdir: outputs
...
'''
from __future__ import annotations
import argparse
import pydoc
import typing
import os
import sys
from argparse import ArgumentParser
from omegaconf import OmegaConf, DictConfig
from pathlib import Path
from pydantic import BaseSettings, Field, parse_obj_as
from typing import Any, ClassVar, Dict, List, Literal, Type, Union, get_origin, get_type_hints, get_args
from typing import ClassVar, Dict, List, Literal, Union, get_origin, get_type_hints, get_args
INIT_FILE = Path('invokeai.yaml')
DB_FILE = Path('invokeai.db')
LEGACY_INIT_FILE = Path('invokeai.init')
# This global stores a singleton InvokeAIAppConfig configuration object
global_config = None
class InvokeAISettings(BaseSettings):
'''
Runtime configuration settings in which default values are
@ -168,7 +184,7 @@ class InvokeAISettings(BaseSettings):
def parse_args(self, argv: list=sys.argv[1:]):
parser = self.get_parser()
opt, _ = parser.parse_known_args(argv)
opt = parser.parse_args(argv)
for name in self.__fields__:
if name not in self._excluded():
setattr(self, name, getattr(opt,name))
@ -330,6 +346,9 @@ the command-line client (recommended for experts only), or
can be changed by editing the file "INVOKEAI_ROOT/invokeai.yaml" or by
setting environment variables INVOKEAI_<setting>.
'''
singleton_config: ClassVar[InvokeAIAppConfig] = None
singleton_init: ClassVar[Dict] = None
#fmt: off
type: Literal["InvokeAI"] = "InvokeAI"
host : str = Field(default="127.0.0.1", description="IP address to bind to", category='Web Server')
@ -352,48 +371,64 @@ setting environment variables INVOKEAI_<setting>.
precision : Literal[tuple(['auto','float16','float32','autocast'])] = Field(default='float16',description='Floating point precision', category='Memory/Performance')
sequential_guidance : bool = Field(default=False, description="Whether to calculate guidance in serial instead of in parallel, lowering memory requirements", category='Memory/Performance')
xformers_enabled : bool = Field(default=True, description="Enable/disable memory-efficient attention", category='Memory/Performance')
tiled_decode : bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty)", category='Memory/Performance')
root : Path = Field(default=_find_root(), description='InvokeAI runtime root directory', category='Paths')
autoconvert_dir : Path = Field(default=None, description='Path to a directory of ckpt files to be converted into diffusers and imported on startup.', category='Paths')
conf_path : Path = Field(default='configs/models.yaml', description='Path to models definition file', category='Paths')
embedding_dir : Path = Field(default='embeddings', description='Path to InvokeAI textual inversion aembeddings directory', category='Paths')
gfpgan_model_dir : Path = Field(default="./models/gfpgan/GFPGANv1.4.pth", description='Path to GFPGAN models directory.', category='Paths')
models_dir : Path = Field(default='./models', description='Path to the models directory', category='Paths')
legacy_conf_dir : Path = Field(default='configs/stable-diffusion', description='Path to directory of legacy checkpoint config files', category='Paths')
lora_dir : Path = Field(default='loras', description='Path to InvokeAI LoRA model directory', category='Paths')
db_dir : Path = Field(default='databases', description='Path to InvokeAI databases directory', category='Paths')
outdir : Path = Field(default='outputs', description='Default folder for output images', category='Paths')
from_file : Path = Field(default=None, description='Take command input from the indicated file (command-line client only)', category='Paths')
use_memory_db : bool = Field(default=False, description='Use in-memory database for storing image metadata', category='Paths')
model : str = Field(default='stable-diffusion-1.5', description='Initial model name', category='Models')
embeddings : bool = Field(default=True, description='Load contents of embeddings directory', category='Models')
log_handlers : List[str] = Field(default=["console"], description='Log handler. Valid options are "console", "file=<path>", "syslog=path|address:host:port", "http=<url>"', category="Logging")
# note - would be better to read the log_format values from logging.py, but this creates circular dependencies issues
log_format : Literal[tuple(['plain','color','syslog','legacy'])] = Field(default="color", description='Log format. Use "plain" for text-only, "color" for colorized output, "legacy" for 2.3-style logging and "syslog" for syslog-style', category="Logging")
log_level : Literal[tuple(["debug","info","warning","error","critical"])] = Field(default="debug", description="Emit logging messages at this level or higher", category="Logging")
#fmt: on
def __init__(self, conf: DictConfig = None, argv: List[str]=None, **kwargs):
def parse_args(self, argv: List[str]=None, conf: DictConfig = None, clobber=False):
'''
Initialize InvokeAIAppconfig.
Update settings with contents of init file, environment, and
command-line settings.
:param conf: alternate Omegaconf dictionary object
:param argv: aternate sys.argv list
:param **kwargs: attributes to initialize with
:param clobber: ovewrite any initialization parameters passed during initialization
'''
super().__init__(**kwargs)
# Set the runtime root directory. We parse command-line switches here
# in order to pick up the --root_dir option.
self.parse_args(argv)
super().parse_args(argv)
if conf is None:
try:
conf = OmegaConf.load(self.root_dir / INIT_FILE)
except:
pass
InvokeAISettings.initconf = conf
# parse args again in order to pick up settings in configuration file
self.parse_args(argv)
super().parse_args(argv)
# restore initialization values
hints = get_type_hints(self)
for k in kwargs:
setattr(self,k,parse_obj_as(hints[k],kwargs[k]))
if self.singleton_init and not clobber:
hints = get_type_hints(self.__class__)
for k in self.singleton_init:
setattr(self,k,parse_obj_as(hints[k],self.singleton_init[k]))
@classmethod
def get_config(cls,**kwargs)->InvokeAIAppConfig:
'''
This returns a singleton InvokeAIAppConfig configuration object.
'''
if cls.singleton_config is None \
or type(cls.singleton_config)!=cls \
or (kwargs and cls.singleton_init != kwargs):
cls.singleton_config = cls(**kwargs)
cls.singleton_init = kwargs
return cls.singleton_config
@property
def root_path(self)->Path:
'''
@ -414,6 +449,13 @@ setting environment variables INVOKEAI_<setting>.
def _resolve(self,partial_path:Path)->Path:
return (self.root_path / partial_path).resolve()
@property
def init_file_path(self)->Path:
'''
Path to invokeai.yaml
'''
return self._resolve(INIT_FILE)
@property
def output_path(self)->Path:
'''
@ -421,6 +463,13 @@ setting environment variables INVOKEAI_<setting>.
'''
return self._resolve(self.outdir)
@property
def db_path(self)->Path:
'''
Path to the invokeai.db file.
'''
return self._resolve(self.db_dir) / DB_FILE
@property
def model_conf_path(self)->Path:
'''
@ -436,32 +485,11 @@ setting environment variables INVOKEAI_<setting>.
return self._resolve(self.legacy_conf_dir)
@property
def cache_dir(self)->Path:
'''
Path to the global cache directory for HuggingFace hub-managed models
'''
return self.models_dir / "hub"
@property
def models_dir(self)->Path:
def models_path(self)->Path:
'''
Path to the models directory
'''
return self._resolve("models")
@property
def embedding_path(self)->Path:
'''
Path to the textual inversion embeddings directory.
'''
return self._resolve(self.embedding_dir) if self.embedding_dir else None
@property
def lora_path(self)->Path:
'''
Path to the LoRA models directory.
'''
return self._resolve(self.lora_dir) if self.lora_dir else None
return self._resolve(self.models_dir)
@property
def autoconvert_path(self)->Path:
@ -470,13 +498,6 @@ setting environment variables INVOKEAI_<setting>.
'''
return self._resolve(self.autoconvert_dir) if self.autoconvert_dir else None
@property
def gfpgan_model_path(self)->Path:
'''
Path to the GFPGAN model.
'''
return self._resolve(self.gfpgan_model_dir) if self.gfpgan_model_dir else None
# the following methods support legacy calls leftover from the Globals era
@property
def full_precision(self)->bool:
@ -511,11 +532,8 @@ class PagingArgumentParser(argparse.ArgumentParser):
text = self.format_help()
pydoc.pager(text)
def get_invokeai_config(cls:Type[InvokeAISettings]=InvokeAIAppConfig,**kwargs)->InvokeAISettings:
def get_invokeai_config(**kwargs)->InvokeAIAppConfig:
'''
This returns a singleton InvokeAIAppConfig configuration object.
Legacy function which returns InvokeAIAppConfig.get_config()
'''
global global_config
if global_config is None or type(global_config)!=cls:
global_config = cls(**kwargs)
return global_config
return InvokeAIAppConfig.get_config(**kwargs)

View File

@ -1,9 +1,10 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from typing import Any
from invokeai.app.api.models.images import ProgressImage
from invokeai.app.models.image import ProgressImage
from invokeai.app.util.misc import get_timestamp
from invokeai.app.services.model_manager_service import BaseModelType, ModelType, SubModelType, ModelInfo
from invokeai.app.models.exceptions import CanceledException
class EventServiceBase:
session_event: str = "session_event"
@ -101,3 +102,53 @@ class EventServiceBase:
graph_execution_state_id=graph_execution_state_id,
),
)
def emit_model_load_started (
self,
graph_execution_state_id: str,
node: dict,
source_node_id: str,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
submodel: SubModelType,
) -> None:
"""Emitted when a model is requested"""
self.__emit_session_event(
event_name="model_load_started",
payload=dict(
graph_execution_state_id=graph_execution_state_id,
node=node,
source_node_id=source_node_id,
model_name=model_name,
base_model=base_model,
model_type=model_type,
submodel=submodel,
),
)
def emit_model_load_completed(
self,
graph_execution_state_id: str,
node: dict,
source_node_id: str,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
submodel: SubModelType,
model_info: ModelInfo,
) -> None:
"""Emitted when a model is correctly loaded (returns model info)"""
self.__emit_session_event(
event_name="model_load_completed",
payload=dict(
graph_execution_state_id=graph_execution_state_id,
node=node,
source_node_id=source_node_id,
model_name=model_name,
base_model=base_model,
model_type=model_type,
submodel=submodel,
model_info=model_info,
),
)

View File

@ -60,6 +60,33 @@ def get_input_field(node: BaseInvocation, field: str) -> Any:
node_input_field = node_inputs.get(field) or None
return node_input_field
from typing import Optional, Union, List, get_args
def is_union_subtype(t1, t2):
t1_args = get_args(t1)
t2_args = get_args(t2)
if not t1_args:
# t1 is a single type
return t1 in t2_args
else:
# t1 is a Union, check that all of its types are in t2_args
return all(arg in t2_args for arg in t1_args)
def is_list_or_contains_list(t):
t_args = get_args(t)
# If the type is a List
if get_origin(t) is list:
return True
# If the type is a Union
elif t_args:
# Check if any of the types in the Union is a List
for arg in t_args:
if get_origin(arg) is list:
return True
return False
def are_connection_types_compatible(from_type: Any, to_type: Any) -> bool:
if not from_type:
@ -85,7 +112,8 @@ def are_connection_types_compatible(from_type: Any, to_type: Any) -> bool:
if to_type in get_args(from_type):
return True
if not issubclass(from_type, to_type):
# if not issubclass(from_type, to_type):
if not is_union_subtype(from_type, to_type):
return False
else:
return False
@ -363,7 +391,7 @@ class Graph(BaseModel):
from_node = self.get_node(edge.source.node_id)
to_node = self.get_node(edge.destination.node_id)
except NodeNotFoundError:
raise InvalidEdgeError("One or both nodes don't exist")
raise InvalidEdgeError("One or both nodes don't exist: {edge.source.node_id} -> {edge.destination.node_id}")
# Validate that an edge to this node+field doesn't already exist
input_edges = self._get_input_edges(edge.destination.node_id, edge.destination.field)
@ -374,41 +402,41 @@ class Graph(BaseModel):
g = self.nx_graph_flat()
g.add_edge(edge.source.node_id, edge.destination.node_id)
if not nx.is_directed_acyclic_graph(g):
raise InvalidEdgeError(f'Edge creates a cycle in the graph')
raise InvalidEdgeError(f'Edge creates a cycle in the graph: {edge.source.node_id} -> {edge.destination.node_id}')
# Validate that the field types are compatible
if not are_connections_compatible(
from_node, edge.source.field, to_node, edge.destination.field
):
raise InvalidEdgeError(f'Fields are incompatible')
raise InvalidEdgeError(f'Fields are incompatible: cannot connect {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}')
# Validate if iterator output type matches iterator input type (if this edge results in both being set)
if isinstance(to_node, IterateInvocation) and edge.destination.field == "collection":
if not self._is_iterator_connection_valid(
edge.destination.node_id, new_input=edge.source
):
raise InvalidEdgeError(f'Iterator input type does not match iterator output type')
raise InvalidEdgeError(f'Iterator input type does not match iterator output type: {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}')
# Validate if iterator input type matches output type (if this edge results in both being set)
if isinstance(from_node, IterateInvocation) and edge.source.field == "item":
if not self._is_iterator_connection_valid(
edge.source.node_id, new_output=edge.destination
):
raise InvalidEdgeError(f'Iterator output type does not match iterator input type')
raise InvalidEdgeError(f'Iterator output type does not match iterator input type:, {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}')
# Validate if collector input type matches output type (if this edge results in both being set)
if isinstance(to_node, CollectInvocation) and edge.destination.field == "item":
if not self._is_collector_connection_valid(
edge.destination.node_id, new_input=edge.source
):
raise InvalidEdgeError(f'Collector output type does not match collector input type')
raise InvalidEdgeError(f'Collector output type does not match collector input type: {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}')
# Validate if collector output type matches input type (if this edge results in both being set)
if isinstance(from_node, CollectInvocation) and edge.source.field == "collection":
if not self._is_collector_connection_valid(
edge.source.node_id, new_output=edge.destination
):
raise InvalidEdgeError(f'Collector input type does not match collector output type')
raise InvalidEdgeError(f'Collector input type does not match collector output type: {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}')
def has_node(self, node_path: str) -> bool:
@ -694,7 +722,11 @@ class Graph(BaseModel):
input_root_type = next(t[0] for t in type_degrees if t[1] == 0) # type: ignore
# Verify that all outputs are lists
if not all((get_origin(f) == list for f in output_fields)):
# if not all((get_origin(f) == list for f in output_fields)):
# return False
# Verify that all outputs are lists
if not all(is_list_or_contains_list(f) for f in output_fields):
return False
# Verify that all outputs match the input type (are a base class or the same class)
@ -713,6 +745,13 @@ class Graph(BaseModel):
g.add_edges_from(set([(e.source.node_id, e.destination.node_id) for e in self.edges]))
return g
def nx_graph_with_data(self) -> nx.DiGraph:
"""Returns a NetworkX DiGraph representing the data and layout of this graph"""
g = nx.DiGraph()
g.add_nodes_from([n for n in self.nodes.items()])
g.add_edges_from(set([(e.source.node_id, e.destination.node_id) for e in self.edges]))
return g
def nx_graph_flat(
self, nx_graph: Optional[nx.DiGraph] = None, prefix: Optional[str] = None
) -> nx.DiGraph:
@ -818,11 +857,9 @@ class GraphExecutionState(BaseModel):
if next_node is None:
prepared_id = self._prepare()
# TODO: prepare multiple nodes at once?
# while prepared_id is not None and not isinstance(self.graph.nodes[prepared_id], IterateInvocation):
# prepared_id = self._prepare()
if prepared_id is not None:
# Prepare as many nodes as we can
while prepared_id is not None:
prepared_id = self._prepare()
next_node = self._get_next_node()
# Get values from edges
@ -969,14 +1006,30 @@ class GraphExecutionState(BaseModel):
# Get flattened source graph
g = self.graph.nx_graph_flat()
# Find next unprepared node where all source nodes are executed
# Find next node that:
# - was not already prepared
# - is not an iterate node whose inputs have not been executed
# - does not have an unexecuted iterate ancestor
sorted_nodes = nx.topological_sort(g)
next_node_id = next(
(
n
for n in sorted_nodes
# exclude nodes that have already been prepared
if n not in self.source_prepared_mapping
and all((e[0] in self.executed for e in g.in_edges(n)))
# exclude iterate nodes whose inputs have not been executed
and not (
isinstance(self.graph.get_node(n), IterateInvocation) # `n` is an iterate node...
and not all((e[0] in self.executed for e in g.in_edges(n))) # ...that has unexecuted inputs
)
# exclude nodes who have unexecuted iterate ancestors
and not any(
(
isinstance(self.graph.get_node(a), IterateInvocation) # `a` is an iterate ancestor of `n`...
and a not in self.executed # ...that is not executed
for a in nx.ancestors(g, n) # for all ancestors `a` of node `n`
)
)
),
None,
)
@ -1073,9 +1126,22 @@ class GraphExecutionState(BaseModel):
)
def _get_next_node(self) -> Optional[BaseInvocation]:
"""Gets the deepest node that is ready to be executed"""
g = self.execution_graph.nx_graph()
sorted_nodes = nx.topological_sort(g)
next_node = next((n for n in sorted_nodes if n not in self.executed), None)
# Depth-first search with pre-order traversal is a depth-first topological sort
sorted_nodes = nx.dfs_preorder_nodes(g)
next_node = next(
(
n
for n in sorted_nodes
if n not in self.executed # the node must not already be executed...
and all((e[0] in self.executed for e in g.in_edges(n))) # ...and all its inputs must be executed
),
None,
)
if next_node is None:
return None

View File

@ -0,0 +1,186 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
from abc import ABC, abstractmethod
from pathlib import Path
from queue import Queue
from typing import Dict, Optional
from PIL.Image import Image as PILImageType
from PIL import Image, PngImagePlugin
from send2trash import send2trash
from invokeai.app.models.image import ResourceOrigin
from invokeai.app.models.metadata import ImageMetadata
from invokeai.app.util.thumbnails import get_thumbnail_name, make_thumbnail
# TODO: Should these excpetions subclass existing python exceptions?
class ImageFileNotFoundException(Exception):
"""Raised when an image file is not found in storage."""
def __init__(self, message="Image file not found"):
super().__init__(message)
class ImageFileSaveException(Exception):
"""Raised when an image cannot be saved."""
def __init__(self, message="Image file not saved"):
super().__init__(message)
class ImageFileDeleteException(Exception):
"""Raised when an image cannot be deleted."""
def __init__(self, message="Image file not deleted"):
super().__init__(message)
class ImageFileStorageBase(ABC):
"""Low-level service responsible for storing and retrieving image files."""
@abstractmethod
def get(self, image_name: str) -> PILImageType:
"""Retrieves an image as PIL Image."""
pass
@abstractmethod
def get_path(self, image_name: str, thumbnail: bool = False) -> str:
"""Gets the internal path to an image or thumbnail."""
pass
# TODO: We need to validate paths before starlette makes the FileResponse, else we get a
# 500 internal server error. I don't like having this method on the service.
@abstractmethod
def validate_path(self, path: str) -> bool:
"""Validates the path given for an image or thumbnail."""
pass
@abstractmethod
def save(
self,
image: PILImageType,
image_name: str,
metadata: Optional[ImageMetadata] = None,
thumbnail_size: int = 256,
) -> None:
"""Saves an image and a 256x256 WEBP thumbnail. Returns a tuple of the image name, thumbnail name, and created timestamp."""
pass
@abstractmethod
def delete(self, image_name: str) -> None:
"""Deletes an image and its thumbnail (if one exists)."""
pass
class DiskImageFileStorage(ImageFileStorageBase):
"""Stores images on disk"""
__output_folder: Path
__cache_ids: Queue # TODO: this is an incredibly naive cache
__cache: Dict[Path, PILImageType]
__max_cache_size: int
def __init__(self, output_folder: str | Path):
self.__cache = dict()
self.__cache_ids = Queue()
self.__max_cache_size = 10 # TODO: get this from config
self.__output_folder: Path = output_folder if isinstance(output_folder, Path) else Path(output_folder)
self.__thumbnails_folder = self.__output_folder / 'thumbnails'
# Validate required output folders at launch
self.__validate_storage_folders()
def get(self, image_name: str) -> PILImageType:
try:
image_path = self.get_path(image_name)
cache_item = self.__get_cache(image_path)
if cache_item:
return cache_item
image = Image.open(image_path)
self.__set_cache(image_path, image)
return image
except FileNotFoundError as e:
raise ImageFileNotFoundException from e
def save(
self,
image: PILImageType,
image_name: str,
metadata: Optional[ImageMetadata] = None,
thumbnail_size: int = 256,
) -> None:
try:
self.__validate_storage_folders()
image_path = self.get_path(image_name)
if metadata is not None:
pnginfo = PngImagePlugin.PngInfo()
pnginfo.add_text("invokeai", metadata.json())
image.save(image_path, "PNG", pnginfo=pnginfo)
else:
image.save(image_path, "PNG")
thumbnail_name = get_thumbnail_name(image_name)
thumbnail_path = self.get_path(thumbnail_name, thumbnail=True)
thumbnail_image = make_thumbnail(image, thumbnail_size)
thumbnail_image.save(thumbnail_path)
self.__set_cache(image_path, image)
self.__set_cache(thumbnail_path, thumbnail_image)
except Exception as e:
raise ImageFileSaveException from e
def delete(self, image_name: str) -> None:
try:
image_path = self.get_path(image_name)
if image_path.exists():
send2trash(image_path)
if image_path in self.__cache:
del self.__cache[image_path]
thumbnail_name = get_thumbnail_name(image_name)
thumbnail_path = self.get_path(thumbnail_name, True)
if thumbnail_path.exists():
send2trash(thumbnail_path)
if thumbnail_path in self.__cache:
del self.__cache[thumbnail_path]
except Exception as e:
raise ImageFileDeleteException from e
# TODO: make this a bit more flexible for e.g. cloud storage
def get_path(self, image_name: str, thumbnail: bool = False) -> Path:
path = self.__output_folder / image_name
if thumbnail:
thumbnail_name = get_thumbnail_name(image_name)
path = self.__thumbnails_folder / thumbnail_name
return path
def validate_path(self, path: str | Path) -> bool:
"""Validates the path given for an image or thumbnail."""
path = path if isinstance(path, Path) else Path(path)
return path.exists()
def __validate_storage_folders(self) -> None:
"""Checks if the required output folders exist and create them if they don't"""
folders: list[Path] = [self.__output_folder, self.__thumbnails_folder]
for folder in folders:
folder.mkdir(parents=True, exist_ok=True)
def __get_cache(self, image_name: Path) -> PILImageType | None:
return None if image_name not in self.__cache else self.__cache[image_name]
def __set_cache(self, image_name: Path, image: PILImageType):
if not image_name in self.__cache:
self.__cache[image_name] = image
self.__cache_ids.put(image_name) # TODO: this should refresh position for LRU cache
if len(self.__cache) > self.__max_cache_size:
cache_id = self.__cache_ids.get()
if cache_id in self.__cache:
del self.__cache[cache_id]

View File

@ -0,0 +1,475 @@
from abc import ABC, abstractmethod
from datetime import datetime
from typing import Generic, Optional, TypeVar, cast
import sqlite3
import threading
from typing import Optional, Union
from pydantic import BaseModel, Field
from pydantic.generics import GenericModel
from invokeai.app.models.metadata import ImageMetadata
from invokeai.app.models.image import (
ImageCategory,
ResourceOrigin,
)
from invokeai.app.services.models.image_record import (
ImageRecord,
ImageRecordChanges,
deserialize_image_record,
)
T = TypeVar("T", bound=BaseModel)
class OffsetPaginatedResults(GenericModel, Generic[T]):
"""Offset-paginated results"""
# fmt: off
items: list[T] = Field(description="Items")
offset: int = Field(description="Offset from which to retrieve items")
limit: int = Field(description="Limit of items to get")
total: int = Field(description="Total number of items in result")
# fmt: on
# TODO: Should these excpetions subclass existing python exceptions?
class ImageRecordNotFoundException(Exception):
"""Raised when an image record is not found."""
def __init__(self, message="Image record not found"):
super().__init__(message)
class ImageRecordSaveException(Exception):
"""Raised when an image record cannot be saved."""
def __init__(self, message="Image record not saved"):
super().__init__(message)
class ImageRecordDeleteException(Exception):
"""Raised when an image record cannot be deleted."""
def __init__(self, message="Image record not deleted"):
super().__init__(message)
class ImageRecordStorageBase(ABC):
"""Low-level service responsible for interfacing with the image record store."""
# TODO: Implement an `update()` method
@abstractmethod
def get(self, image_name: str) -> ImageRecord:
"""Gets an image record."""
pass
@abstractmethod
def update(
self,
image_name: str,
changes: ImageRecordChanges,
) -> None:
"""Updates an image record."""
pass
@abstractmethod
def get_many(
self,
offset: int = 0,
limit: int = 10,
image_origin: Optional[ResourceOrigin] = None,
categories: Optional[list[ImageCategory]] = None,
is_intermediate: Optional[bool] = None,
board_id: Optional[str] = None,
) -> OffsetPaginatedResults[ImageRecord]:
"""Gets a page of image records."""
pass
# TODO: The database has a nullable `deleted_at` column, currently unused.
# Should we implement soft deletes? Would need coordination with ImageFileStorage.
@abstractmethod
def delete(self, image_name: str) -> None:
"""Deletes an image record."""
pass
@abstractmethod
def save(
self,
image_name: str,
image_origin: ResourceOrigin,
image_category: ImageCategory,
width: int,
height: int,
session_id: Optional[str],
node_id: Optional[str],
metadata: Optional[ImageMetadata],
is_intermediate: bool = False,
) -> datetime:
"""Saves an image record."""
pass
@abstractmethod
def get_most_recent_image_for_board(self, board_id: str) -> ImageRecord | None:
"""Gets the most recent image for a board."""
pass
class SqliteImageRecordStorage(ImageRecordStorageBase):
_filename: str
_conn: sqlite3.Connection
_cursor: sqlite3.Cursor
_lock: threading.Lock
def __init__(self, filename: str) -> None:
super().__init__()
self._filename = filename
self._conn = sqlite3.connect(filename, check_same_thread=False)
# Enable row factory to get rows as dictionaries (must be done before making the cursor!)
self._conn.row_factory = sqlite3.Row
self._cursor = self._conn.cursor()
self._lock = threading.Lock()
try:
self._lock.acquire()
# Enable foreign keys
self._conn.execute("PRAGMA foreign_keys = ON;")
self._create_tables()
self._conn.commit()
finally:
self._lock.release()
def _create_tables(self) -> None:
"""Creates the `images` table."""
# Create the `images` table.
self._cursor.execute(
"""--sql
CREATE TABLE IF NOT EXISTS images (
image_name TEXT NOT NULL PRIMARY KEY,
-- This is an enum in python, unrestricted string here for flexibility
image_origin TEXT NOT NULL,
-- This is an enum in python, unrestricted string here for flexibility
image_category TEXT NOT NULL,
width INTEGER NOT NULL,
height INTEGER NOT NULL,
session_id TEXT,
node_id TEXT,
metadata TEXT,
is_intermediate BOOLEAN DEFAULT FALSE,
board_id TEXT,
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Updated via trigger
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Soft delete, currently unused
deleted_at DATETIME
);
"""
)
# Create the `images` table indices.
self._cursor.execute(
"""--sql
CREATE UNIQUE INDEX IF NOT EXISTS idx_images_image_name ON images(image_name);
"""
)
self._cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_images_image_origin ON images(image_origin);
"""
)
self._cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_images_image_category ON images(image_category);
"""
)
self._cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_images_created_at ON images(created_at);
"""
)
# Add trigger for `updated_at`.
self._cursor.execute(
"""--sql
CREATE TRIGGER IF NOT EXISTS tg_images_updated_at
AFTER UPDATE
ON images FOR EACH ROW
BEGIN
UPDATE images SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
WHERE image_name = old.image_name;
END;
"""
)
def get(self, image_name: str) -> Union[ImageRecord, None]:
try:
self._lock.acquire()
self._cursor.execute(
f"""--sql
SELECT * FROM images
WHERE image_name = ?;
""",
(image_name,),
)
result = cast(Union[sqlite3.Row, None], self._cursor.fetchone())
except sqlite3.Error as e:
self._conn.rollback()
raise ImageRecordNotFoundException from e
finally:
self._lock.release()
if not result:
raise ImageRecordNotFoundException
return deserialize_image_record(dict(result))
def update(
self,
image_name: str,
changes: ImageRecordChanges,
) -> None:
try:
self._lock.acquire()
# Change the category of the image
if changes.image_category is not None:
self._cursor.execute(
f"""--sql
UPDATE images
SET image_category = ?
WHERE image_name = ?;
""",
(changes.image_category, image_name),
)
# Change the session associated with the image
if changes.session_id is not None:
self._cursor.execute(
f"""--sql
UPDATE images
SET session_id = ?
WHERE image_name = ?;
""",
(changes.session_id, image_name),
)
# Change the image's `is_intermediate`` flag
if changes.is_intermediate is not None:
self._cursor.execute(
f"""--sql
UPDATE images
SET is_intermediate = ?
WHERE image_name = ?;
""",
(changes.is_intermediate, image_name),
)
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
raise ImageRecordSaveException from e
finally:
self._lock.release()
def get_many(
self,
offset: int = 0,
limit: int = 10,
image_origin: Optional[ResourceOrigin] = None,
categories: Optional[list[ImageCategory]] = None,
is_intermediate: Optional[bool] = None,
board_id: Optional[str] = None,
) -> OffsetPaginatedResults[ImageRecord]:
try:
self._lock.acquire()
# Manually build two queries - one for the count, one for the records
count_query = """--sql
SELECT COUNT(*)
FROM images
LEFT JOIN board_images ON board_images.image_name = images.image_name
WHERE 1=1
"""
images_query = """--sql
SELECT images.*
FROM images
LEFT JOIN board_images ON board_images.image_name = images.image_name
WHERE 1=1
"""
query_conditions = ""
query_params = []
if image_origin is not None:
query_conditions += """--sql
AND images.image_origin = ?
"""
query_params.append(image_origin.value)
if categories is not None:
# Convert the enum values to unique list of strings
category_strings = list(map(lambda c: c.value, set(categories)))
# Create the correct length of placeholders
placeholders = ",".join("?" * len(category_strings))
query_conditions += f"""--sql
AND images.image_category IN ( {placeholders} )
"""
# Unpack the included categories into the query params
for c in category_strings:
query_params.append(c)
if is_intermediate is not None:
query_conditions += """--sql
AND images.is_intermediate = ?
"""
query_params.append(is_intermediate)
if board_id is not None:
query_conditions += """--sql
AND board_images.board_id = ?
"""
query_params.append(board_id)
query_pagination = """--sql
ORDER BY images.created_at DESC LIMIT ? OFFSET ?
"""
# Final images query with pagination
images_query += query_conditions + query_pagination + ";"
# Add all the parameters
images_params = query_params.copy()
images_params.append(limit)
images_params.append(offset)
# Build the list of images, deserializing each row
self._cursor.execute(images_query, images_params)
result = cast(list[sqlite3.Row], self._cursor.fetchall())
images = list(map(lambda r: deserialize_image_record(dict(r)), result))
# Set up and execute the count query, without pagination
count_query += query_conditions + ";"
count_params = query_params.copy()
self._cursor.execute(count_query, count_params)
count = cast(int, self._cursor.fetchone()[0])
except sqlite3.Error as e:
self._conn.rollback()
raise e
finally:
self._lock.release()
return OffsetPaginatedResults(
items=images, offset=offset, limit=limit, total=count
)
def delete(self, image_name: str) -> None:
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
DELETE FROM images
WHERE image_name = ?;
""",
(image_name,),
)
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
raise ImageRecordDeleteException from e
finally:
self._lock.release()
def save(
self,
image_name: str,
image_origin: ResourceOrigin,
image_category: ImageCategory,
session_id: Optional[str],
width: int,
height: int,
node_id: Optional[str],
metadata: Optional[ImageMetadata],
is_intermediate: bool = False,
) -> datetime:
try:
metadata_json = (
None if metadata is None else metadata.json(exclude_none=True)
)
self._lock.acquire()
self._cursor.execute(
"""--sql
INSERT OR IGNORE INTO images (
image_name,
image_origin,
image_category,
width,
height,
node_id,
session_id,
metadata,
is_intermediate
)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?);
""",
(
image_name,
image_origin.value,
image_category.value,
width,
height,
node_id,
session_id,
metadata_json,
is_intermediate,
),
)
self._conn.commit()
self._cursor.execute(
"""--sql
SELECT created_at
FROM images
WHERE image_name = ?;
""",
(image_name,),
)
created_at = datetime.fromisoformat(self._cursor.fetchone()[0])
return created_at
except sqlite3.Error as e:
self._conn.rollback()
raise ImageRecordSaveException from e
finally:
self._lock.release()
def get_most_recent_image_for_board(
self, board_id: str
) -> Union[ImageRecord, None]:
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
SELECT images.*
FROM images
JOIN board_images ON images.image_name = board_images.image_name
WHERE board_images.board_id = ?
ORDER BY images.created_at DESC
LIMIT 1;
""",
(board_id,),
)
result = cast(Union[sqlite3.Row, None], self._cursor.fetchone())
finally:
self._lock.release()
if result is None:
return None
return deserialize_image_record(dict(result))

View File

@ -1,274 +0,0 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
import os
from glob import glob
from abc import ABC, abstractmethod
from pathlib import Path
from queue import Queue
from typing import Dict, List
from PIL.Image import Image
import PIL.Image as PILImage
from send2trash import send2trash
from invokeai.app.api.models.images import (
ImageResponse,
ImageResponseMetadata,
SavedImage,
)
from invokeai.app.models.image import ImageType
from invokeai.app.services.metadata import (
InvokeAIMetadata,
MetadataServiceBase,
build_invokeai_metadata_pnginfo,
)
from invokeai.app.services.item_storage import PaginatedResults
from invokeai.app.util.misc import get_timestamp
from invokeai.app.util.thumbnails import get_thumbnail_name, make_thumbnail
class ImageStorageBase(ABC):
"""Responsible for storing and retrieving images."""
@abstractmethod
def get(self, image_type: ImageType, image_name: str) -> Image:
"""Retrieves an image as PIL Image."""
pass
@abstractmethod
def list(
self, image_type: ImageType, page: int = 0, per_page: int = 10
) -> PaginatedResults[ImageResponse]:
"""Gets a paginated list of images."""
pass
# TODO: make this a bit more flexible for e.g. cloud storage
@abstractmethod
def get_path(
self, image_type: ImageType, image_name: str, is_thumbnail: bool = False
) -> str:
"""Gets the internal path to an image or its thumbnail."""
pass
# TODO: make this a bit more flexible for e.g. cloud storage
@abstractmethod
def get_uri(
self, image_type: ImageType, image_name: str, is_thumbnail: bool = False
) -> str:
"""Gets the external URI to an image or its thumbnail."""
pass
# TODO: make this a bit more flexible for e.g. cloud storage
@abstractmethod
def validate_path(self, path: str) -> bool:
"""Validates an image path."""
pass
@abstractmethod
def save(
self,
image_type: ImageType,
image_name: str,
image: Image,
metadata: InvokeAIMetadata | None = None,
) -> SavedImage:
"""Saves an image and a 256x256 WEBP thumbnail. Returns a tuple of the image name, thumbnail name, and created timestamp."""
pass
@abstractmethod
def delete(self, image_type: ImageType, image_name: str) -> None:
"""Deletes an image and its thumbnail (if one exists)."""
pass
def create_name(self, context_id: str, node_id: str) -> str:
"""Creates a unique contextual image filename."""
return f"{context_id}_{node_id}_{str(get_timestamp())}.png"
class DiskImageStorage(ImageStorageBase):
"""Stores images on disk"""
__output_folder: str
__cache_ids: Queue # TODO: this is an incredibly naive cache
__cache: Dict[str, Image]
__max_cache_size: int
__metadata_service: MetadataServiceBase
def __init__(self, output_folder: str, metadata_service: MetadataServiceBase):
self.__output_folder = output_folder
self.__cache = dict()
self.__cache_ids = Queue()
self.__max_cache_size = 10 # TODO: get this from config
self.__metadata_service = metadata_service
Path(output_folder).mkdir(parents=True, exist_ok=True)
# TODO: don't hard-code. get/save/delete should maybe take subpath?
for image_type in ImageType:
Path(os.path.join(output_folder, image_type)).mkdir(
parents=True, exist_ok=True
)
Path(os.path.join(output_folder, image_type, "thumbnails")).mkdir(
parents=True, exist_ok=True
)
def list(
self, image_type: ImageType, page: int = 0, per_page: int = 10
) -> PaginatedResults[ImageResponse]:
dir_path = os.path.join(self.__output_folder, image_type)
image_paths = glob(f"{dir_path}/*.png")
count = len(image_paths)
sorted_image_paths = sorted(
glob(f"{dir_path}/*.png"), key=os.path.getctime, reverse=True
)
page_of_image_paths = sorted_image_paths[
page * per_page : (page + 1) * per_page
]
page_of_images: List[ImageResponse] = []
for path in page_of_image_paths:
filename = os.path.basename(path)
img = PILImage.open(path)
invokeai_metadata = self.__metadata_service.get_metadata(img)
page_of_images.append(
ImageResponse(
image_type=image_type.value,
image_name=filename,
# TODO: DiskImageStorage should not be building URLs...?
image_url=self.get_uri(image_type, filename),
thumbnail_url=self.get_uri(image_type, filename, True),
# TODO: Creation of this object should happen elsewhere (?), just making it fit here so it works
metadata=ImageResponseMetadata(
created=int(os.path.getctime(path)),
width=img.width,
height=img.height,
invokeai=invokeai_metadata,
),
)
)
page_count_trunc = int(count / per_page)
page_count_mod = count % per_page
page_count = page_count_trunc if page_count_mod == 0 else page_count_trunc + 1
return PaginatedResults[ImageResponse](
items=page_of_images,
page=page,
pages=page_count,
per_page=per_page,
total=count,
)
def get(self, image_type: ImageType, image_name: str) -> Image:
image_path = self.get_path(image_type, image_name)
cache_item = self.__get_cache(image_path)
if cache_item:
return cache_item
image = PILImage.open(image_path)
self.__set_cache(image_path, image)
return image
# TODO: make this a bit more flexible for e.g. cloud storage
def get_path(
self, image_type: ImageType, image_name: str, is_thumbnail: bool = False
) -> str:
# strip out any relative path shenanigans
basename = os.path.basename(image_name)
if is_thumbnail:
path = os.path.join(
self.__output_folder, image_type, "thumbnails", basename
)
else:
path = os.path.join(self.__output_folder, image_type, basename)
abspath = os.path.abspath(path)
return abspath
def get_uri(
self, image_type: ImageType, image_name: str, is_thumbnail: bool = False
) -> str:
# strip out any relative path shenanigans
basename = os.path.basename(image_name)
if is_thumbnail:
thumbnail_basename = get_thumbnail_name(basename)
uri = f"api/v1/images/{image_type.value}/thumbnails/{thumbnail_basename}"
else:
uri = f"api/v1/images/{image_type.value}/{basename}"
return uri
def validate_path(self, path: str) -> bool:
try:
os.stat(path)
return True
except Exception:
return False
def save(
self,
image_type: ImageType,
image_name: str,
image: Image,
metadata: InvokeAIMetadata | None = None,
) -> SavedImage:
image_path = self.get_path(image_type, image_name)
# TODO: Reading the image and then saving it strips the metadata...
if metadata:
pnginfo = build_invokeai_metadata_pnginfo(metadata=metadata)
image.save(image_path, "PNG", pnginfo=pnginfo)
else:
image.save(image_path) # this saved image has an empty info
thumbnail_name = get_thumbnail_name(image_name)
thumbnail_path = self.get_path(image_type, thumbnail_name, is_thumbnail=True)
thumbnail_image = make_thumbnail(image)
thumbnail_image.save(thumbnail_path)
self.__set_cache(image_path, image)
self.__set_cache(thumbnail_path, thumbnail_image)
return SavedImage(
image_name=image_name,
thumbnail_name=thumbnail_name,
created=int(os.path.getctime(image_path)),
)
def delete(self, image_type: ImageType, image_name: str) -> None:
basename = os.path.basename(image_name)
image_path = self.get_path(image_type, basename)
if os.path.exists(image_path):
send2trash(image_path)
if image_path in self.__cache:
del self.__cache[image_path]
thumbnail_name = get_thumbnail_name(image_name)
thumbnail_path = self.get_path(image_type, thumbnail_name, True)
if os.path.exists(thumbnail_path):
send2trash(thumbnail_path)
if thumbnail_path in self.__cache:
del self.__cache[thumbnail_path]
def __get_cache(self, image_name: str) -> Image | None:
return None if image_name not in self.__cache else self.__cache[image_name]
def __set_cache(self, image_name: str, image: Image):
if not image_name in self.__cache:
self.__cache[image_name] = image
self.__cache_ids.put(
image_name
) # TODO: this should refresh position for LRU cache
if len(self.__cache) > self.__max_cache_size:
cache_id = self.__cache_ids.get()
if cache_id in self.__cache:
del self.__cache[cache_id]

View File

@ -0,0 +1,354 @@
from abc import ABC, abstractmethod
from logging import Logger
from typing import Optional, TYPE_CHECKING, Union
from PIL.Image import Image as PILImageType
from invokeai.app.models.image import (
ImageCategory,
ResourceOrigin,
InvalidImageCategoryException,
InvalidOriginException,
)
from invokeai.app.models.metadata import ImageMetadata
from invokeai.app.services.board_image_record_storage import BoardImageRecordStorageBase
from invokeai.app.services.image_record_storage import (
ImageRecordDeleteException,
ImageRecordNotFoundException,
ImageRecordSaveException,
ImageRecordStorageBase,
OffsetPaginatedResults,
)
from invokeai.app.services.models.image_record import (
ImageRecord,
ImageDTO,
ImageRecordChanges,
image_record_to_dto,
)
from invokeai.app.services.image_file_storage import (
ImageFileDeleteException,
ImageFileNotFoundException,
ImageFileSaveException,
ImageFileStorageBase,
)
from invokeai.app.services.item_storage import ItemStorageABC, PaginatedResults
from invokeai.app.services.metadata import MetadataServiceBase
from invokeai.app.services.resource_name import NameServiceBase
from invokeai.app.services.urls import UrlServiceBase
if TYPE_CHECKING:
from invokeai.app.services.graph import GraphExecutionState
class ImageServiceABC(ABC):
"""High-level service for image management."""
@abstractmethod
def create(
self,
image: PILImageType,
image_origin: ResourceOrigin,
image_category: ImageCategory,
node_id: Optional[str] = None,
session_id: Optional[str] = None,
is_intermediate: bool = False,
) -> ImageDTO:
"""Creates an image, storing the file and its metadata."""
pass
@abstractmethod
def update(
self,
image_name: str,
changes: ImageRecordChanges,
) -> ImageDTO:
"""Updates an image."""
pass
@abstractmethod
def get_pil_image(self, image_name: str) -> PILImageType:
"""Gets an image as a PIL image."""
pass
@abstractmethod
def get_record(self, image_name: str) -> ImageRecord:
"""Gets an image record."""
pass
@abstractmethod
def get_dto(self, image_name: str) -> ImageDTO:
"""Gets an image DTO."""
pass
@abstractmethod
def get_path(self, image_name: str, thumbnail: bool = False) -> str:
"""Gets an image's path."""
pass
@abstractmethod
def validate_path(self, path: str) -> bool:
"""Validates an image's path."""
pass
@abstractmethod
def get_url(self, image_name: str, thumbnail: bool = False) -> str:
"""Gets an image's or thumbnail's URL."""
pass
@abstractmethod
def get_many(
self,
offset: int = 0,
limit: int = 10,
image_origin: Optional[ResourceOrigin] = None,
categories: Optional[list[ImageCategory]] = None,
is_intermediate: Optional[bool] = None,
board_id: Optional[str] = None,
) -> OffsetPaginatedResults[ImageDTO]:
"""Gets a paginated list of image DTOs."""
pass
@abstractmethod
def delete(self, image_name: str):
"""Deletes an image."""
pass
class ImageServiceDependencies:
"""Service dependencies for the ImageService."""
image_records: ImageRecordStorageBase
image_files: ImageFileStorageBase
board_image_records: BoardImageRecordStorageBase
metadata: MetadataServiceBase
urls: UrlServiceBase
logger: Logger
names: NameServiceBase
graph_execution_manager: ItemStorageABC["GraphExecutionState"]
def __init__(
self,
image_record_storage: ImageRecordStorageBase,
image_file_storage: ImageFileStorageBase,
board_image_record_storage: BoardImageRecordStorageBase,
metadata: MetadataServiceBase,
url: UrlServiceBase,
logger: Logger,
names: NameServiceBase,
graph_execution_manager: ItemStorageABC["GraphExecutionState"],
):
self.image_records = image_record_storage
self.image_files = image_file_storage
self.board_image_records = board_image_record_storage
self.metadata = metadata
self.urls = url
self.logger = logger
self.names = names
self.graph_execution_manager = graph_execution_manager
class ImageService(ImageServiceABC):
_services: ImageServiceDependencies
def __init__(self, services: ImageServiceDependencies):
self._services = services
def create(
self,
image: PILImageType,
image_origin: ResourceOrigin,
image_category: ImageCategory,
node_id: Optional[str] = None,
session_id: Optional[str] = None,
is_intermediate: bool = False,
) -> ImageDTO:
if image_origin not in ResourceOrigin:
raise InvalidOriginException
if image_category not in ImageCategory:
raise InvalidImageCategoryException
image_name = self._services.names.create_image_name()
metadata = self._get_metadata(session_id, node_id)
(width, height) = image.size
try:
# TODO: Consider using a transaction here to ensure consistency between storage and database
self._services.image_records.save(
# Non-nullable fields
image_name=image_name,
image_origin=image_origin,
image_category=image_category,
width=width,
height=height,
# Meta fields
is_intermediate=is_intermediate,
# Nullable fields
node_id=node_id,
session_id=session_id,
metadata=metadata,
)
self._services.image_files.save(
image_name=image_name,
image=image,
metadata=metadata,
)
image_dto = self.get_dto(image_name)
return image_dto
except ImageRecordSaveException:
self._services.logger.error("Failed to save image record")
raise
except ImageFileSaveException:
self._services.logger.error("Failed to save image file")
raise
except Exception as e:
self._services.logger.error("Problem saving image record and file")
raise e
def update(
self,
image_name: str,
changes: ImageRecordChanges,
) -> ImageDTO:
try:
self._services.image_records.update(image_name, changes)
return self.get_dto(image_name)
except ImageRecordSaveException:
self._services.logger.error("Failed to update image record")
raise
except Exception as e:
self._services.logger.error("Problem updating image record")
raise e
def get_pil_image(self, image_name: str) -> PILImageType:
try:
return self._services.image_files.get(image_name)
except ImageFileNotFoundException:
self._services.logger.error("Failed to get image file")
raise
except Exception as e:
self._services.logger.error("Problem getting image file")
raise e
def get_record(self, image_name: str) -> ImageRecord:
try:
return self._services.image_records.get(image_name)
except ImageRecordNotFoundException:
self._services.logger.error("Image record not found")
raise
except Exception as e:
self._services.logger.error("Problem getting image record")
raise e
def get_dto(self, image_name: str) -> ImageDTO:
try:
image_record = self._services.image_records.get(image_name)
image_dto = image_record_to_dto(
image_record,
self._services.urls.get_image_url(image_name),
self._services.urls.get_image_url(image_name, True),
self._services.board_image_records.get_board_for_image(image_name),
)
return image_dto
except ImageRecordNotFoundException:
self._services.logger.error("Image record not found")
raise
except Exception as e:
self._services.logger.error("Problem getting image DTO")
raise e
def get_path(self, image_name: str, thumbnail: bool = False) -> str:
try:
return self._services.image_files.get_path(image_name, thumbnail)
except Exception as e:
self._services.logger.error("Problem getting image path")
raise e
def validate_path(self, path: str) -> bool:
try:
return self._services.image_files.validate_path(path)
except Exception as e:
self._services.logger.error("Problem validating image path")
raise e
def get_url(self, image_name: str, thumbnail: bool = False) -> str:
try:
return self._services.urls.get_image_url(image_name, thumbnail)
except Exception as e:
self._services.logger.error("Problem getting image path")
raise e
def get_many(
self,
offset: int = 0,
limit: int = 10,
image_origin: Optional[ResourceOrigin] = None,
categories: Optional[list[ImageCategory]] = None,
is_intermediate: Optional[bool] = None,
board_id: Optional[str] = None,
) -> OffsetPaginatedResults[ImageDTO]:
try:
results = self._services.image_records.get_many(
offset,
limit,
image_origin,
categories,
is_intermediate,
board_id,
)
image_dtos = list(
map(
lambda r: image_record_to_dto(
r,
self._services.urls.get_image_url(r.image_name),
self._services.urls.get_image_url(r.image_name, True),
self._services.board_image_records.get_board_for_image(
r.image_name
),
),
results.items,
)
)
return OffsetPaginatedResults[ImageDTO](
items=image_dtos,
offset=results.offset,
limit=results.limit,
total=results.total,
)
except Exception as e:
self._services.logger.error("Problem getting paginated image DTOs")
raise e
def delete(self, image_name: str):
try:
self._services.image_files.delete(image_name)
self._services.image_records.delete(image_name)
except ImageRecordDeleteException:
self._services.logger.error(f"Failed to delete image record")
raise
except ImageFileDeleteException:
self._services.logger.error(f"Failed to delete image file")
raise
except Exception as e:
self._services.logger.error("Problem deleting image record and file")
raise e
def _get_metadata(
self, session_id: Optional[str] = None, node_id: Optional[str] = None
) -> Union[ImageMetadata, None]:
"""Get the metadata for a node."""
metadata = None
if node_id is not None and session_id is not None:
session = self._services.graph_execution_manager.get(session_id)
metadata = self._services.metadata.create_image_metadata(session, node_id)
return metadata

View File

@ -1,58 +1,67 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
from __future__ import annotations
from typing import TYPE_CHECKING
from typing import types
from invokeai.app.services.metadata import MetadataServiceBase
from invokeai.backend import ModelManager
if TYPE_CHECKING:
from logging import Logger
from invokeai.app.services.board_images import BoardImagesServiceABC
from invokeai.app.services.boards import BoardServiceABC
from invokeai.app.services.images import ImageServiceABC
from invokeai.backend import ModelManager
from invokeai.app.services.events import EventServiceBase
from invokeai.app.services.latent_storage import LatentsStorageBase
from invokeai.app.services.restoration_services import RestorationServices
from invokeai.app.services.invocation_queue import InvocationQueueABC
from invokeai.app.services.item_storage import ItemStorageABC
from invokeai.app.services.config import InvokeAISettings
from invokeai.app.services.graph import GraphExecutionState, LibraryGraph
from invokeai.app.services.invoker import InvocationProcessorABC
from .events import EventServiceBase
from .latent_storage import LatentsStorageBase
from .image_storage import ImageStorageBase
from .restoration_services import RestorationServices
from .invocation_queue import InvocationQueueABC
from .item_storage import ItemStorageABC
from .config import InvokeAISettings
class InvocationServices:
"""Services that can be used by invocations"""
events: EventServiceBase
latents: LatentsStorageBase
images: ImageStorageBase
metadata: MetadataServiceBase
queue: InvocationQueueABC
model_manager: ModelManager
restoration: RestorationServices
configuration: InvokeAISettings
# NOTE: we must forward-declare any types that include invocations, since invocations can use services
graph_library: ItemStorageABC["LibraryGraph"]
graph_execution_manager: ItemStorageABC["GraphExecutionState"]
# TODO: Just forward-declared everything due to circular dependencies. Fix structure.
events: "EventServiceBase"
latents: "LatentsStorageBase"
queue: "InvocationQueueABC"
model_manager: "ModelManager"
restoration: "RestorationServices"
configuration: "InvokeAISettings"
images: "ImageServiceABC"
boards: "BoardServiceABC"
board_images: "BoardImagesServiceABC"
graph_library: "ItemStorageABC"["LibraryGraph"]
graph_execution_manager: "ItemStorageABC"["GraphExecutionState"]
processor: "InvocationProcessorABC"
def __init__(
self,
model_manager: ModelManager,
events: EventServiceBase,
logger: types.ModuleType,
latents: LatentsStorageBase,
images: ImageStorageBase,
metadata: MetadataServiceBase,
queue: InvocationQueueABC,
graph_library: ItemStorageABC["LibraryGraph"],
graph_execution_manager: ItemStorageABC["GraphExecutionState"],
processor: "InvocationProcessorABC",
restoration: RestorationServices,
configuration: InvokeAISettings=None,
self,
model_manager: "ModelManager",
events: "EventServiceBase",
logger: "Logger",
latents: "LatentsStorageBase",
images: "ImageServiceABC",
boards: "BoardServiceABC",
board_images: "BoardImagesServiceABC",
queue: "InvocationQueueABC",
graph_library: "ItemStorageABC"["LibraryGraph"],
graph_execution_manager: "ItemStorageABC"["GraphExecutionState"],
processor: "InvocationProcessorABC",
restoration: "RestorationServices",
configuration: "InvokeAISettings",
):
self.model_manager = model_manager
self.events = events
self.logger = logger
self.latents = latents
self.images = images
self.metadata = metadata
self.boards = boards
self.board_images = board_images
self.queue = queue
self.graph_library = graph_library
self.graph_execution_manager = graph_execution_manager
self.processor = processor
self.restoration = restoration
self.configuration = configuration
self.boards = boards

View File

@ -22,7 +22,8 @@ class Invoker:
def invoke(
self, graph_execution_state: GraphExecutionState, invoke_all: bool = False
) -> str | None:
"""Determines the next node to invoke and returns the id of the invoked node, or None if there are no nodes to execute"""
"""Determines the next node to invoke and enqueues it, preparing if needed.
Returns the id of the queued node, or `None` if there are no nodes left to enqueue."""
# Get the next invocation
invocation = graph_execution_state.next()

View File

@ -1,6 +1,5 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
import os
from abc import ABC, abstractmethod
from pathlib import Path
from queue import Queue
@ -16,7 +15,7 @@ class LatentsStorageBase(ABC):
pass
@abstractmethod
def set(self, name: str, data: torch.Tensor) -> None:
def save(self, name: str, data: torch.Tensor) -> None:
pass
@abstractmethod
@ -47,8 +46,8 @@ class ForwardCacheLatentsStorage(LatentsStorageBase):
self.__set_cache(name, latent)
return latent
def set(self, name: str, data: torch.Tensor) -> None:
self.__underlying_storage.set(name, data)
def save(self, name: str, data: torch.Tensor) -> None:
self.__underlying_storage.save(name, data)
self.__set_cache(name, data)
def delete(self, name: str) -> None:
@ -70,24 +69,26 @@ class ForwardCacheLatentsStorage(LatentsStorageBase):
class DiskLatentsStorage(LatentsStorageBase):
"""Stores latents in a folder on disk without caching"""
__output_folder: str
__output_folder: str | Path
def __init__(self, output_folder: str):
self.__output_folder = output_folder
Path(output_folder).mkdir(parents=True, exist_ok=True)
def __init__(self, output_folder: str | Path):
self.__output_folder = output_folder if isinstance(output_folder, Path) else Path(output_folder)
self.__output_folder.mkdir(parents=True, exist_ok=True)
def get(self, name: str) -> torch.Tensor:
latent_path = self.get_path(name)
return torch.load(latent_path)
def set(self, name: str, data: torch.Tensor) -> None:
def save(self, name: str, data: torch.Tensor) -> None:
self.__output_folder.mkdir(parents=True, exist_ok=True)
latent_path = self.get_path(name)
torch.save(data, latent_path)
def delete(self, name: str) -> None:
latent_path = self.get_path(name)
os.remove(latent_path)
latent_path.unlink()
def get_path(self, name: str) -> str:
return os.path.join(self.__output_folder, name)
def get_path(self, name: str) -> Path:
return self.__output_folder / name

View File

@ -1,105 +1,142 @@
import json
from abc import ABC, abstractmethod
from typing import Any, Dict, Optional, TypedDict
from PIL import Image, PngImagePlugin
from pydantic import BaseModel
from typing import Any, Union
import networkx as nx
from invokeai.app.models.image import ImageType, is_image_type
class MetadataImageField(TypedDict):
"""Pydantic-less ImageField, used for metadata parsing."""
image_type: ImageType
image_name: str
class MetadataLatentsField(TypedDict):
"""Pydantic-less LatentsField, used for metadata parsing."""
latents_name: str
class MetadataColorField(TypedDict):
"""Pydantic-less ColorField, used for metadata parsing"""
r: int
g: int
b: int
a: int
# TODO: This is a placeholder for `InvocationsUnion` pending resolution of circular imports
NodeMetadata = Dict[
str, None | str | int | float | bool | MetadataImageField | MetadataLatentsField | MetadataColorField
]
class InvokeAIMetadata(TypedDict, total=False):
"""InvokeAI-specific metadata format."""
session_id: Optional[str]
node: Optional[NodeMetadata]
def build_invokeai_metadata_pnginfo(
metadata: InvokeAIMetadata | None,
) -> PngImagePlugin.PngInfo:
"""Builds a PngInfo object with key `"invokeai"` and value `metadata`"""
pnginfo = PngImagePlugin.PngInfo()
if metadata is not None:
pnginfo.add_text("invokeai", json.dumps(metadata))
return pnginfo
from invokeai.app.models.metadata import ImageMetadata
from invokeai.app.services.graph import Graph, GraphExecutionState
class MetadataServiceBase(ABC):
@abstractmethod
def get_metadata(self, image: Image.Image) -> InvokeAIMetadata | None:
"""Gets the InvokeAI metadata from a PIL Image, skipping invalid values"""
pass
"""Handles building metadata for nodes, images, and outputs."""
@abstractmethod
def build_metadata(
self, session_id: str, node: BaseModel
) -> InvokeAIMetadata | None:
"""Builds an InvokeAIMetadata object"""
def create_image_metadata(
self, session: GraphExecutionState, node_id: str
) -> ImageMetadata:
"""Builds an ImageMetadata object for a node."""
pass
class PngMetadataService(MetadataServiceBase):
"""Handles loading and building metadata for images."""
class CoreMetadataService(MetadataServiceBase):
_ANCESTOR_TYPES = ["t2l", "l2l"]
"""The ancestor types that contain the core metadata"""
# TODO: Use `InvocationsUnion` to **validate** metadata as representing a fully-functioning node
def _load_metadata(self, image: Image.Image) -> dict | None:
"""Loads a specific info entry from a PIL Image."""
_ANCESTOR_PARAMS = ["type", "steps", "model", "cfg_scale", "scheduler", "strength"]
"""The core metadata parameters in the ancestor types"""
try:
info = image.info.get("invokeai")
_NOISE_FIELDS = ["seed", "width", "height"]
"""The core metadata parameters in the noise node"""
if type(info) is not str:
return None
loaded_metadata = json.loads(info)
if type(loaded_metadata) is not dict:
return None
if len(loaded_metadata.items()) == 0:
return None
return loaded_metadata
except:
return None
def get_metadata(self, image: Image.Image) -> dict | None:
"""Retrieves an image's metadata as a dict"""
loaded_metadata = self._load_metadata(image)
return loaded_metadata
def build_metadata(self, session_id: str, node: BaseModel) -> InvokeAIMetadata:
metadata = InvokeAIMetadata(session_id=session_id, node=node.dict())
def create_image_metadata(
self, session: GraphExecutionState, node_id: str
) -> ImageMetadata:
metadata = self._build_metadata_from_graph(session, node_id)
return metadata
def _find_nearest_ancestor(self, G: nx.DiGraph, node_id: str) -> Union[str, None]:
"""
Finds the id of the nearest ancestor (of a valid type) of a given node.
Parameters:
G (nx.DiGraph): The execution graph, converted in to a networkx DiGraph. Its nodes must
have the same data as the execution graph.
node_id (str): The ID of the node.
Returns:
str | None: The ID of the nearest ancestor, or None if there are no valid ancestors.
"""
# Retrieve the node from the graph
node = G.nodes[node_id]
# If the node type is one of the core metadata node types, return its id
if node.get("type") in self._ANCESTOR_TYPES:
return node.get("id")
# Else, look for the ancestor in the predecessor nodes
for predecessor in G.predecessors(node_id):
result = self._find_nearest_ancestor(G, predecessor)
if result:
return result
# If there are no valid ancestors, return None
return None
def _get_additional_metadata(
self, graph: Graph, node_id: str
) -> Union[dict[str, Any], None]:
"""
Returns additional metadata for a given node.
Parameters:
graph (Graph): The execution graph.
node_id (str): The ID of the node.
Returns:
dict[str, Any] | None: A dictionary of additional metadata.
"""
metadata = {}
# Iterate over all edges in the graph
for edge in graph.edges:
dest_node_id = edge.destination.node_id
dest_field = edge.destination.field
source_node_dict = graph.nodes[edge.source.node_id].dict()
# If the destination node ID matches the given node ID, gather necessary metadata
if dest_node_id == node_id:
# Prompt
if dest_field == "positive_conditioning":
metadata["positive_conditioning"] = source_node_dict.get("prompt")
# Negative prompt
if dest_field == "negative_conditioning":
metadata["negative_conditioning"] = source_node_dict.get("prompt")
# Seed, width and height
if dest_field == "noise":
for field in self._NOISE_FIELDS:
metadata[field] = source_node_dict.get(field)
return metadata
def _build_metadata_from_graph(
self, session: GraphExecutionState, node_id: str
) -> ImageMetadata:
"""
Builds an ImageMetadata object for a node.
Parameters:
session (GraphExecutionState): The session.
node_id (str): The ID of the node.
Returns:
ImageMetadata: The metadata for the node.
"""
# We need to do all the traversal on the execution graph
graph = session.execution_graph
# Find the nearest `t2l`/`l2l` ancestor of the given node
ancestor_id = self._find_nearest_ancestor(graph.nx_graph_with_data(), node_id)
# If no ancestor was found, return an empty ImageMetadata object
if ancestor_id is None:
return ImageMetadata()
ancestor_node = graph.get_node(ancestor_id)
# Grab all the core metadata from the ancestor node
ancestor_metadata = {
param: val
for param, val in ancestor_node.dict().items()
if param in self._ANCESTOR_PARAMS
}
# Get this image's prompts and noise parameters
addl_metadata = self._get_additional_metadata(graph, ancestor_id)
# If additional metadata was found, add it to the main metadata
if addl_metadata is not None:
ancestor_metadata.update(addl_metadata)
return ImageMetadata(**ancestor_metadata)

View File

@ -1,104 +0,0 @@
import os
import sys
import torch
from argparse import Namespace
from omegaconf import OmegaConf
from pathlib import Path
from typing import types
import invokeai.version
from .config import InvokeAISettings
from ...backend import ModelManager
from ...backend.util import choose_precision, choose_torch_device
# TODO: Replace with an abstract class base ModelManagerBase
def get_model_manager(config: InvokeAISettings, logger: types.ModuleType) -> ModelManager:
model_config = config.model_conf_path
if not model_config.exists():
report_model_error(
config, FileNotFoundError(f"The file {model_config} could not be found."), logger
)
logger.info(f"{invokeai.version.__app_name__}, version {invokeai.version.__version__}")
logger.info(f'InvokeAI runtime directory is "{config.root}"')
# these two lines prevent a horrible warning message from appearing
# when the frozen CLIP tokenizer is imported
import transformers # type: ignore
transformers.logging.set_verbosity_error()
import diffusers
diffusers.logging.set_verbosity_error()
embedding_path = config.embedding_path
# migrate legacy models
ModelManager.migrate_models()
# creating the model manager
try:
device = torch.device(choose_torch_device())
precision = 'float16' if config.precision=='float16' \
else 'float32' if config.precision=='float32' \
else choose_precision(device)
model_manager = ModelManager(
OmegaConf.load(config.model_conf_path),
precision=precision,
device_type=device,
max_loaded_models=config.max_loaded_models,
embedding_path = embedding_path,
logger = logger,
)
except (FileNotFoundError, TypeError, AssertionError) as e:
report_model_error(config, e, logger)
except (IOError, KeyError) as e:
logger.error(f"{e}. Aborting.")
sys.exit(-1)
# try to autoconvert new models
# autoimport new .ckpt files
if config.autoconvert_path:
model_manager.heuristic_import(
config.autoconvert_path,
)
return model_manager
def report_model_error(opt: Namespace, e: Exception, logger: types.ModuleType):
logger.error(f'An error occurred while attempting to initialize the model: "{str(e)}"')
logger.error(
"This can be caused by a missing or corrupted models file, and can sometimes be fixed by (re)installing the models."
)
yes_to_all = os.environ.get("INVOKE_MODEL_RECONFIGURE")
if yes_to_all:
logger.warning(
"Reconfiguration is being forced by environment variable INVOKE_MODEL_RECONFIGURE"
)
else:
response = input(
"Do you want to run invokeai-configure script to select and/or reinstall models? [y] "
)
if response.startswith(("n", "N")):
return
logger.info("invokeai-configure is launching....\n")
# Match arguments that were set on the CLI
# only the arguments accepted by the configuration script are parsed
root_dir = ["--root", opt.root_dir] if opt.root_dir is not None else []
config = ["--config", opt.conf] if opt.conf is not None else []
sys.argv = ["invokeai-configure"]
sys.argv.extend(root_dir)
sys.argv.extend(config.to_dict())
if yes_to_all is not None:
for arg in yes_to_all.split():
sys.argv.append(arg)
from invokeai.frontend.install import invokeai_configure
invokeai_configure()
# TODO: Figure out how to restart
# print('** InvokeAI will now restart')
# sys.argv = previous_args
# main() # would rather do a os.exec(), but doesn't exist?
# sys.exit(0)

View File

@ -0,0 +1,363 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
from __future__ import annotations
import torch
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Optional, Union, Callable, List, Tuple, types, TYPE_CHECKING
from dataclasses import dataclass
from invokeai.backend.model_management.model_manager import (
ModelManager,
BaseModelType,
ModelType,
SubModelType,
ModelInfo,
)
from invokeai.app.models.exceptions import CanceledException
from .config import InvokeAIAppConfig
from ...backend.util import choose_precision, choose_torch_device
if TYPE_CHECKING:
from ..invocations.baseinvocation import BaseInvocation, InvocationContext
class ModelManagerServiceBase(ABC):
"""Responsible for managing models on disk and in memory"""
@abstractmethod
def __init__(
self,
config: InvokeAIAppConfig,
logger: types.ModuleType,
):
"""
Initialize with the path to the models.yaml config file.
Optional parameters are the torch device type, precision, max_models,
and sequential_offload boolean. Note that the default device
type and precision are set up for a CUDA system running at half precision.
"""
pass
@abstractmethod
def get_model(
self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
submodel: Optional[SubModelType] = None,
node: Optional[BaseInvocation] = None,
context: Optional[InvocationContext] = None,
) -> ModelInfo:
"""Retrieve the indicated model with name and type.
submodel can be used to get a part (such as the vae)
of a diffusers pipeline."""
pass
@property
@abstractmethod
def logger(self):
pass
@abstractmethod
def model_exists(
self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
) -> bool:
pass
@abstractmethod
def model_info(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict:
"""
Given a model name returns a dict-like (OmegaConf) object describing it.
"""
pass
@abstractmethod
def model_names(self) -> List[Tuple[str, BaseModelType, ModelType]]:
"""
Returns a list of all the model names known.
"""
pass
@abstractmethod
def list_models(self, base_model: Optional[BaseModelType] = None, model_type: Optional[ModelType] = None) -> dict:
"""
Return a dict of models in the format:
{ model_type1:
{ model_name1: {'status': 'active'|'cached'|'not loaded',
'model_name' : name,
'model_type' : SDModelType,
'description': description,
'format': 'folder'|'safetensors'|'ckpt'
},
model_name2: { etc }
},
model_type2:
{ model_name_n: etc
}
"""
pass
@abstractmethod
def add_model(
self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
model_attributes: dict,
clobber: bool = False
) -> None:
"""
Update the named model with a dictionary of attributes. Will fail with an
assertion error if the name already exists. Pass clobber=True to overwrite.
On a successful update, the config will be changed in memory. Will fail
with an assertion error if provided attributes are incorrect or
the model name is missing. Call commit() to write changes to disk.
"""
pass
@abstractmethod
def del_model(
self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
):
"""
Delete the named model from configuration. If delete_files is true,
then the underlying weight file or diffusers directory will be deleted
as well. Call commit() to write to disk.
"""
pass
@abstractmethod
def commit(self, conf_file: Path = None) -> None:
"""
Write current configuration out to the indicated file.
If no conf_file is provided, then replaces the
original file/database used to initialize the object.
"""
pass
# simple implementation
class ModelManagerService(ModelManagerServiceBase):
"""Responsible for managing models on disk and in memory"""
def __init__(
self,
config: InvokeAIAppConfig,
logger: types.ModuleType,
):
"""
Initialize with the path to the models.yaml config file.
Optional parameters are the torch device type, precision, max_models,
and sequential_offload boolean. Note that the default device
type and precision are set up for a CUDA system running at half precision.
"""
if config.model_conf_path and config.model_conf_path.exists():
config_file = config.model_conf_path
else:
config_file = config.root_dir / "configs/models.yaml"
if not config_file.exists():
raise IOError(f"The file {config_file} could not be found.")
logger.debug(f'config file={config_file}')
device = torch.device(choose_torch_device())
precision = config.precision
if precision == "auto":
precision = choose_precision(device)
dtype = torch.float32 if precision == 'float32' else torch.float16
# this is transitional backward compatibility
# support for the deprecated `max_loaded_models`
# configuration value. If present, then the
# cache size is set to 2.5 GB times
# the number of max_loaded_models. Otherwise
# use new `max_cache_size` config setting
max_cache_size = config.max_cache_size \
if hasattr(config,'max_cache_size') \
else config.max_loaded_models * 2.5
sequential_offload = config.sequential_guidance
self.mgr = ModelManager(
config=config_file,
device_type=device,
precision=dtype,
max_cache_size=max_cache_size,
sequential_offload=sequential_offload,
logger=logger,
)
logger.info('Model manager service initialized')
def get_model(
self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
submodel: Optional[SubModelType] = None,
node: Optional[BaseInvocation] = None,
context: Optional[InvocationContext] = None,
) -> ModelInfo:
"""
Retrieve the indicated model. submodel can be used to get a
part (such as the vae) of a diffusers mode.
"""
# if we are called from within a node, then we get to emit
# load start and complete events
if node and context:
self._emit_load_event(
node=node,
context=context,
model_name=model_name,
base_model=base_model,
model_type=model_type,
submodel=submodel,
)
model_info = self.mgr.get_model(
model_name,
base_model,
model_type,
submodel,
)
if node and context:
self._emit_load_event(
node=node,
context=context,
model_name=model_name,
base_model=base_model,
model_type=model_type,
submodel=submodel,
model_info=model_info
)
return model_info
def model_exists(
self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
) -> bool:
"""
Given a model name, returns True if it is a valid
identifier.
"""
return self.mgr.model_exists(
model_name,
base_model,
model_type,
)
def model_info(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict:
"""
Given a model name returns a dict-like (OmegaConf) object describing it.
"""
return self.mgr.model_info(model_name, base_model, model_type)
def model_names(self) -> List[Tuple[str, BaseModelType, ModelType]]:
"""
Returns a list of all the model names known.
"""
return self.mgr.model_names()
def list_models(
self,
base_model: Optional[BaseModelType] = None,
model_type: Optional[ModelType] = None
) -> list[dict]:
# ) -> dict:
"""
Return a list of models.
"""
return self.mgr.list_models(base_model, model_type)
def add_model(
self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
model_attributes: dict,
clobber: bool = False,
)->None:
"""
Update the named model with a dictionary of attributes. Will fail with an
assertion error if the name already exists. Pass clobber=True to overwrite.
On a successful update, the config will be changed in memory. Will fail
with an assertion error if provided attributes are incorrect or
the model name is missing. Call commit() to write changes to disk.
"""
return self.mgr.add_model(model_name, base_model, model_type, model_attributes, clobber)
def del_model(
self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
):
"""
Delete the named model from configuration. If delete_files is true,
then the underlying weight file or diffusers directory will be deleted
as well. Call commit() to write to disk.
"""
self.mgr.del_model(model_name, base_model, model_type)
def commit(self, conf_file: Optional[Path]=None):
"""
Write current configuration out to the indicated file.
If no conf_file is provided, then replaces the
original file/database used to initialize the object.
"""
return self.mgr.commit(conf_file)
def _emit_load_event(
self,
node,
context,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
submodel: SubModelType,
model_info: Optional[ModelInfo] = None,
):
if context.services.queue.is_canceled(context.graph_execution_state_id):
raise CanceledException()
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
source_node_id = graph_execution_state.prepared_source_mapping[node.id]
if model_info:
context.services.events.emit_model_load_completed(
graph_execution_state_id=context.graph_execution_state_id,
node=node.dict(),
source_node_id=source_node_id,
model_name=model_name,
base_model=base_model,
model_type=model_type,
submodel=submodel,
model_info=model_info
)
else:
context.services.events.emit_model_load_started(
graph_execution_state_id=context.graph_execution_state_id,
node=node.dict(),
source_node_id=source_node_id,
model_name=model_name,
base_model=base_model,
model_type=model_type,
submodel=submodel,
)
@property
def logger(self):
return self.mgr.logger

View File

@ -0,0 +1,62 @@
from typing import Optional, Union
from datetime import datetime
from pydantic import BaseModel, Extra, Field, StrictBool, StrictStr
from invokeai.app.util.misc import get_iso_timestamp
class BoardRecord(BaseModel):
"""Deserialized board record."""
board_id: str = Field(description="The unique ID of the board.")
"""The unique ID of the board."""
board_name: str = Field(description="The name of the board.")
"""The name of the board."""
created_at: Union[datetime, str] = Field(
description="The created timestamp of the board."
)
"""The created timestamp of the image."""
updated_at: Union[datetime, str] = Field(
description="The updated timestamp of the board."
)
"""The updated timestamp of the image."""
deleted_at: Union[datetime, str, None] = Field(
description="The deleted timestamp of the board."
)
"""The updated timestamp of the image."""
cover_image_name: Optional[str] = Field(
description="The name of the cover image of the board."
)
"""The name of the cover image of the board."""
class BoardDTO(BoardRecord):
"""Deserialized board record with cover image URL and image count."""
cover_image_name: Optional[str] = Field(
description="The name of the board's cover image."
)
"""The URL of the thumbnail of the most recent image in the board."""
image_count: int = Field(description="The number of images in the board.")
"""The number of images in the board."""
def deserialize_board_record(board_dict: dict) -> BoardRecord:
"""Deserializes a board record."""
# Retrieve all the values, setting "reasonable" defaults if they are not present.
board_id = board_dict.get("board_id", "unknown")
board_name = board_dict.get("board_name", "unknown")
cover_image_name = board_dict.get("cover_image_name", "unknown")
created_at = board_dict.get("created_at", get_iso_timestamp())
updated_at = board_dict.get("updated_at", get_iso_timestamp())
deleted_at = board_dict.get("deleted_at", get_iso_timestamp())
return BoardRecord(
board_id=board_id,
board_name=board_name,
cover_image_name=cover_image_name,
created_at=created_at,
updated_at=updated_at,
deleted_at=deleted_at,
)

View File

@ -0,0 +1,151 @@
import datetime
from typing import Optional, Union
from pydantic import BaseModel, Extra, Field, StrictBool, StrictStr
from invokeai.app.models.image import ImageCategory, ResourceOrigin
from invokeai.app.models.metadata import ImageMetadata
from invokeai.app.util.misc import get_iso_timestamp
class ImageRecord(BaseModel):
"""Deserialized image record."""
image_name: str = Field(description="The unique name of the image.")
"""The unique name of the image."""
image_origin: ResourceOrigin = Field(description="The type of the image.")
"""The origin of the image."""
image_category: ImageCategory = Field(description="The category of the image.")
"""The category of the image."""
width: int = Field(description="The width of the image in px.")
"""The actual width of the image in px. This may be different from the width in metadata."""
height: int = Field(description="The height of the image in px.")
"""The actual height of the image in px. This may be different from the height in metadata."""
created_at: Union[datetime.datetime, str] = Field(
description="The created timestamp of the image."
)
"""The created timestamp of the image."""
updated_at: Union[datetime.datetime, str] = Field(
description="The updated timestamp of the image."
)
"""The updated timestamp of the image."""
deleted_at: Union[datetime.datetime, str, None] = Field(
description="The deleted timestamp of the image."
)
"""The deleted timestamp of the image."""
is_intermediate: bool = Field(description="Whether this is an intermediate image.")
"""Whether this is an intermediate image."""
session_id: Optional[str] = Field(
default=None,
description="The session ID that generated this image, if it is a generated image.",
)
"""The session ID that generated this image, if it is a generated image."""
node_id: Optional[str] = Field(
default=None,
description="The node ID that generated this image, if it is a generated image.",
)
"""The node ID that generated this image, if it is a generated image."""
metadata: Optional[ImageMetadata] = Field(
default=None,
description="A limited subset of the image's generation metadata. Retrieve the image's session for full metadata.",
)
"""A limited subset of the image's generation metadata. Retrieve the image's session for full metadata."""
class ImageRecordChanges(BaseModel, extra=Extra.forbid):
"""A set of changes to apply to an image record.
Only limited changes are valid:
- `image_category`: change the category of an image
- `session_id`: change the session associated with an image
- `is_intermediate`: change the image's `is_intermediate` flag
"""
image_category: Optional[ImageCategory] = Field(
description="The image's new category."
)
"""The image's new category."""
session_id: Optional[StrictStr] = Field(
default=None,
description="The image's new session ID.",
)
"""The image's new session ID."""
is_intermediate: Optional[StrictBool] = Field(
default=None, description="The image's new `is_intermediate` flag."
)
"""The image's new `is_intermediate` flag."""
class ImageUrlsDTO(BaseModel):
"""The URLs for an image and its thumbnail."""
image_name: str = Field(description="The unique name of the image.")
"""The unique name of the image."""
image_url: str = Field(description="The URL of the image.")
"""The URL of the image."""
thumbnail_url: str = Field(description="The URL of the image's thumbnail.")
"""The URL of the image's thumbnail."""
class ImageDTO(ImageRecord, ImageUrlsDTO):
"""Deserialized image record, enriched for the frontend."""
board_id: Union[str, None] = Field(
description="The id of the board the image belongs to, if one exists."
)
"""The id of the board the image belongs to, if one exists."""
pass
def image_record_to_dto(
image_record: ImageRecord, image_url: str, thumbnail_url: str, board_id: Union[str, None]
) -> ImageDTO:
"""Converts an image record to an image DTO."""
return ImageDTO(
**image_record.dict(),
image_url=image_url,
thumbnail_url=thumbnail_url,
board_id=board_id,
)
def deserialize_image_record(image_dict: dict) -> ImageRecord:
"""Deserializes an image record."""
# Retrieve all the values, setting "reasonable" defaults if they are not present.
image_name = image_dict.get("image_name", "unknown")
image_origin = ResourceOrigin(
image_dict.get("image_origin", ResourceOrigin.INTERNAL.value)
)
image_category = ImageCategory(
image_dict.get("image_category", ImageCategory.GENERAL.value)
)
width = image_dict.get("width", 0)
height = image_dict.get("height", 0)
session_id = image_dict.get("session_id", None)
node_id = image_dict.get("node_id", None)
created_at = image_dict.get("created_at", get_iso_timestamp())
updated_at = image_dict.get("updated_at", get_iso_timestamp())
deleted_at = image_dict.get("deleted_at", get_iso_timestamp())
is_intermediate = image_dict.get("is_intermediate", False)
raw_metadata = image_dict.get("metadata")
if raw_metadata is not None:
metadata = ImageMetadata.parse_raw(raw_metadata)
else:
metadata = None
return ImageRecord(
image_name=image_name,
image_origin=image_origin,
image_category=image_category,
width=width,
height=height,
session_id=session_id,
node_id=node_id,
metadata=metadata,
created_at=created_at,
updated_at=updated_at,
deleted_at=deleted_at,
is_intermediate=is_intermediate,
)

View File

@ -0,0 +1,30 @@
from abc import ABC, abstractmethod
from enum import Enum, EnumMeta
import uuid
class ResourceType(str, Enum, metaclass=EnumMeta):
"""Enum for resource types."""
IMAGE = "image"
LATENT = "latent"
class NameServiceBase(ABC):
"""Low-level service responsible for naming resources (images, latents, etc)."""
# TODO: Add customizable naming schemes
@abstractmethod
def create_image_name(self) -> str:
"""Creates a name for an image."""
pass
class SimpleNameService(NameServiceBase):
"""Creates image names from UUIDs."""
# TODO: Add customizable naming schemes
def create_image_name(self) -> str:
uuid_str = str(uuid.uuid4())
filename = f"{uuid_str}.png"
return filename

View File

@ -16,13 +16,14 @@ class RestorationServices:
gfpgan, codeformer, esrgan = None, None, None
if args.restore or args.esrgan:
restoration = Restoration()
if args.restore:
# TODO: redo for new model structure
if False and args.restore:
gfpgan, codeformer = restoration.load_face_restore_models(
args.gfpgan_model_path
)
else:
logger.info("Face restoration disabled")
if args.esrgan:
if False and args.esrgan:
esrgan = restoration.load_esrgan(args.esrgan_bg_tile)
else:
logger.info("Upscaling disabled")

View File

@ -26,7 +26,6 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
self._table_name = table_name
self._id_field = id_field # TODO: validate that T has this field
self._lock = Lock()
self._conn = sqlite3.connect(
self._filename, check_same_thread=False
) # TODO: figure out a better threading solution

View File

@ -0,0 +1,25 @@
import os
from abc import ABC, abstractmethod
class UrlServiceBase(ABC):
"""Responsible for building URLs for resources."""
@abstractmethod
def get_image_url(self, image_name: str, thumbnail: bool = False) -> str:
"""Gets the URL for an image or thumbnail."""
pass
class LocalUrlService(UrlServiceBase):
def __init__(self, base_url: str = "api/v1"):
self._base_url = base_url
def get_image_url(self, image_name: str, thumbnail: bool = False) -> str:
image_basename = os.path.basename(image_name)
# These paths are determined by the routes in invokeai/app/api/routers/images.py
if thumbnail:
return f"{self._base_url}/images/{image_basename}/thumbnail"
return f"{self._base_url}/images/{image_basename}"

View File

@ -0,0 +1,15 @@
from enum import EnumMeta
class MetaEnum(EnumMeta):
"""Metaclass to support additional features in Enums.
- `in` operator support: `'value' in MyEnum -> bool`
"""
def __contains__(cls, item):
try:
cls(item)
except ValueError:
return False
return True

View File

@ -6,6 +6,14 @@ def get_timestamp():
return int(datetime.datetime.now(datetime.timezone.utc).timestamp())
def get_iso_timestamp() -> str:
return datetime.datetime.utcnow().isoformat()
def get_datetime_from_iso_timestamp(iso_timestamp: str) -> datetime.datetime:
return datetime.datetime.fromisoformat(iso_timestamp)
SEED_MAX = np.iinfo(np.int32).max

View File

@ -1,5 +1,5 @@
from invokeai.app.api.models.images import ProgressImage
from invokeai.app.models.exceptions import CanceledException
from invokeai.app.models.image import ProgressImage
from ..invocations.baseinvocation import InvocationContext
from ...backend.util.util import image_to_dataURL
from ...backend.generator.base import Generator

View File

@ -5,9 +5,11 @@ from .generator import (
InvokeAIGeneratorBasicParams,
InvokeAIGenerator,
InvokeAIGeneratorOutput,
Txt2Img,
Img2Img,
Inpaint
)
from .model_management import ModelManager, SDModelComponent
from .model_management import (
ModelManager, ModelCache, BaseModelType,
ModelType, SubModelType, ModelInfo
)
from .safety_checker import SafetyChecker

View File

@ -5,7 +5,6 @@ from .base import (
InvokeAIGenerator,
InvokeAIGeneratorBasicParams,
InvokeAIGeneratorOutput,
Txt2Img,
Img2Img,
Inpaint,
Generator,

View File

@ -29,7 +29,6 @@ import invokeai.backend.util.logging as logger
from ..image_util import configure_model_padding
from ..util.util import rand_perlin_2d
from ..safety_checker import SafetyChecker
from ..prompting.conditioning import get_uc_and_c_and_ec
from ..stable_diffusion.diffusers_pipeline import StableDiffusionGeneratorPipeline
from ..stable_diffusion.schedulers import SCHEDULER_MAP
@ -75,17 +74,21 @@ class InvokeAIGenerator(metaclass=ABCMeta):
def __init__(self,
model_info: dict,
params: InvokeAIGeneratorBasicParams=InvokeAIGeneratorBasicParams(),
**kwargs,
):
self.model_info=model_info
self.params=params
self.kwargs = kwargs
def generate(self,
prompt: str='',
callback: Optional[Callable]=None,
step_callback: Optional[Callable]=None,
iterations: int=1,
**keyword_args,
)->Iterator[InvokeAIGeneratorOutput]:
def generate(
self,
conditioning: tuple,
scheduler,
callback: Optional[Callable]=None,
step_callback: Optional[Callable]=None,
iterations: int=1,
**keyword_args,
)->Iterator[InvokeAIGeneratorOutput]:
'''
Return an iterator across the indicated number of generations.
Each time the iterator is called it will return an InvokeAIGeneratorOutput
@ -111,51 +114,46 @@ class InvokeAIGenerator(metaclass=ABCMeta):
generator_args.update(keyword_args)
model_info = self.model_info
model_name = model_info['model_name']
model:StableDiffusionGeneratorPipeline = model_info['model']
model_hash = model_info['hash']
scheduler: Scheduler = self.get_scheduler(
model=model,
scheduler_name=generator_args.get('scheduler')
)
uc, c, extra_conditioning_info = get_uc_and_c_and_ec(prompt,model=model)
gen_class = self._generator_class()
generator = gen_class(model, self.params.precision)
if self.params.variation_amount > 0:
generator.set_variation(generator_args.get('seed'),
generator_args.get('variation_amount'),
generator_args.get('with_variations')
)
model_name = model_info.name
model_hash = model_info.hash
with model_info.context as model:
gen_class = self._generator_class()
generator = gen_class(model, self.params.precision, **self.kwargs)
if self.params.variation_amount > 0:
generator.set_variation(generator_args.get('seed'),
generator_args.get('variation_amount'),
generator_args.get('with_variations')
)
if isinstance(model, DiffusionPipeline):
for component in [model.unet, model.vae]:
configure_model_padding(component,
if isinstance(model, DiffusionPipeline):
for component in [model.unet, model.vae]:
configure_model_padding(component,
generator_args.get('seamless',False),
generator_args.get('seamless_axes')
)
else:
configure_model_padding(model,
generator_args.get('seamless',False),
generator_args.get('seamless_axes')
)
else:
configure_model_padding(model,
generator_args.get('seamless',False),
generator_args.get('seamless_axes')
)
iteration_count = range(iterations) if iterations else itertools.count(start=0, step=1)
for i in iteration_count:
results = generator.generate(prompt,
conditioning=(uc, c, extra_conditioning_info),
step_callback=step_callback,
sampler=scheduler,
**generator_args,
)
output = InvokeAIGeneratorOutput(
image=results[0][0],
seed=results[0][1],
attention_maps_images=results[0][2],
model_hash = model_hash,
params=Namespace(model_name=model_name,**generator_args),
)
if callback:
callback(output)
iteration_count = range(iterations) if iterations else itertools.count(start=0, step=1)
for i in iteration_count:
results = generator.generate(
conditioning=conditioning,
step_callback=step_callback,
sampler=scheduler,
**generator_args,
)
output = InvokeAIGeneratorOutput(
image=results[0][0],
seed=results[0][1],
attention_maps_images=results[0][2],
model_hash = model_hash,
params=Namespace(model_name=model_name,**generator_args),
)
if callback:
callback(output)
yield output
@classmethod
@ -168,20 +166,6 @@ class InvokeAIGenerator(metaclass=ABCMeta):
def load_generator(self, model: StableDiffusionGeneratorPipeline, generator_class: Type[Generator]):
return generator_class(model, self.params.precision)
def get_scheduler(self, scheduler_name:str, model: StableDiffusionGeneratorPipeline)->Scheduler:
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP['ddim'])
scheduler_config = model.scheduler.config
if "_backup" in scheduler_config:
scheduler_config = scheduler_config["_backup"]
scheduler_config = {**scheduler_config, **scheduler_extra_config, "_backup": scheduler_config}
scheduler = scheduler_class.from_config(scheduler_config)
# hack copied over from generate.py
if not hasattr(scheduler, 'uses_inpainting_model'):
scheduler.uses_inpainting_model = lambda: False
return scheduler
@classmethod
def _generator_class(cls)->Type[Generator]:
'''
@ -191,13 +175,6 @@ class InvokeAIGenerator(metaclass=ABCMeta):
'''
return Generator
# ------------------------------------
class Txt2Img(InvokeAIGenerator):
@classmethod
def _generator_class(cls):
from .txt2img import Txt2Img
return Txt2Img
# ------------------------------------
class Img2Img(InvokeAIGenerator):
def generate(self,
@ -251,36 +228,17 @@ class Inpaint(Img2Img):
from .inpaint import Inpaint
return Inpaint
# ------------------------------------
class Embiggen(Txt2Img):
def generate(
self,
embiggen: list=None,
embiggen_tiles: list = None,
strength: float=0.75,
**kwargs)->Iterator[InvokeAIGeneratorOutput]:
return super().generate(embiggen=embiggen,
embiggen_tiles=embiggen_tiles,
strength=strength,
**kwargs)
@classmethod
def _generator_class(cls):
from .embiggen import Embiggen
return Embiggen
class Generator:
downsampling_factor: int
latent_channels: int
precision: str
model: DiffusionPipeline
def __init__(self, model: DiffusionPipeline, precision: str):
def __init__(self, model: DiffusionPipeline, precision: str, **kwargs):
self.model = model
self.precision = precision
self.seed = None
self.latent_channels = model.channels
self.latent_channels = model.unet.config.in_channels
self.downsampling_factor = downsampling # BUG: should come from model or config
self.safety_checker = None
self.perlin = 0.0
@ -291,7 +249,7 @@ class Generator:
self.free_gpu_mem = None
# this is going to be overridden in img2img.py, txt2img.py and inpaint.py
def get_make_image(self, prompt, **kwargs):
def get_make_image(self, **kwargs):
"""
Returns a function returning an image derived from the prompt and the initial image
Return value depends on the seed at the time you call it
@ -307,7 +265,6 @@ class Generator:
def generate(
self,
prompt,
width,
height,
sampler,
@ -332,7 +289,6 @@ class Generator:
saver.get_stacked_maps_image()
)
make_image = self.get_make_image(
prompt,
sampler=sampler,
init_image=init_image,
width=width,

View File

@ -1,559 +0,0 @@
"""
invokeai.backend.generator.embiggen descends from .generator
and generates with .generator.img2img
"""
import numpy as np
import torch
from PIL import Image
from tqdm import trange
import invokeai.backend.util.logging as logger
from .base import Generator
from .img2img import Img2Img
class Embiggen(Generator):
def __init__(self, model, precision):
super().__init__(model, precision)
self.init_latent = None
# Replace generate because Embiggen doesn't need/use most of what it does normallly
def generate(
self,
prompt,
iterations=1,
seed=None,
image_callback=None,
step_callback=None,
**kwargs,
):
make_image = self.get_make_image(prompt, step_callback=step_callback, **kwargs)
results = []
seed = seed if seed else self.new_seed()
# Noise will be generated by the Img2Img generator when called
for _ in trange(iterations, desc="Generating"):
# make_image will call Img2Img which will do the equivalent of get_noise itself
image = make_image()
results.append([image, seed])
if image_callback is not None:
image_callback(image, seed, prompt_in=prompt)
seed = self.new_seed()
return results
@torch.no_grad()
def get_make_image(
self,
prompt,
sampler,
steps,
cfg_scale,
ddim_eta,
conditioning,
init_img,
strength,
width,
height,
embiggen,
embiggen_tiles,
step_callback=None,
**kwargs,
):
"""
Returns a function returning an image derived from the prompt and multi-stage twice-baked potato layering over the img2img on the initial image
Return value depends on the seed at the time you call it
"""
assert (
not sampler.uses_inpainting_model()
), "--embiggen is not supported by inpainting models"
# Construct embiggen arg array, and sanity check arguments
if embiggen == None: # embiggen can also be called with just embiggen_tiles
embiggen = [1.0] # If not specified, assume no scaling
elif embiggen[0] < 0:
embiggen[0] = 1.0
logger.warning(
"Embiggen scaling factor cannot be negative, fell back to the default of 1.0 !"
)
if len(embiggen) < 2:
embiggen.append(0.75)
elif embiggen[1] > 1.0 or embiggen[1] < 0:
embiggen[1] = 0.75
logger.warning(
"Embiggen upscaling strength for ESRGAN must be between 0 and 1, fell back to the default of 0.75 !"
)
if len(embiggen) < 3:
embiggen.append(0.25)
elif embiggen[2] < 0:
embiggen[2] = 0.25
logger.warning(
"Overlap size for Embiggen must be a positive ratio between 0 and 1 OR a number of pixels, fell back to the default of 0.25 !"
)
# Convert tiles from their user-freindly count-from-one to count-from-zero, because we need to do modulo math
# and then sort them, because... people.
if embiggen_tiles:
embiggen_tiles = list(map(lambda n: n - 1, embiggen_tiles))
embiggen_tiles.sort()
if strength >= 0.5:
logger.warning(
f"Embiggen may produce mirror motifs if the strength (-f) is too high (currently {strength}). Try values between 0.35-0.45."
)
# Prep img2img generator, since we wrap over it
gen_img2img = Img2Img(self.model, self.precision)
# Open original init image (not a tensor) to manipulate
initsuperimage = Image.open(init_img)
with Image.open(init_img) as img:
initsuperimage = img.convert("RGB")
# Size of the target super init image in pixels
initsuperwidth, initsuperheight = initsuperimage.size
# Increase by scaling factor if not already resized, using ESRGAN as able
if embiggen[0] != 1.0:
initsuperwidth = round(initsuperwidth * embiggen[0])
initsuperheight = round(initsuperheight * embiggen[0])
if embiggen[1] > 0: # No point in ESRGAN upscaling if strength is set zero
from ..restoration.realesrgan import ESRGAN
esrgan = ESRGAN()
logger.info(
f"ESRGAN upscaling init image prior to cutting with Embiggen with strength {embiggen[1]}"
)
if embiggen[0] > 2:
initsuperimage = esrgan.process(
initsuperimage,
embiggen[1], # upscale strength
self.seed,
4, # upscale scale
)
else:
initsuperimage = esrgan.process(
initsuperimage,
embiggen[1], # upscale strength
self.seed,
2, # upscale scale
)
# We could keep recursively re-running ESRGAN for a requested embiggen[0] larger than 4x
# but from personal experiance it doesn't greatly improve anything after 4x
# Resize to target scaling factor resolution
initsuperimage = initsuperimage.resize(
(initsuperwidth, initsuperheight), Image.Resampling.LANCZOS
)
# Use width and height as tile widths and height
# Determine buffer size in pixels
if embiggen[2] < 1:
if embiggen[2] < 0:
embiggen[2] = 0
overlap_size_x = round(embiggen[2] * width)
overlap_size_y = round(embiggen[2] * height)
else:
overlap_size_x = round(embiggen[2])
overlap_size_y = round(embiggen[2])
# With overall image width and height known, determine how many tiles we need
def ceildiv(a, b):
return -1 * (-a // b)
# X and Y needs to be determined independantly (we may have savings on one based on the buffer pixel count)
# (initsuperwidth - width) is the area remaining to the right that we need to layers tiles to fill
# (width - overlap_size_x) is how much new we can fill with a single tile
emb_tiles_x = 1
emb_tiles_y = 1
if (initsuperwidth - width) > 0:
emb_tiles_x = ceildiv(initsuperwidth - width, width - overlap_size_x) + 1
if (initsuperheight - height) > 0:
emb_tiles_y = ceildiv(initsuperheight - height, height - overlap_size_y) + 1
# Sanity
assert (
emb_tiles_x > 1 or emb_tiles_y > 1
), f"ERROR: Based on the requested dimensions of {initsuperwidth}x{initsuperheight} and tiles of {width}x{height} you don't need to Embiggen! Check your arguments."
# Prep alpha layers --------------
# https://stackoverflow.com/questions/69321734/how-to-create-different-transparency-like-gradient-with-python-pil
# agradientL is Left-side transparent
agradientL = (
Image.linear_gradient("L").rotate(90).resize((overlap_size_x, height))
)
# agradientT is Top-side transparent
agradientT = Image.linear_gradient("L").resize((width, overlap_size_y))
# radial corner is the left-top corner, made full circle then cut to just the left-top quadrant
agradientC = Image.new("L", (256, 256))
for y in range(256):
for x in range(256):
# Find distance to lower right corner (numpy takes arrays)
distanceToLR = np.sqrt([(255 - x) ** 2 + (255 - y) ** 2])[0]
# Clamp values to max 255
if distanceToLR > 255:
distanceToLR = 255
# Place the pixel as invert of distance
agradientC.putpixel((x, y), round(255 - distanceToLR))
# Create alternative asymmetric diagonal corner to use on "tailing" intersections to prevent hard edges
# Fits for a left-fading gradient on the bottom side and full opacity on the right side.
agradientAsymC = Image.new("L", (256, 256))
for y in range(256):
for x in range(256):
value = round(max(0, x - (255 - y)) * (255 / max(1, y)))
# Clamp values
value = max(0, value)
value = min(255, value)
agradientAsymC.putpixel((x, y), value)
# Create alpha layers default fully white
alphaLayerL = Image.new("L", (width, height), 255)
alphaLayerT = Image.new("L", (width, height), 255)
alphaLayerLTC = Image.new("L", (width, height), 255)
# Paste gradients into alpha layers
alphaLayerL.paste(agradientL, (0, 0))
alphaLayerT.paste(agradientT, (0, 0))
alphaLayerLTC.paste(agradientL, (0, 0))
alphaLayerLTC.paste(agradientT, (0, 0))
alphaLayerLTC.paste(agradientC.resize((overlap_size_x, overlap_size_y)), (0, 0))
# make masks with an asymmetric upper-right corner so when the curved transparent corner of the next tile
# to its right is placed it doesn't reveal a hard trailing semi-transparent edge in the overlapping space
alphaLayerTaC = alphaLayerT.copy()
alphaLayerTaC.paste(
agradientAsymC.rotate(270).resize((overlap_size_x, overlap_size_y)),
(width - overlap_size_x, 0),
)
alphaLayerLTaC = alphaLayerLTC.copy()
alphaLayerLTaC.paste(
agradientAsymC.rotate(270).resize((overlap_size_x, overlap_size_y)),
(width - overlap_size_x, 0),
)
if embiggen_tiles:
# Individual unconnected sides
alphaLayerR = Image.new("L", (width, height), 255)
alphaLayerR.paste(agradientL.rotate(180), (width - overlap_size_x, 0))
alphaLayerB = Image.new("L", (width, height), 255)
alphaLayerB.paste(agradientT.rotate(180), (0, height - overlap_size_y))
alphaLayerTB = Image.new("L", (width, height), 255)
alphaLayerTB.paste(agradientT, (0, 0))
alphaLayerTB.paste(agradientT.rotate(180), (0, height - overlap_size_y))
alphaLayerLR = Image.new("L", (width, height), 255)
alphaLayerLR.paste(agradientL, (0, 0))
alphaLayerLR.paste(agradientL.rotate(180), (width - overlap_size_x, 0))
# Sides and corner Layers
alphaLayerRBC = Image.new("L", (width, height), 255)
alphaLayerRBC.paste(agradientL.rotate(180), (width - overlap_size_x, 0))
alphaLayerRBC.paste(agradientT.rotate(180), (0, height - overlap_size_y))
alphaLayerRBC.paste(
agradientC.rotate(180).resize((overlap_size_x, overlap_size_y)),
(width - overlap_size_x, height - overlap_size_y),
)
alphaLayerLBC = Image.new("L", (width, height), 255)
alphaLayerLBC.paste(agradientL, (0, 0))
alphaLayerLBC.paste(agradientT.rotate(180), (0, height - overlap_size_y))
alphaLayerLBC.paste(
agradientC.rotate(90).resize((overlap_size_x, overlap_size_y)),
(0, height - overlap_size_y),
)
alphaLayerRTC = Image.new("L", (width, height), 255)
alphaLayerRTC.paste(agradientL.rotate(180), (width - overlap_size_x, 0))
alphaLayerRTC.paste(agradientT, (0, 0))
alphaLayerRTC.paste(
agradientC.rotate(270).resize((overlap_size_x, overlap_size_y)),
(width - overlap_size_x, 0),
)
# All but X layers
alphaLayerABT = Image.new("L", (width, height), 255)
alphaLayerABT.paste(alphaLayerLBC, (0, 0))
alphaLayerABT.paste(agradientL.rotate(180), (width - overlap_size_x, 0))
alphaLayerABT.paste(
agradientC.rotate(180).resize((overlap_size_x, overlap_size_y)),
(width - overlap_size_x, height - overlap_size_y),
)
alphaLayerABL = Image.new("L", (width, height), 255)
alphaLayerABL.paste(alphaLayerRTC, (0, 0))
alphaLayerABL.paste(agradientT.rotate(180), (0, height - overlap_size_y))
alphaLayerABL.paste(
agradientC.rotate(180).resize((overlap_size_x, overlap_size_y)),
(width - overlap_size_x, height - overlap_size_y),
)
alphaLayerABR = Image.new("L", (width, height), 255)
alphaLayerABR.paste(alphaLayerLBC, (0, 0))
alphaLayerABR.paste(agradientT, (0, 0))
alphaLayerABR.paste(
agradientC.resize((overlap_size_x, overlap_size_y)), (0, 0)
)
alphaLayerABB = Image.new("L", (width, height), 255)
alphaLayerABB.paste(alphaLayerRTC, (0, 0))
alphaLayerABB.paste(agradientL, (0, 0))
alphaLayerABB.paste(
agradientC.resize((overlap_size_x, overlap_size_y)), (0, 0)
)
# All-around layer
alphaLayerAA = Image.new("L", (width, height), 255)
alphaLayerAA.paste(alphaLayerABT, (0, 0))
alphaLayerAA.paste(agradientT, (0, 0))
alphaLayerAA.paste(
agradientC.resize((overlap_size_x, overlap_size_y)), (0, 0)
)
alphaLayerAA.paste(
agradientC.rotate(270).resize((overlap_size_x, overlap_size_y)),
(width - overlap_size_x, 0),
)
# Clean up temporary gradients
del agradientL
del agradientT
del agradientC
def make_image():
# Make main tiles -------------------------------------------------
if embiggen_tiles:
logger.info(f"Making {len(embiggen_tiles)} Embiggen tiles...")
else:
logger.info(
f"Making {(emb_tiles_x * emb_tiles_y)} Embiggen tiles ({emb_tiles_x}x{emb_tiles_y})..."
)
emb_tile_store = []
# Although we could use the same seed for every tile for determinism, at higher strengths this may
# produce duplicated structures for each tile and make the tiling effect more obvious
# instead track and iterate a local seed we pass to Img2Img
seed = self.seed
seedintlimit = (
np.iinfo(np.uint32).max - 1
) # only retreive this one from numpy
for tile in range(emb_tiles_x * emb_tiles_y):
# Don't iterate on first tile
if tile != 0:
if seed < seedintlimit:
seed += 1
else:
seed = 0
# Determine if this is a re-run and replace
if embiggen_tiles and not tile in embiggen_tiles:
continue
# Get row and column entries
emb_row_i = tile // emb_tiles_x
emb_column_i = tile % emb_tiles_x
# Determine bounds to cut up the init image
# Determine upper-left point
if emb_column_i + 1 == emb_tiles_x:
left = initsuperwidth - width
else:
left = round(emb_column_i * (width - overlap_size_x))
if emb_row_i + 1 == emb_tiles_y:
top = initsuperheight - height
else:
top = round(emb_row_i * (height - overlap_size_y))
right = left + width
bottom = top + height
# Cropped image of above dimension (does not modify the original)
newinitimage = initsuperimage.crop((left, top, right, bottom))
# DEBUG:
# newinitimagepath = init_img[0:-4] + f'_emb_Ti{tile}.png'
# newinitimage.save(newinitimagepath)
if embiggen_tiles:
logger.debug(
f"Making tile #{tile + 1} ({embiggen_tiles.index(tile) + 1} of {len(embiggen_tiles)} requested)"
)
else:
logger.debug(f"Starting {tile + 1} of {(emb_tiles_x * emb_tiles_y)} tiles")
# create a torch tensor from an Image
newinitimage = np.array(newinitimage).astype(np.float32) / 255.0
newinitimage = newinitimage[None].transpose(0, 3, 1, 2)
newinitimage = torch.from_numpy(newinitimage)
newinitimage = 2.0 * newinitimage - 1.0
newinitimage = newinitimage.to(self.model.device)
clear_cuda_cache = (
kwargs["clear_cuda_cache"] if "clear_cuda_cache" in kwargs else None
)
tile_results = gen_img2img.generate(
prompt,
iterations=1,
seed=seed,
sampler=sampler,
steps=steps,
cfg_scale=cfg_scale,
conditioning=conditioning,
ddim_eta=ddim_eta,
image_callback=None, # called only after the final image is generated
step_callback=step_callback, # called after each intermediate image is generated
width=width,
height=height,
init_image=newinitimage, # notice that init_image is different from init_img
mask_image=None,
strength=strength,
clear_cuda_cache=clear_cuda_cache,
)
emb_tile_store.append(tile_results[0][0])
# DEBUG (but, also has other uses), worth saving if you want tiles without a transparency overlap to manually composite
# emb_tile_store[-1].save(init_img[0:-4] + f'_emb_To{tile}.png')
del newinitimage
# Sanity check we have them all
if len(emb_tile_store) == (emb_tiles_x * emb_tiles_y) or (
embiggen_tiles != [] and len(emb_tile_store) == len(embiggen_tiles)
):
outputsuperimage = Image.new("RGBA", (initsuperwidth, initsuperheight))
if embiggen_tiles:
outputsuperimage.alpha_composite(
initsuperimage.convert("RGBA"), (0, 0)
)
for tile in range(emb_tiles_x * emb_tiles_y):
if embiggen_tiles:
if tile in embiggen_tiles:
intileimage = emb_tile_store.pop(0)
else:
continue
else:
intileimage = emb_tile_store[tile]
intileimage = intileimage.convert("RGBA")
# Get row and column entries
emb_row_i = tile // emb_tiles_x
emb_column_i = tile % emb_tiles_x
if emb_row_i == 0 and emb_column_i == 0 and not embiggen_tiles:
left = 0
top = 0
else:
# Determine upper-left point
if emb_column_i + 1 == emb_tiles_x:
left = initsuperwidth - width
else:
left = round(emb_column_i * (width - overlap_size_x))
if emb_row_i + 1 == emb_tiles_y:
top = initsuperheight - height
else:
top = round(emb_row_i * (height - overlap_size_y))
# Handle gradients for various conditions
# Handle emb_rerun case
if embiggen_tiles:
# top of image
if emb_row_i == 0:
if emb_column_i == 0:
if (tile + 1) in embiggen_tiles: # Look-ahead right
if (
tile + emb_tiles_x
) not in embiggen_tiles: # Look-ahead down
intileimage.putalpha(alphaLayerB)
# Otherwise do nothing on this tile
elif (
tile + emb_tiles_x
) in embiggen_tiles: # Look-ahead down only
intileimage.putalpha(alphaLayerR)
else:
intileimage.putalpha(alphaLayerRBC)
elif emb_column_i == emb_tiles_x - 1:
if (
tile + emb_tiles_x
) in embiggen_tiles: # Look-ahead down
intileimage.putalpha(alphaLayerL)
else:
intileimage.putalpha(alphaLayerLBC)
else:
if (tile + 1) in embiggen_tiles: # Look-ahead right
if (
tile + emb_tiles_x
) in embiggen_tiles: # Look-ahead down
intileimage.putalpha(alphaLayerL)
else:
intileimage.putalpha(alphaLayerLBC)
elif (
tile + emb_tiles_x
) in embiggen_tiles: # Look-ahead down only
intileimage.putalpha(alphaLayerLR)
else:
intileimage.putalpha(alphaLayerABT)
# bottom of image
elif emb_row_i == emb_tiles_y - 1:
if emb_column_i == 0:
if (tile + 1) in embiggen_tiles: # Look-ahead right
intileimage.putalpha(alphaLayerTaC)
else:
intileimage.putalpha(alphaLayerRTC)
elif emb_column_i == emb_tiles_x - 1:
# No tiles to look ahead to
intileimage.putalpha(alphaLayerLTC)
else:
if (tile + 1) in embiggen_tiles: # Look-ahead right
intileimage.putalpha(alphaLayerLTaC)
else:
intileimage.putalpha(alphaLayerABB)
# vertical middle of image
else:
if emb_column_i == 0:
if (tile + 1) in embiggen_tiles: # Look-ahead right
if (
tile + emb_tiles_x
) in embiggen_tiles: # Look-ahead down
intileimage.putalpha(alphaLayerTaC)
else:
intileimage.putalpha(alphaLayerTB)
elif (
tile + emb_tiles_x
) in embiggen_tiles: # Look-ahead down only
intileimage.putalpha(alphaLayerRTC)
else:
intileimage.putalpha(alphaLayerABL)
elif emb_column_i == emb_tiles_x - 1:
if (
tile + emb_tiles_x
) in embiggen_tiles: # Look-ahead down
intileimage.putalpha(alphaLayerLTC)
else:
intileimage.putalpha(alphaLayerABR)
else:
if (tile + 1) in embiggen_tiles: # Look-ahead right
if (
tile + emb_tiles_x
) in embiggen_tiles: # Look-ahead down
intileimage.putalpha(alphaLayerLTaC)
else:
intileimage.putalpha(alphaLayerABR)
elif (
tile + emb_tiles_x
) in embiggen_tiles: # Look-ahead down only
intileimage.putalpha(alphaLayerABB)
else:
intileimage.putalpha(alphaLayerAA)
# Handle normal tiling case (much simpler - since we tile left to right, top to bottom)
else:
if emb_row_i == 0 and emb_column_i >= 1:
intileimage.putalpha(alphaLayerL)
elif emb_row_i >= 1 and emb_column_i == 0:
if (
emb_column_i + 1 == emb_tiles_x
): # If we don't have anything that can be placed to the right
intileimage.putalpha(alphaLayerT)
else:
intileimage.putalpha(alphaLayerTaC)
else:
if (
emb_column_i + 1 == emb_tiles_x
): # If we don't have anything that can be placed to the right
intileimage.putalpha(alphaLayerLTC)
else:
intileimage.putalpha(alphaLayerLTaC)
# Layer tile onto final image
outputsuperimage.alpha_composite(intileimage, (left, top))
else:
logger.error(
"Could not find all Embiggen output tiles in memory? Something must have gone wrong with img2img generation."
)
# after internal loops and patching up return Embiggen image
return outputsuperimage
# end of function declaration
return make_image

View File

@ -22,7 +22,6 @@ class Img2Img(Generator):
def get_make_image(
self,
prompt,
sampler,
steps,
cfg_scale,

View File

@ -161,9 +161,7 @@ class Inpaint(Img2Img):
im: Image.Image,
seam_size: int,
seam_blur: int,
prompt,
seed,
sampler,
steps,
cfg_scale,
ddim_eta,
@ -177,8 +175,6 @@ class Inpaint(Img2Img):
mask = self.mask_edge(hard_mask, seam_size, seam_blur)
make_image = self.get_make_image(
prompt,
sampler,
steps,
cfg_scale,
ddim_eta,
@ -196,15 +192,13 @@ class Inpaint(Img2Img):
seam_noise = self.get_noise(im.width, im.height)
result = make_image(seam_noise, seed)
result = make_image(seam_noise, seed=None)
return result
@torch.no_grad()
def get_make_image(
self,
prompt,
sampler,
steps,
cfg_scale,
ddim_eta,
@ -306,7 +300,6 @@ class Inpaint(Img2Img):
# noinspection PyTypeChecker
pipeline: StableDiffusionGeneratorPipeline = self.model
pipeline.scheduler = sampler
# todo: support cross-attention control
uc, c, _ = conditioning
@ -345,9 +338,7 @@ class Inpaint(Img2Img):
result,
seam_size,
seam_blur,
prompt,
seed,
sampler,
seam_steps,
cfg_scale,
ddim_eta,
@ -360,8 +351,6 @@ class Inpaint(Img2Img):
# Restore original settings
self.get_make_image(
prompt,
sampler,
steps,
cfg_scale,
ddim_eta,

View File

@ -1,81 +0,0 @@
"""
invokeai.backend.generator.txt2img inherits from invokeai.backend.generator
"""
import PIL.Image
import torch
from ..stable_diffusion import (
ConditioningData,
PostprocessingSettings,
StableDiffusionGeneratorPipeline,
)
from .base import Generator
class Txt2Img(Generator):
def __init__(self, model, precision):
super().__init__(model, precision)
@torch.no_grad()
def get_make_image(
self,
prompt,
sampler,
steps,
cfg_scale,
ddim_eta,
conditioning,
width,
height,
step_callback=None,
threshold=0.0,
warmup=0.2,
perlin=0.0,
h_symmetry_time_pct=None,
v_symmetry_time_pct=None,
attention_maps_callback=None,
**kwargs,
):
"""
Returns a function returning an image derived from the prompt and the initial image
Return value depends on the seed at the time you call it
kwargs are 'width' and 'height'
"""
self.perlin = perlin
# noinspection PyTypeChecker
pipeline: StableDiffusionGeneratorPipeline = self.model
pipeline.scheduler = sampler
uc, c, extra_conditioning_info = conditioning
conditioning_data = ConditioningData(
uc,
c,
cfg_scale,
extra_conditioning_info,
postprocessing_settings=PostprocessingSettings(
threshold=threshold,
warmup=warmup,
h_symmetry_time_pct=h_symmetry_time_pct,
v_symmetry_time_pct=v_symmetry_time_pct,
),
).add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta)
def make_image(x_T: torch.Tensor, _: int) -> PIL.Image.Image:
pipeline_output = pipeline.image_from_embeddings(
latents=torch.zeros_like(x_T, dtype=self.torch_dtype()),
noise=x_T,
num_inference_steps=steps,
conditioning_data=conditioning_data,
callback=step_callback,
)
if (
pipeline_output.attention_map_saver is not None
and attention_maps_callback is not None
):
attention_maps_callback(pipeline_output.attention_map_saver)
return pipeline.numpy_to_pil(pipeline_output.images)[0]
return make_image

View File

@ -1,209 +0,0 @@
"""
invokeai.backend.generator.txt2img inherits from invokeai.backend.generator
"""
import math
from typing import Callable, Optional
import torch
from diffusers.utils.logging import get_verbosity, set_verbosity, set_verbosity_error
from ..stable_diffusion import PostprocessingSettings
from .base import Generator
from ..stable_diffusion.diffusers_pipeline import StableDiffusionGeneratorPipeline
from ..stable_diffusion.diffusers_pipeline import ConditioningData
from ..stable_diffusion.diffusers_pipeline import trim_to_multiple_of
import invokeai.backend.util.logging as logger
class Txt2Img2Img(Generator):
def __init__(self, model, precision):
super().__init__(model, precision)
self.init_latent = None # for get_noise()
def get_make_image(
self,
prompt: str,
sampler,
steps: int,
cfg_scale: float,
ddim_eta,
conditioning,
width: int,
height: int,
strength: float,
step_callback: Optional[Callable] = None,
threshold=0.0,
warmup=0.2,
perlin=0.0,
h_symmetry_time_pct=None,
v_symmetry_time_pct=None,
attention_maps_callback=None,
**kwargs,
):
"""
Returns a function returning an image derived from the prompt and the initial image
Return value depends on the seed at the time you call it
kwargs are 'width' and 'height'
"""
self.perlin = perlin
# noinspection PyTypeChecker
pipeline: StableDiffusionGeneratorPipeline = self.model
pipeline.scheduler = sampler
uc, c, extra_conditioning_info = conditioning
conditioning_data = ConditioningData(
uc,
c,
cfg_scale,
extra_conditioning_info,
postprocessing_settings=PostprocessingSettings(
threshold=threshold,
warmup=0.2,
h_symmetry_time_pct=h_symmetry_time_pct,
v_symmetry_time_pct=v_symmetry_time_pct,
),
).add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta)
def make_image(x_T: torch.Tensor, _: int):
first_pass_latent_output, _ = pipeline.latents_from_embeddings(
latents=torch.zeros_like(x_T),
num_inference_steps=steps,
conditioning_data=conditioning_data,
noise=x_T,
callback=step_callback,
)
# Get our initial generation width and height directly from the latent output so
# the message below is accurate.
init_width = first_pass_latent_output.size()[3] * self.downsampling_factor
init_height = first_pass_latent_output.size()[2] * self.downsampling_factor
logger.info(
f"Interpolating from {init_width}x{init_height} to {width}x{height} using DDIM sampling"
)
# resizing
resized_latents = torch.nn.functional.interpolate(
first_pass_latent_output,
size=(
height // self.downsampling_factor,
width // self.downsampling_factor,
),
mode="bilinear",
)
# Free up memory from the last generation.
clear_cuda_cache = kwargs["clear_cuda_cache"] or None
if clear_cuda_cache is not None:
clear_cuda_cache()
second_pass_noise = self.get_noise_like(
resized_latents, override_perlin=True
)
# Clear symmetry for the second pass
from dataclasses import replace
new_postprocessing_settings = replace(
conditioning_data.postprocessing_settings, h_symmetry_time_pct=None
)
new_postprocessing_settings = replace(
new_postprocessing_settings, v_symmetry_time_pct=None
)
new_conditioning_data = replace(
conditioning_data, postprocessing_settings=new_postprocessing_settings
)
verbosity = get_verbosity()
set_verbosity_error()
pipeline_output = pipeline.img2img_from_latents_and_embeddings(
resized_latents,
num_inference_steps=steps,
conditioning_data=new_conditioning_data,
strength=strength,
noise=second_pass_noise,
callback=step_callback,
)
set_verbosity(verbosity)
if (
pipeline_output.attention_map_saver is not None
and attention_maps_callback is not None
):
attention_maps_callback(pipeline_output.attention_map_saver)
return pipeline.numpy_to_pil(pipeline_output.images)[0]
# FIXME: do we really need something entirely different for the inpainting model?
# in the case of the inpainting model being loaded, the trick of
# providing an interpolated latent doesn't work, so we transiently
# create a 512x512 PIL image, upscale it, and run the inpainting
# over it in img2img mode. Because the inpaing model is so conservative
# it doesn't change the image (much)
return make_image
def get_noise_like(self, like: torch.Tensor, override_perlin: bool = False):
device = like.device
if device.type == "mps":
x = torch.randn_like(like, device="cpu", dtype=self.torch_dtype()).to(
device
)
else:
x = torch.randn_like(like, device=device, dtype=self.torch_dtype())
if self.perlin > 0.0 and override_perlin == False:
shape = like.shape
x = (1 - self.perlin) * x + self.perlin * self.get_perlin_noise(
shape[3], shape[2]
)
return x
# returns a tensor filled with random numbers from a normal distribution
def get_noise(self, width, height, scale=True):
# print(f"Get noise: {width}x{height}")
if scale:
# Scale the input width and height for the initial generation
# Make their area equivalent to the model's resolution area (e.g. 512*512 = 262144),
# while keeping the minimum dimension at least 0.5 * resolution (e.g. 512*0.5 = 256)
aspect = width / height
dimension = self.model.unet.config.sample_size * self.model.vae_scale_factor
min_dimension = math.floor(dimension * 0.5)
model_area = (
dimension * dimension
) # hardcoded for now since all models are trained on square images
if aspect > 1.0:
init_height = max(min_dimension, math.sqrt(model_area / aspect))
init_width = init_height * aspect
else:
init_width = max(min_dimension, math.sqrt(model_area * aspect))
init_height = init_width / aspect
scaled_width, scaled_height = trim_to_multiple_of(
math.floor(init_width), math.floor(init_height)
)
else:
scaled_width = width
scaled_height = height
device = self.model.device
channels = self.latent_channels
if channels == 9:
channels = 4 # we don't really want noise for all the mask channels
shape = (
1,
channels,
scaled_height // self.downsampling_factor,
scaled_width // self.downsampling_factor,
)
if self.use_mps_noise or device.type == "mps":
tensor = torch.empty(size=shape, device="cpu")
tensor = self.get_noise_like(like=tensor).to(device)
else:
tensor = torch.empty(size=shape, device=device)
tensor = self.get_noise_like(like=tensor)
return tensor

View File

@ -6,7 +6,8 @@ be suppressed or deferred
"""
import numpy as np
import invokeai.backend.util.logging as logger
from invokeai.app.services.config import get_invokeai_config
from invokeai.app.services.config import InvokeAIAppConfig
config = InvokeAIAppConfig.get_config()
class PatchMatch:
"""
@ -21,7 +22,6 @@ class PatchMatch:
@classmethod
def _load_patch_match(self):
config = get_invokeai_config()
if self.tried_load:
return
if config.try_patchmatch:

View File

@ -33,10 +33,11 @@ from PIL import Image, ImageOps
from transformers import AutoProcessor, CLIPSegForImageSegmentation
import invokeai.backend.util.logging as logger
from invokeai.app.services.config import get_invokeai_config
from invokeai.app.services.config import InvokeAIAppConfig
CLIPSEG_MODEL = "CIDAS/clipseg-rd64-refined"
CLIPSEG_SIZE = 352
config = InvokeAIAppConfig.get_config()
class SegmentedGrayscale(object):
def __init__(self, image: Image, heatmap: torch.Tensor):
@ -83,7 +84,6 @@ class Txt2Mask(object):
def __init__(self, device="cpu", refined=False):
logger.info("Initializing clipseg model for text to mask inference")
config = get_invokeai_config()
# BUG: we are not doing anything with the device option at this time
self.device = device

View File

@ -12,8 +12,8 @@ print("Loading Python libraries...\n",file=sys.stderr)
import argparse
import io
import os
import re
import shutil
import textwrap
import traceback
import warnings
from argparse import Namespace
@ -35,50 +35,44 @@ from transformers import (
CLIPTextModel,
CLIPTokenizer,
)
import invokeai.configs as configs
from invokeai.app.services.config import (
InvokeAIAppConfig,
)
from invokeai.backend.util.logging import InvokeAILogger
from invokeai.frontend.install.model_install import addModelsForm, process_and_execute
from invokeai.frontend.install.widgets import (
CenteredButtonPress,
IntTitleSlider,
set_min_terminal_size,
CyclingForm,
MIN_COLS,
MIN_LINES,
)
from invokeai.backend.config.legacy_arg_parsing import legacy_parser
from invokeai.backend.config.model_install_backend import (
from invokeai.backend.install.legacy_arg_parsing import legacy_parser
from invokeai.backend.install.model_install_backend import (
default_dataset,
download_from_hf,
hf_download_with_resume,
recommended_datasets,
)
from invokeai.app.services.config import (
get_invokeai_config,
InvokeAIAppConfig,
UserSelections,
)
warnings.filterwarnings("ignore")
transformers.logging.set_verbosity_error()
# --------------------------globals-----------------------
config = get_invokeai_config()
config = InvokeAIAppConfig.get_config()
Model_dir = "models"
Weights_dir = "ldm/stable-diffusion-v1/"
# the initial "configs" dir is now bundled in the `invokeai.configs` package
Dataset_path = Path(configs.__path__[0]) / "INITIAL_MODELS.yaml"
Default_config_file = config.model_conf_path
SD_Configs = config.legacy_conf_path
Datasets = OmegaConf.load(Dataset_path)
# minimum size for the UI
MIN_COLS = 135
MIN_LINES = 45
PRECISION_CHOICES = ['auto','float16','float32','autocast']
INIT_FILE_PREAMBLE = """# InvokeAI initialization file
@ -87,6 +81,7 @@ INIT_FILE_PREAMBLE = """# InvokeAI initialization file
# or renaming it and then running invokeai-configure again.
"""
logger=None
# --------------------------------------------
def postscript(errors: None):
@ -103,7 +98,7 @@ Command-line client:
invokeai
If you installed using an installation script, run:
{config.root}/invoke.{"bat" if sys.platform == "win32" else "sh"}
{config.root_path}/invoke.{"bat" if sys.platform == "win32" else "sh"}
Add the '--help' argument to see all of the command-line switches available for use.
"""
@ -215,16 +210,11 @@ def download_realesrgan():
model_url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth"
wdn_model_url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth"
model_dest = os.path.join(
config.root, "models/realesrgan/realesr-general-x4v3.pth"
)
model_dest = config.root_path / "models/realesrgan/realesr-general-x4v3.pth"
wdn_model_dest = config.root_path / "models/realesrgan/realesr-general-wdn-x4v3.pth"
wdn_model_dest = os.path.join(
config.root, "models/realesrgan/realesr-general-wdn-x4v3.pth"
)
download_with_progress_bar(model_url, model_dest, "RealESRGAN")
download_with_progress_bar(wdn_model_url, wdn_model_dest, "RealESRGANwdn")
download_with_progress_bar(model_url, str(model_dest), "RealESRGAN")
download_with_progress_bar(wdn_model_url, str(wdn_model_dest), "RealESRGANwdn")
def download_gfpgan():
@ -243,8 +233,8 @@ def download_gfpgan():
"./models/gfpgan/weights/parsing_parsenet.pth",
],
):
model_url, model_dest = model[0], os.path.join(config.root, model[1])
download_with_progress_bar(model_url, model_dest, "GFPGAN weights")
model_url, model_dest = model[0], config.root_path / model[1]
download_with_progress_bar(model_url, str(model_dest), "GFPGAN weights")
# ---------------------------------------------
@ -253,8 +243,8 @@ def download_codeformer():
model_url = (
"https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth"
)
model_dest = os.path.join(config.root, "models/codeformer/codeformer.pth")
download_with_progress_bar(model_url, model_dest, "CodeFormer")
model_dest = config.root_path / "models/codeformer/codeformer.pth"
download_with_progress_bar(model_url, str(model_dest), "CodeFormer")
# ---------------------------------------------
@ -306,7 +296,7 @@ def download_vaes():
if not hf_download_with_resume(
repo_id=repo_id,
model_name=model_name,
model_dir=str(config.root / Model_dir / Weights_dir),
model_dir=str(config.root_path / Model_dir / Weights_dir),
):
raise Exception(f"download of {model_name} failed")
except Exception as e:
@ -321,24 +311,24 @@ def get_root(root: str = None) -> str:
elif os.environ.get("INVOKEAI_ROOT"):
return os.environ.get("INVOKEAI_ROOT")
else:
return config.root
return str(config.root_path)
# -------------------------------------
class editOptsForm(npyscreen.FormMultiPage):
class editOptsForm(CyclingForm, npyscreen.FormMultiPage):
# for responsive resizing - disabled
# FIX_MINIMUM_SIZE_WHEN_CREATED = False
def create(self):
program_opts = self.parentApp.program_opts
old_opts = self.parentApp.invokeai_opts
first_time = not (config.root / 'invokeai.yaml').exists()
first_time = not (config.root_path / 'invokeai.yaml').exists()
access_token = HfFolder.get_token()
window_width, window_height = get_terminal_size()
for i in [
"Configure startup settings. You can come back and change these later.",
"Use ctrl-N and ctrl-P to move to the <N>ext and <P>revious fields.",
"Use cursor arrows to make a checkbox selection, and space to toggle.",
]:
label = """Configure startup settings. You can come back and change these later.
Use ctrl-N and ctrl-P to move to the <N>ext and <P>revious fields.
Use cursor arrows to make a checkbox selection, and space to toggle.
"""
for i in textwrap.wrap(label,width=window_width-6):
self.add_widget_intelligent(
npyscreen.FixedText,
value=i,
@ -365,7 +355,7 @@ class editOptsForm(npyscreen.FormMultiPage):
self.outdir = self.add_widget_intelligent(
npyscreen.TitleFilename,
name="(<tab> autocompletes, ctrl-N advances):",
value=str(old_opts.outdir) or str(default_output_dir()),
value=str(default_output_dir()),
select_dir=True,
must_exist=False,
use_two_lines=False,
@ -388,14 +378,13 @@ class editOptsForm(npyscreen.FormMultiPage):
scroll_exit=True,
)
self.nextrely += 1
for i in [
"If you have an account at HuggingFace you may optionally paste your access token here",
'to allow InvokeAI to download restricted styles & subjects from the "Concept Library".',
"See https://huggingface.co/settings/tokens",
]:
label = """If you have an account at HuggingFace you may optionally paste your access token here
to allow InvokeAI to download restricted styles & subjects from the "Concept Library". See https://huggingface.co/settings/tokens.
"""
for line in textwrap.wrap(label,width=window_width-6):
self.add_widget_intelligent(
npyscreen.FixedText,
value=i,
value=line,
editable=False,
color="CONTROL",
)
@ -472,7 +461,7 @@ class editOptsForm(npyscreen.FormMultiPage):
self.nextrely += 1
self.add_widget_intelligent(
npyscreen.FixedText,
value="Directories containing textual inversion and LoRA models (<tab> autocompletes, ctrl-N advances):",
value="Directories containing textual inversion, controlnet and LoRA models (<tab> autocompletes, ctrl-N advances):",
editable=False,
color="CONTROL",
)
@ -498,6 +487,17 @@ class editOptsForm(npyscreen.FormMultiPage):
begin_entry_at=32,
scroll_exit=True,
)
self.controlnet_dir = self.add_widget_intelligent(
npyscreen.TitleFilename,
name=" ControlNets:",
value=str(default_controlnet_dir()),
select_dir=True,
must_exist=False,
use_two_lines=False,
labelColor="GOOD",
begin_entry_at=32,
scroll_exit=True,
)
self.nextrely += 1
self.add_widget_intelligent(
npyscreen.TitleFixedText,
@ -508,11 +508,11 @@ class editOptsForm(npyscreen.FormMultiPage):
scroll_exit=True,
)
self.nextrely -= 1
for i in [
"BY DOWNLOADING THE STABLE DIFFUSION WEIGHT FILES, YOU AGREE TO HAVE READ",
"AND ACCEPTED THE CREATIVEML RESPONSIBLE AI LICENSE LOCATED AT",
"https://huggingface.co/spaces/CompVis/stable-diffusion-license",
]:
label = """BY DOWNLOADING THE STABLE DIFFUSION WEIGHT FILES, YOU AGREE TO HAVE READ
AND ACCEPTED THE CREATIVEML RESPONSIBLE AI LICENSE LOCATED AT
https://huggingface.co/spaces/CompVis/stable-diffusion-license
"""
for i in textwrap.wrap(label,width=window_width-6):
self.add_widget_intelligent(
npyscreen.FixedText,
value=i,
@ -551,7 +551,7 @@ class editOptsForm(npyscreen.FormMultiPage):
self.editing = False
else:
self.editing = True
def validate_field_values(self, opt: Namespace) -> bool:
bad_fields = []
if not opt.license_acceptance:
@ -587,6 +587,7 @@ class editOptsForm(npyscreen.FormMultiPage):
"always_use_cpu",
"embedding_dir",
"lora_dir",
"controlnet_dir",
]:
setattr(new_opts, attr, getattr(self, attr).value)
@ -614,6 +615,7 @@ class EditOptApplication(npyscreen.NPSAppManaged):
"MAIN",
editOptsForm,
name="InvokeAI Startup Options",
cycle_widgets=True,
)
if not (self.program_opts.skip_sd_weights or self.program_opts.default_only):
self.model_select = self.addForm(
@ -621,6 +623,7 @@ class EditOptApplication(npyscreen.NPSAppManaged):
addModelsForm,
name="Install Stable Diffusion Models",
multipage=True,
cycle_widgets=True,
)
def new_opts(self):
@ -634,17 +637,14 @@ def edit_opts(program_opts: Namespace, invokeai_opts: Namespace) -> argparse.Nam
def default_startup_options(init_file: Path) -> Namespace:
opts = InvokeAIAppConfig(argv=[])
outdir = Path(opts.outdir)
if not outdir.is_absolute():
opts.outdir = str(config.root / opts.outdir)
opts = InvokeAIAppConfig.get_config()
if not init_file.exists():
opts.nsfw_checker = True
return opts
def default_user_selections(program_opts: Namespace) -> Namespace:
return Namespace(
starter_models=default_dataset()
def default_user_selections(program_opts: Namespace) -> UserSelections:
return UserSelections(
install_models=default_dataset()
if program_opts.default_only
else recommended_datasets()
if program_opts.yes_to_all
@ -652,26 +652,27 @@ def default_user_selections(program_opts: Namespace) -> Namespace:
purge_deleted_models=False,
scan_directory=None,
autoscan_on_startup=None,
import_model_paths=None,
convert_to_diffusers=None,
)
# -------------------------------------
def initialize_rootdir(root: str, yes_to_all: bool = False):
def initialize_rootdir(root: Path, yes_to_all: bool = False):
print("** INITIALIZING INVOKEAI RUNTIME DIRECTORY **")
for name in (
"models",
"configs",
"embeddings",
"text-inversion-output",
"text-inversion-training-data",
"models",
"configs",
"embeddings",
"databases",
"loras",
"controlnets",
"text-inversion-output",
"text-inversion-training-data",
):
os.makedirs(os.path.join(root, name), exist_ok=True)
configs_src = Path(configs.__path__[0])
configs_dest = Path(root) / "configs"
configs_dest = root / "configs"
if not os.path.samefile(configs_src, configs_dest):
shutil.copytree(configs_src, configs_dest, dirs_exist_ok=True)
@ -682,8 +683,17 @@ def run_console_ui(
) -> (Namespace, Namespace):
# parse_args() will read from init file if present
invokeai_opts = default_startup_options(initfile)
invokeai_opts.root = program_opts.root
set_min_terminal_size(MIN_COLS, MIN_LINES)
# The third argument is needed in the Windows 11 environment to
# launch a console window running this program.
set_min_terminal_size(MIN_COLS, MIN_LINES,'invokeai-configure')
# the install-models application spawns a subprocess to install
# models, and will crash unless this is set before running.
import torch
torch.multiprocessing.set_start_method("spawn")
editApp = EditOptApplication(program_opts, invokeai_opts)
editApp.run()
if editApp.user_cancelled:
@ -697,27 +707,32 @@ def write_opts(opts: Namespace, init_file: Path):
"""
Update the invokeai.yaml file with values from current settings.
"""
# this will load current settings
config = InvokeAIAppConfig()
new_config = InvokeAIAppConfig.get_config()
new_config.root = config.root
for key,value in opts.__dict__.items():
if hasattr(config,key):
setattr(config,key,value)
if hasattr(new_config,key):
setattr(new_config,key,value)
with open(init_file,'w', encoding='utf-8') as file:
file.write(config.to_yaml())
file.write(new_config.to_yaml())
# -------------------------------------
def default_output_dir() -> Path:
return config.root / "outputs"
return config.root_path / "outputs"
# -------------------------------------
def default_embedding_dir() -> Path:
return config.root / "embeddings"
return config.root_path / "embeddings"
# -------------------------------------
def default_lora_dir() -> Path:
return config.root / "loras"
return config.root_path / "loras"
# -------------------------------------
def default_controlnet_dir() -> Path:
return config.root_path / "controlnets"
# -------------------------------------
def write_default_options(program_opts: Namespace, initfile: Path):
@ -731,7 +746,7 @@ def write_default_options(program_opts: Namespace, initfile: Path):
# yaml format.
def migrate_init_file(legacy_format:Path):
old = legacy_parser.parse_args([f'@{str(legacy_format)}'])
new = InvokeAIAppConfig(conf={})
new = InvokeAIAppConfig.get_config()
fields = list(get_type_hints(InvokeAIAppConfig).keys())
for attr in fields:
@ -805,9 +820,13 @@ def main():
)
opt = parser.parse_args()
# setting a global here
global config
config.root = Path(os.path.expanduser(get_root(opt.root) or ""))
invoke_args = []
if opt.root:
invoke_args.extend(['--root',opt.root])
if opt.full_precision:
invoke_args.extend(['--precision','float32'])
config.parse_args(invoke_args)
logger = InvokeAILogger().getLogger(config=config)
errors = set()
@ -815,16 +834,16 @@ def main():
models_to_download = default_user_selections(opt)
# We check for to see if the runtime directory is correctly initialized.
old_init_file = Path(config.root, 'invokeai.init')
new_init_file = Path(config.root, 'invokeai.yaml')
old_init_file = config.root_path / 'invokeai.init'
new_init_file = config.root_path / 'invokeai.yaml'
if old_init_file.exists() and not new_init_file.exists():
print('** Migrating invokeai.init to invokeai.yaml')
migrate_init_file(old_init_file)
config = get_invokeai_config() # reread defaults
# Load new init file into config
config.parse_args(argv=[],conf=OmegaConf.load(new_init_file))
if not config.model_conf_path.exists():
initialize_rootdir(config.root, opt.yes_to_all)
initialize_rootdir(config.root_path, opt.yes_to_all)
if opt.yes_to_all:
write_default_options(opt, new_init_file)
@ -844,7 +863,7 @@ def main():
if opt.skip_support_models:
print("\n** SKIPPING SUPPORT MODEL DOWNLOADS PER USER REQUEST **")
else:
print("\n** DOWNLOADING SUPPORT MODELS **")
print("\n** CHECKING/UPDATING SUPPORT MODELS **")
download_bert()
download_sd1_clip()
download_sd2_clip()
@ -862,6 +881,8 @@ def main():
process_and_execute(opt, models_to_download)
postscript(errors=errors)
if not opt.yes_to_all:
input('Press any key to continue...')
except KeyboardInterrupt:
print("\nGoodbye! Come back soon.")

View File

@ -9,6 +9,7 @@ SAMPLER_CHOICES = [
"ddpm",
"deis",
"lms",
"lms_k",
"pndm",
"heun",
"heun_k",
@ -18,8 +19,13 @@ SAMPLER_CHOICES = [
"kdpm_2",
"kdpm_2_a",
"dpmpp_2s",
"dpmpp_2s_k",
"dpmpp_2m",
"dpmpp_2m_k",
"dpmpp_2m_sde",
"dpmpp_2m_sde_k",
"dpmpp_sde",
"dpmpp_sde_k",
"unipc",
]

View File

@ -6,28 +6,30 @@ import re
import shutil
import sys
import warnings
from dataclasses import dataclass,field
from pathlib import Path
from tempfile import TemporaryFile
from typing import List
from typing import List, Dict, Callable
import requests
from diffusers import AutoencoderKL
from huggingface_hub import hf_hub_url
from huggingface_hub import hf_hub_url, HfFolder
from omegaconf import OmegaConf
from omegaconf.dictconfig import DictConfig
from tqdm import tqdm
import invokeai.configs as configs
from invokeai.app.services.config import get_invokeai_config
from ..model_management import ModelManager
from ..stable_diffusion import StableDiffusionGeneratorPipeline
from invokeai.app.services.config import InvokeAIAppConfig
from ..stable_diffusion import StableDiffusionGeneratorPipeline
from ..util.logging import InvokeAILogger
warnings.filterwarnings("ignore")
# --------------------------globals-----------------------
config = get_invokeai_config()
config = InvokeAIAppConfig.get_config()
Model_dir = "models"
Weights_dir = "ldm/stable-diffusion-v1/"
@ -37,6 +39,9 @@ Dataset_path = Path(configs.__path__[0]) / "INITIAL_MODELS.yaml"
# initial models omegaconf
Datasets = None
# logger
logger = InvokeAILogger.getLogger(name='InvokeAI')
Config_preamble = """
# This file describes the alternative machine learning models
# available to InvokeAI script.
@ -47,11 +52,30 @@ Config_preamble = """
# was trained on.
"""
@dataclass
class ModelInstallList:
'''Class for listing models to be installed/removed'''
install_models: List[str] = field(default_factory=list)
remove_models: List[str] = field(default_factory=list)
@dataclass
class UserSelections():
install_models: List[str]= field(default_factory=list)
remove_models: List[str]=field(default_factory=list)
purge_deleted_models: bool=field(default_factory=list)
install_cn_models: List[str] = field(default_factory=list)
remove_cn_models: List[str] = field(default_factory=list)
install_lora_models: List[str] = field(default_factory=list)
remove_lora_models: List[str] = field(default_factory=list)
install_ti_models: List[str] = field(default_factory=list)
remove_ti_models: List[str] = field(default_factory=list)
scan_directory: Path = None
autoscan_on_startup: bool=False
import_model_paths: str=None
def default_config_file():
return config.model_conf_path
def sd_configs():
return config.legacy_conf_path
@ -59,45 +83,67 @@ def initial_models():
global Datasets
if Datasets:
return Datasets
return (Datasets := OmegaConf.load(Dataset_path))
return (Datasets := OmegaConf.load(Dataset_path)['diffusers'])
def install_requested_models(
install_initial_models: List[str] = None,
remove_models: List[str] = None,
scan_directory: Path = None,
external_models: List[str] = None,
scan_at_startup: bool = False,
precision: str = "float16",
purge_deleted: bool = False,
config_file_path: Path = None,
diffusers: ModelInstallList = None,
controlnet: ModelInstallList = None,
lora: ModelInstallList = None,
ti: ModelInstallList = None,
cn_model_map: Dict[str,str] = None, # temporary - move to model manager
scan_directory: Path = None,
external_models: List[str] = None,
scan_at_startup: bool = False,
precision: str = "float16",
purge_deleted: bool = False,
config_file_path: Path = None,
model_config_file_callback: Callable[[Path],Path] = None
):
"""
Entry point for installing/deleting starter models, or installing external models.
"""
access_token = HfFolder.get_token()
config_file_path = config_file_path or default_config_file()
if not config_file_path.exists():
open(config_file_path, "w")
# prevent circular import here
from ..model_management import ModelManager
model_manager = ModelManager(OmegaConf.load(config_file_path), precision=precision)
if controlnet:
model_manager.install_controlnet_models(controlnet.install_models, access_token=access_token)
model_manager.delete_controlnet_models(controlnet.remove_models)
if remove_models and len(remove_models) > 0:
print("== DELETING UNCHECKED STARTER MODELS ==")
for model in remove_models:
print(f"{model}...")
model_manager.del_model(model, delete_files=purge_deleted)
model_manager.commit(config_file_path)
if lora:
model_manager.install_lora_models(lora.install_models, access_token=access_token)
model_manager.delete_lora_models(lora.remove_models)
if install_initial_models and len(install_initial_models) > 0:
print("== INSTALLING SELECTED STARTER MODELS ==")
successfully_downloaded = download_weight_datasets(
models=install_initial_models,
access_token=None,
precision=precision,
) # FIX: for historical reasons, we don't use model manager here
update_config_file(successfully_downloaded, config_file_path)
if len(successfully_downloaded) < len(install_initial_models):
print("** Some of the model downloads were not successful")
if ti:
model_manager.install_ti_models(ti.install_models, access_token=access_token)
model_manager.delete_ti_models(ti.remove_models)
if diffusers:
# TODO: Replace next three paragraphs with calls into new model manager
if diffusers.remove_models and len(diffusers.remove_models) > 0:
logger.info("Processing requested deletions")
for model in diffusers.remove_models:
logger.info(f"{model}...")
model_manager.del_model(model, delete_files=purge_deleted)
model_manager.commit(config_file_path)
if diffusers.install_models and len(diffusers.install_models) > 0:
logger.info("Installing requested models")
downloaded_paths = download_weight_datasets(
models=diffusers.install_models,
access_token=None,
precision=precision,
)
successful = {x:v for x,v in downloaded_paths.items() if v is not None}
if len(successful) > 0:
update_config_file(successful, config_file_path)
if len(successful) < len(diffusers.install_models):
unsuccessful = [x for x in downloaded_paths if downloaded_paths[x] is None]
logger.warning(f"Some of the model downloads were not successful: {unsuccessful}")
# due to above, we have to reload the model manager because conf file
# was changed behind its back
@ -108,12 +154,14 @@ def install_requested_models(
external_models.append(str(scan_directory))
if len(external_models) > 0:
print("== INSTALLING EXTERNAL MODELS ==")
logger.info("INSTALLING EXTERNAL MODELS")
for path_url_or_repo in external_models:
try:
logger.debug(f'In install_requested_models; callback = {model_config_file_callback}')
model_manager.heuristic_import(
path_url_or_repo,
commit_to_conf=config_file_path,
config_file_callback = model_config_file_callback,
)
except KeyboardInterrupt:
sys.exit(-1)
@ -121,18 +169,22 @@ def install_requested_models(
pass
if scan_at_startup and scan_directory.is_dir():
argument = "--autoconvert"
print('** The global initfile is no longer supported; rewrite to support new yaml format **')
initfile = Path(config.root, 'invokeai.init')
replacement = Path(config.root, f"invokeai.init.new")
directory = str(scan_directory).replace("\\", "/")
with open(initfile, "r") as input:
with open(replacement, "w") as output:
while line := input.readline():
if not line.startswith(argument):
output.writelines([line])
output.writelines([f"{argument} {directory}"])
os.replace(replacement, initfile)
update_autoconvert_dir(scan_directory)
else:
update_autoconvert_dir(None)
def update_autoconvert_dir(autodir: Path):
'''
Update the "autoconvert_dir" option in invokeai.yaml
'''
invokeai_config_path = config.init_file_path
conf = OmegaConf.load(invokeai_config_path)
conf.InvokeAI.Paths.autoconvert_dir = str(autodir) if autodir else None
yaml = OmegaConf.to_yaml(conf)
tmpfile = invokeai_config_path.parent / "new_config.tmp"
with open(tmpfile, "w", encoding="utf-8") as outfile:
outfile.write(yaml)
tmpfile.replace(invokeai_config_path)
# -------------------------------------
@ -144,33 +196,21 @@ def yes_or_no(prompt: str, default_yes=True):
else:
return response[0] in ("y", "Y")
# -------------------------------------
def get_root(root: str = None) -> str:
if root:
return root
elif os.environ.get("INVOKEAI_ROOT"):
return os.environ.get("INVOKEAI_ROOT")
else:
return config.root
# ---------------------------------------------
def recommended_datasets() -> dict:
datasets = dict()
def recommended_datasets() -> List['str']:
datasets = set()
for ds in initial_models().keys():
if initial_models()[ds].get("recommended", False):
datasets[ds] = True
return datasets
datasets.add(ds)
return list(datasets)
# ---------------------------------------------
def default_dataset() -> dict:
datasets = dict()
datasets = set()
for ds in initial_models().keys():
if initial_models()[ds].get("default", False):
datasets[ds] = True
return datasets
datasets.add(ds)
return list(datasets)
# ---------------------------------------------
@ -185,14 +225,14 @@ def all_datasets() -> dict:
# look for legacy model.ckpt in models directory and offer to
# normalize its name
def migrate_models_ckpt():
model_path = os.path.join(config.root, Model_dir, Weights_dir)
model_path = os.path.join(config.root_dir, Model_dir, Weights_dir)
if not os.path.exists(os.path.join(model_path, "model.ckpt")):
return
new_name = initial_models()["stable-diffusion-1.4"]["file"]
print(
logger.warning(
'The Stable Diffusion v4.1 "model.ckpt" is already installed. The name will be changed to {new_name} to avoid confusion.'
)
print(f"model.ckpt => {new_name}")
logger.warning(f"model.ckpt => {new_name}")
os.replace(
os.path.join(model_path, "model.ckpt"), os.path.join(model_path, new_name)
)
@ -205,7 +245,7 @@ def download_weight_datasets(
migrate_models_ckpt()
successful = dict()
for mod in models:
print(f"Downloading {mod}:")
logger.info(f"Downloading {mod}:")
successful[mod] = _download_repo_or_file(
initial_models()[mod], access_token, precision=precision
)
@ -226,11 +266,10 @@ def _download_repo_or_file(
)
return path
def _download_ckpt_weights(mconfig: DictConfig, access_token: str) -> Path:
repo_id = mconfig["repo_id"]
filename = mconfig["file"]
cache_dir = os.path.join(config.root, Model_dir, Weights_dir)
cache_dir = os.path.join(config.root_dir, Model_dir, Weights_dir)
return hf_download_with_resume(
repo_id=repo_id,
model_dir=cache_dir,
@ -243,6 +282,9 @@ def _download_ckpt_weights(mconfig: DictConfig, access_token: str) -> Path:
def download_from_hf(
model_class: object, model_name: str, **kwargs
):
logger = InvokeAILogger.getLogger('InvokeAI')
logger.addFilter(lambda x: 'fp16 is not a valid' not in x.getMessage())
path = config.cache_dir
model = model_class.from_pretrained(
model_name,
@ -274,10 +316,10 @@ def _download_diffusion_weights(
**extra_args,
)
except OSError as e:
if str(e).startswith("fp16 is not a valid"):
if 'Revision Not Found' in str(e):
pass
else:
print(f"An unexpected error occurred while downloading the model: {e})")
logger.error(str(e))
if path:
break
return path
@ -285,9 +327,13 @@ def _download_diffusion_weights(
# ---------------------------------------------
def hf_download_with_resume(
repo_id: str, model_dir: str, model_name: str, access_token: str = None
repo_id: str,
model_dir: str,
model_name: str,
model_dest: Path = None,
access_token: str = None,
) -> Path:
model_dest = Path(os.path.join(model_dir, model_name))
model_dest = model_dest or Path(os.path.join(model_dir, model_name))
os.makedirs(model_dir, exist_ok=True)
url = hf_hub_url(repo_id, model_name)
@ -307,20 +353,19 @@ def hf_download_with_resume(
if (
resp.status_code == 416
): # "range not satisfiable", which means nothing to return
print(f"* {model_name}: complete file found. Skipping.")
logger.info(f"{model_name}: complete file found. Skipping.")
return model_dest
elif resp.status_code == 404:
logger.warning("File not found")
return None
elif resp.status_code != 200:
print(f"** An error occurred during downloading {model_name}: {resp.reason}")
logger.warning(f"{model_name}: {resp.reason}")
elif exist_size > 0:
print(f"* {model_name}: partial file found. Resuming...")
logger.info(f"{model_name}: partial file found. Resuming...")
else:
print(f"* {model_name}: Downloading...")
logger.info(f"{model_name}: Downloading...")
try:
if total < 2000:
print(f"*** ERROR DOWNLOADING {model_name}: {resp.text}")
return None
with open(model_dest, open_mode) as file, tqdm(
desc=model_name,
initial=exist_size,
@ -333,7 +378,7 @@ def hf_download_with_resume(
size = file.write(data)
bar.update(size)
except Exception as e:
print(f"An error occurred while downloading {model_name}: {str(e)}")
logger.error(f"An error occurred while downloading {model_name}: {str(e)}")
return None
return model_dest
@ -358,8 +403,8 @@ def update_config_file(successfully_downloaded: dict, config_file: Path):
try:
backup = None
if os.path.exists(config_file):
print(
f"** {config_file.name} exists. Renaming to {config_file.stem}.yaml.orig"
logger.warning(
f"{config_file.name} exists. Renaming to {config_file.stem}.yaml.orig"
)
backup = config_file.with_suffix(".yaml.orig")
## Ugh. Windows is unable to overwrite an existing backup file, raises a WinError 183
@ -376,16 +421,16 @@ def update_config_file(successfully_downloaded: dict, config_file: Path):
new_config.write(tmp.read())
except Exception as e:
print(f"**Error creating config file {config_file}: {str(e)} **")
logger.error(f"Error creating config file {config_file}: {str(e)}")
if backup is not None:
print("restoring previous config file")
logger.info("restoring previous config file")
## workaround, for WinError 183, see above
if sys.platform == "win32" and config_file.is_file():
config_file.unlink()
backup.rename(config_file)
return
print(f"Successfully created new configuration file {config_file}")
logger.info(f"Successfully created new configuration file {config_file}")
# ---------------------------------------------
@ -419,7 +464,7 @@ def new_config_file_contents(
stanza["height"] = mod["height"]
if "file" in mod:
stanza["weights"] = os.path.relpath(
successfully_downloaded[model], start=config.root
successfully_downloaded[model], start=config.root_dir
)
stanza["config"] = os.path.normpath(
os.path.join(sd_configs(), mod["config"])
@ -452,14 +497,14 @@ def delete_weights(model_name: str, conf_stanza: dict):
if re.match("/VAE/", conf_stanza.get("config")):
return
print(
f"\n** The checkpoint version of {model_name} is superseded by the diffusers version. Deleting the original file {weights}?"
logger.warning(
f"\nThe checkpoint version of {model_name} is superseded by the diffusers version. Deleting the original file {weights}?"
)
weights = Path(weights)
if not weights.is_absolute():
weights = Path(config.root) / weights
weights = config.root_dir / weights
try:
weights.unlink()
except OSError as e:
print(str(e))
logger.error(str(e))

View File

@ -1,11 +1,6 @@
"""
Initialization file for invokeai.backend.model_management
"""
from .convert_ckpt_to_diffusers import (
convert_ckpt_to_diffusers,
load_pipeline_from_original_stable_diffusion_ckpt,
)
from .model_manager import ModelManager,SDModelComponent
from .model_manager import ModelManager, ModelInfo
from .model_cache import ModelCache
from .models import BaseModelType, ModelType, SubModelType, ModelVariantType

View File

@ -26,12 +26,15 @@ import torch
from safetensors.torch import load_file
import invokeai.backend.util.logging as logger
from invokeai.app.services.config import get_invokeai_config
from invokeai.app.services.config import InvokeAIAppConfig
from .model_manager import ModelManager, SDLegacyType
from .model_manager import ModelManager
from .model_cache import ModelCache
from .models import SchedulerPredictionType, BaseModelType, ModelVariantType
try:
from omegaconf import OmegaConf
from omegaconf.dictconfig import DictConfig
except ImportError:
raise ImportError(
"OmegaConf is required to convert the LDM checkpoints. Please install it with `pip install OmegaConf`."
@ -56,10 +59,6 @@ from diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion import (
LDMBertConfig,
LDMBertModel,
)
from diffusers.pipelines.paint_by_example import (
PaintByExampleImageEncoder,
PaintByExamplePipeline,
)
from diffusers.pipelines.stable_diffusion.safety_checker import (
StableDiffusionSafetyChecker,
)
@ -74,6 +73,8 @@ from transformers import (
from ..stable_diffusion import StableDiffusionGeneratorPipeline
MODEL_ROOT = None
def shave_segments(path, n_shave_prefix_segments=1):
"""
Removes segments. Positive values shave the first segments, negative shave the last segments.
@ -612,16 +613,29 @@ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False
return new_checkpoint
def convert_ldm_vae_checkpoint(checkpoint, config):
# extract state dict for VAE
vae_state_dict = {}
vae_key = "first_stage_model."
keys = list(checkpoint.keys())
for key in keys:
if key.startswith(vae_key):
vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
# Extract state dict for VAE. Works both with burnt-in
# VAEs, and with standalone VAEs.
# checkpoint can either be a all-in-one stable diffusion
# model, or an isolated vae .ckpt. This tests for
# a key that will be present in the all-in-one model
# that isn't present in the isolated ckpt.
probe_key = "first_stage_model.encoder.conv_in.weight"
if probe_key in checkpoint:
vae_state_dict = {}
vae_key = "first_stage_model."
keys = list(checkpoint.keys())
for key in keys:
if key.startswith(vae_key):
vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
else:
vae_state_dict = checkpoint
new_checkpoint = convert_ldm_vae_state_dict(vae_state_dict,config)
return new_checkpoint
def convert_ldm_vae_state_dict(vae_state_dict, config):
new_checkpoint = {}
new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
@ -841,10 +855,7 @@ def convert_ldm_bert_checkpoint(checkpoint, config):
def convert_ldm_clip_checkpoint(checkpoint):
text_model = CLIPTextModel.from_pretrained(
"openai/clip-vit-large-patch14", cache_dir=get_invokeai_config().cache_dir
)
text_model = CLIPTextModel.from_pretrained(MODEL_ROOT / 'clip-vit-large-patch14')
keys = list(checkpoint.keys())
text_model_dict = {}
@ -896,82 +907,10 @@ protected = {re.escape(x[0]): x[1] for x in textenc_transformer_conversion_lst}
textenc_pattern = re.compile("|".join(protected.keys()))
def convert_paint_by_example_checkpoint(checkpoint):
cache_dir = get_invokeai_config().cache_dir
config = CLIPVisionConfig.from_pretrained(
"openai/clip-vit-large-patch14", cache_dir=cache_dir
)
model = PaintByExampleImageEncoder(config)
keys = list(checkpoint.keys())
text_model_dict = {}
for key in keys:
if key.startswith("cond_stage_model.transformer"):
text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[
key
]
# load clip vision
model.model.load_state_dict(text_model_dict)
# load mapper
keys_mapper = {
k[len("cond_stage_model.mapper.res") :]: v
for k, v in checkpoint.items()
if k.startswith("cond_stage_model.mapper")
}
MAPPING = {
"attn.c_qkv": ["attn1.to_q", "attn1.to_k", "attn1.to_v"],
"attn.c_proj": ["attn1.to_out.0"],
"ln_1": ["norm1"],
"ln_2": ["norm3"],
"mlp.c_fc": ["ff.net.0.proj"],
"mlp.c_proj": ["ff.net.2"],
}
mapped_weights = {}
for key, value in keys_mapper.items():
prefix = key[: len("blocks.i")]
suffix = key.split(prefix)[-1].split(".")[-1]
name = key.split(prefix)[-1].split(suffix)[0][1:-1]
mapped_names = MAPPING[name]
num_splits = len(mapped_names)
for i, mapped_name in enumerate(mapped_names):
new_name = ".".join([prefix, mapped_name, suffix])
shape = value.shape[0] // num_splits
mapped_weights[new_name] = value[i * shape : (i + 1) * shape]
model.mapper.load_state_dict(mapped_weights)
# load final layer norm
model.final_layer_norm.load_state_dict(
{
"bias": checkpoint["cond_stage_model.final_ln.bias"],
"weight": checkpoint["cond_stage_model.final_ln.weight"],
}
)
# load final proj
model.proj_out.load_state_dict(
{
"bias": checkpoint["proj_out.bias"],
"weight": checkpoint["proj_out.weight"],
}
)
# load uncond vector
model.uncond_vector.data = torch.nn.Parameter(checkpoint["learnable_vector"])
return model
def convert_open_clip_checkpoint(checkpoint):
cache_dir = get_invokeai_config().cache_dir
text_model = CLIPTextModel.from_pretrained(
"stabilityai/stable-diffusion-2", subfolder="text_encoder", cache_dir=cache_dir
MODEL_ROOT / 'stable-diffusion-2-clip',
subfolder='text_encoder',
)
keys = list(checkpoint.keys())
@ -1047,22 +986,30 @@ def replace_checkpoint_vae(checkpoint, vae_path:str):
new_key = f'first_stage_model.{vae_key}'
checkpoint[new_key] = state_dict[vae_key]
def convert_ldm_vae_to_diffusers(checkpoint, vae_config: DictConfig, image_size: int)->AutoencoderKL:
vae_config = create_vae_diffusers_config(
vae_config, image_size=image_size
)
converted_vae_checkpoint = convert_ldm_vae_checkpoint(
checkpoint, vae_config
)
vae = AutoencoderKL(**vae_config)
vae.load_state_dict(converted_vae_checkpoint)
return vae
def load_pipeline_from_original_stable_diffusion_ckpt(
checkpoint_path: str,
original_config_file: str = None,
num_in_channels: int = None,
scheduler_type: str = "pndm",
pipeline_type: str = None,
image_size: int = None,
prediction_type: str = None,
model_version: BaseModelType,
model_variant: ModelVariantType,
original_config_file: str,
extract_ema: bool = True,
upcast_attn: bool = False,
vae: AutoencoderKL = None,
vae_path: str = None,
precision: torch.dtype = torch.float32,
return_generator_pipeline: bool = False,
scan_needed:bool=True,
) -> Union[StableDiffusionPipeline, StableDiffusionGeneratorPipeline]:
upcast_attention: bool = False,
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon,
scan_needed: bool = True,
) -> StableDiffusionPipeline:
"""
Load a Stable Diffusion pipeline object from a CompVis-style `.ckpt`/`.safetensors` file and (ideally) a `.yaml`
config file.
@ -1074,147 +1021,68 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
:param checkpoint_path: Path to `.ckpt` file.
:param original_config_file: Path to `.yaml` config file corresponding to the original architecture.
If `None`, will be automatically inferred by looking for a key that only exists in SD2.0 models.
:param image_size: The image size that the model was trained on. Use 512 for Stable Diffusion v1.X and Stable Diffusion v2
Base. Use 768 for Stable Diffusion v2.
:param prediction_type: The prediction type that the model was trained on. Use `'epsilon'` for Stable Diffusion
v1.X and Stable Diffusion v2 Base. Use `'v-prediction'` for Stable Diffusion v2.
:param num_in_channels: The number of input channels. If `None` number of input channels will be automatically
inferred.
:param scheduler_type: Type of scheduler to use. Should be one of `["pndm", "lms", "heun", "euler",
"euler-ancestral", "dpm", "ddim"]`. :param model_type: The pipeline type. `None` to automatically infer, or one of
`["FrozenOpenCLIPEmbedder", "FrozenCLIPEmbedder", "PaintByExample"]`. :param extract_ema: Only relevant for
`["FrozenOpenCLIPEmbedder", "FrozenCLIPEmbedder"]`. :param extract_ema: Only relevant for
checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights
or not. Defaults to `False`. Pass `True` to extract the EMA weights. EMA weights usually yield higher
quality images for inference. Non-EMA weights are usually better to continue fine-tuning.
:param precision: precision to use - torch.float16, torch.float32 or torch.autocast
:param upcast_attention: Whether the attention computation should always be upcasted. This is necessary when
running stable diffusion 2.1.
:param vae: A diffusers VAE to load into the pipeline.
:param vae_path: Path to a checkpoint VAE that will be converted into diffusers and loaded into the pipeline.
"""
config = get_invokeai_config()
config = InvokeAIAppConfig.get_config()
with warnings.catch_warnings():
warnings.simplefilter("ignore")
verbosity = dlogging.get_verbosity()
dlogging.set_verbosity_error()
if Path(checkpoint_path).suffix == '.ckpt':
if scan_needed:
ModelManager.scan_model(checkpoint_path,checkpoint_path)
checkpoint = torch.load(checkpoint_path)
else:
if str(checkpoint_path).endswith(".safetensors"):
checkpoint = load_file(checkpoint_path)
cache_dir = config.cache_dir
pipeline_class = (
StableDiffusionGeneratorPipeline
if return_generator_pipeline
else StableDiffusionPipeline
)
# Sometimes models don't have the global_step item
if "global_step" in checkpoint:
global_step = checkpoint["global_step"]
else:
logger.debug("global_step key not found in model")
global_step = None
if scan_needed:
ModelCache.scan_model(checkpoint_path, checkpoint_path)
checkpoint = torch.load(checkpoint_path)
# sometimes there is a state_dict key and sometimes not
if "state_dict" in checkpoint:
checkpoint = checkpoint["state_dict"]
upcast_attention = False
if original_config_file is None:
model_type = ModelManager.probe_model_type(checkpoint)
if model_type == SDLegacyType.V2_v:
original_config_file = (
config.legacy_conf_path / "v2-inference-v.yaml"
)
if global_step == 110000:
# v2.1 needs to upcast attention
upcast_attention = True
elif model_type == SDLegacyType.V2_e:
original_config_file = (
config.legacy_conf_path / "v2-inference.yaml"
)
elif model_type == SDLegacyType.V1_INPAINT:
original_config_file = (
config.legacy_conf_path / "v1-inpainting-inference.yaml"
)
elif model_type == SDLegacyType.V1:
original_config_file = (
config.legacy_conf_path / "v1-inference.yaml"
)
else:
raise Exception("Unknown checkpoint type")
original_config = OmegaConf.load(original_config_file)
if num_in_channels is not None:
original_config["model"]["params"]["unet_config"]["params"][
"in_channels"
] = num_in_channels
if (
"parameterization" in original_config["model"]["params"]
and original_config["model"]["params"]["parameterization"] == "v"
):
if prediction_type is None:
# NOTE: For stable diffusion 2 base it is recommended to pass `prediction_type=="epsilon"`
# as it relies on a brittle global step parameter here
prediction_type = "epsilon" if global_step == 875000 else "v_prediction"
if image_size is None:
# NOTE: For stable diffusion 2 base one has to pass `image_size==512`
# as it relies on a brittle global step parameter here
image_size = 512 if global_step == 875000 else 768
if model_version == BaseModelType.StableDiffusion2 and prediction_type == SchedulerPredictionType.VPrediction:
image_size = 768
else:
if prediction_type is None:
prediction_type = "epsilon"
if image_size is None:
image_size = 512
image_size = 512
#
# convert scheduler
#
num_train_timesteps = original_config.model.params.timesteps
beta_start = original_config.model.params.linear_start
beta_end = original_config.model.params.linear_end
scheduler = DDIMScheduler(
scheduler = PNDMScheduler(
beta_end=beta_end,
beta_schedule="scaled_linear",
beta_start=beta_start,
num_train_timesteps=num_train_timesteps,
steps_offset=1,
clip_sample=False,
set_alpha_to_one=False,
prediction_type=prediction_type,
skip_prk_steps=True
)
# make sure scheduler works correctly with DDIM
scheduler.register_to_config(clip_sample=False)
if scheduler_type == "pndm":
config = dict(scheduler.config)
config["skip_prk_steps"] = True
scheduler = PNDMScheduler.from_config(config)
elif scheduler_type == "lms":
scheduler = LMSDiscreteScheduler.from_config(scheduler.config)
elif scheduler_type == "heun":
scheduler = HeunDiscreteScheduler.from_config(scheduler.config)
elif scheduler_type == "euler":
scheduler = EulerDiscreteScheduler.from_config(scheduler.config)
elif scheduler_type == "euler-ancestral":
scheduler = EulerAncestralDiscreteScheduler.from_config(scheduler.config)
elif scheduler_type == "dpm":
scheduler = DPMSolverMultistepScheduler.from_config(scheduler.config)
elif scheduler_type == 'unipc':
scheduler = UniPCMultistepScheduler.from_config(scheduler.config)
elif scheduler_type == "ddim":
scheduler = scheduler
else:
raise ValueError(f"Scheduler of type {scheduler_type} doesn't exist!")
#
# convert unet
#
# Convert the UNet2DConditionModel model.
unet_config = create_unet_diffusers_config(
original_config, image_size=image_size
)
@ -1227,44 +1095,25 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
unet.load_state_dict(converted_unet_checkpoint)
# If a replacement VAE path was specified, we'll incorporate that into
# the checkpoint model and then convert it
if vae_path:
logger.debug(f"Converting VAE {vae_path}")
replace_checkpoint_vae(checkpoint,vae_path)
# otherwise we use the original VAE, provided that
# an externally loaded diffusers VAE was not passed
elif not vae:
logger.debug("Using checkpoint model's original VAE")
#
# convert vae
#
if vae:
logger.debug("Using replacement diffusers VAE")
else: # convert the original or replacement VAE
vae_config = create_vae_diffusers_config(
original_config, image_size=image_size
)
converted_vae_checkpoint = convert_ldm_vae_checkpoint(
checkpoint, vae_config
)
vae = AutoencoderKL(**vae_config)
vae.load_state_dict(converted_vae_checkpoint)
vae = convert_ldm_vae_to_diffusers(
checkpoint,
original_config,
image_size,
)
# Convert the text model.
model_type = pipeline_type
if model_type is None:
model_type = original_config.model.params.cond_stage_config.target.split(
"."
)[-1]
model_type = original_config.model.params.cond_stage_config.target.split(".")[-1]
if model_type == "FrozenOpenCLIPEmbedder":
text_model = convert_open_clip_checkpoint(checkpoint)
tokenizer = CLIPTokenizer.from_pretrained(
"stabilityai/stable-diffusion-2",
subfolder="tokenizer",
cache_dir=cache_dir,
MODEL_ROOT / 'stable-diffusion-2-clip',
subfolder='tokenizer',
)
pipe = pipeline_class(
pipe = StableDiffusionPipeline(
vae=vae.to(precision),
text_encoder=text_model.to(precision),
tokenizer=tokenizer,
@ -1274,49 +1123,26 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
feature_extractor=None,
requires_safety_checker=False,
)
elif model_type == "PaintByExample":
vision_model = convert_paint_by_example_checkpoint(checkpoint)
tokenizer = CLIPTokenizer.from_pretrained(
"openai/clip-vit-large-patch14", cache_dir=cache_dir
)
feature_extractor = AutoFeatureExtractor.from_pretrained(
"CompVis/stable-diffusion-safety-checker", cache_dir=cache_dir
)
pipe = PaintByExamplePipeline(
vae=vae,
image_encoder=vision_model,
unet=unet,
scheduler=scheduler,
safety_checker=None,
feature_extractor=feature_extractor,
)
elif model_type in ["FrozenCLIPEmbedder", "WeightedFrozenCLIPEmbedder"]:
text_model = convert_ldm_clip_checkpoint(checkpoint)
tokenizer = CLIPTokenizer.from_pretrained(
"openai/clip-vit-large-patch14", cache_dir=cache_dir
)
safety_checker = StableDiffusionSafetyChecker.from_pretrained(
"CompVis/stable-diffusion-safety-checker",
cache_dir=config.cache_dir,
)
feature_extractor = AutoFeatureExtractor.from_pretrained(
"CompVis/stable-diffusion-safety-checker", cache_dir=cache_dir
)
pipe = pipeline_class(
tokenizer = CLIPTokenizer.from_pretrained(MODEL_ROOT / 'clip-vit-large-patch14')
safety_checker = StableDiffusionSafetyChecker.from_pretrained(MODEL_ROOT / 'stable-diffusion-safety-checker')
feature_extractor = AutoFeatureExtractor.from_pretrained(MODEL_ROOT / 'stable-diffusion-safety-checker')
pipe = StableDiffusionPipeline(
vae=vae.to(precision),
text_encoder=text_model.to(precision),
tokenizer=tokenizer,
unet=unet.to(precision),
scheduler=scheduler,
safety_checker=None if return_generator_pipeline else safety_checker.to(precision),
safety_checker=safety_checker.to(precision),
feature_extractor=feature_extractor,
)
else:
text_config = create_ldm_bert_config(original_config)
text_model = convert_ldm_bert_checkpoint(checkpoint, text_config)
tokenizer = BertTokenizerFast.from_pretrained(
"bert-base-uncased", cache_dir=cache_dir
)
tokenizer = BertTokenizerFast.from_pretrained(MODEL_ROOT / "bert-base-uncased")
pipe = LDMTextToImagePipeline(
vqvae=vae,
bert=text_model,
@ -1330,15 +1156,19 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
def convert_ckpt_to_diffusers(
checkpoint_path: Union[str, Path],
dump_path: Union[str, Path],
**kwargs,
checkpoint_path: Union[str, Path],
dump_path: Union[str, Path],
model_root: Union[str, Path],
**kwargs,
):
"""
Takes all the arguments of load_pipeline_from_original_stable_diffusion_ckpt(),
and in addition a path-like object indicating the location of the desired diffusers
model to be written.
"""
# setting global here to avoid massive changes late at night
global MODEL_ROOT
MODEL_ROOT = Path(model_root) / 'core/convert'
pipe = load_pipeline_from_original_stable_diffusion_ckpt(checkpoint_path, **kwargs)
pipe.save_pretrained(

View File

@ -0,0 +1,678 @@
from __future__ import annotations
import copy
from pathlib import Path
from contextlib import contextmanager
from typing import Optional, Dict, Tuple, Any
import torch
from safetensors.torch import load_file
from torch.utils.hooks import RemovableHandle
from diffusers.models import UNet2DConditionModel
from transformers import CLIPTextModel
from compel.embeddings_provider import BaseTextualInversionManager
class LoRALayerBase:
#rank: Optional[int]
#alpha: Optional[float]
#bias: Optional[torch.Tensor]
#layer_key: str
#@property
#def scale(self):
# return self.alpha / self.rank if (self.alpha and self.rank) else 1.0
def __init__(
self,
layer_key: str,
values: dict,
):
if "alpha" in values:
self.alpha = values["alpha"].item()
else:
self.alpha = None
if (
"bias_indices" in values
and "bias_values" in values
and "bias_size" in values
):
self.bias = torch.sparse_coo_tensor(
values["bias_indices"],
values["bias_values"],
tuple(values["bias_size"]),
)
else:
self.bias = None
self.rank = None # set in layer implementation
self.layer_key = layer_key
def forward(
self,
module: torch.nn.Module,
input_h: Any, # for real looks like Tuple[torch.nn.Tensor] but not sure
multiplier: float,
):
if type(module) == torch.nn.Conv2d:
op = torch.nn.functional.conv2d
extra_args = dict(
stride=module.stride,
padding=module.padding,
dilation=module.dilation,
groups=module.groups,
)
else:
op = torch.nn.functional.linear
extra_args = {}
weight = self.get_weight(module)
bias = self.bias if self.bias is not None else 0
scale = self.alpha / self.rank if (self.alpha and self.rank) else 1.0
return op(
*input_h,
(weight + bias).view(module.weight.shape),
None,
**extra_args,
) * multiplier * scale
def get_weight(self, module: torch.nn.Module):
raise NotImplementedError()
def calc_size(self) -> int:
model_size = 0
for val in [self.bias]:
if val is not None:
model_size += val.nelement() * val.element_size()
return model_size
def to(
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
if self.bias is not None:
self.bias = self.bias.to(device=device, dtype=dtype)
# TODO: find and debug lora/locon with bias
class LoRALayer(LoRALayerBase):
#up: torch.Tensor
#mid: Optional[torch.Tensor]
#down: torch.Tensor
def __init__(
self,
layer_key: str,
values: dict,
):
super().__init__(layer_key, values)
self.up = values["lora_up.weight"]
self.down = values["lora_down.weight"]
if "lora_mid.weight" in values:
self.mid = values["lora_mid.weight"]
else:
self.mid = None
self.rank = self.down.shape[0]
def get_weight(self, module: torch.nn.Module):
if self.mid is not None:
up = self.up.reshape(up.shape[0], up.shape[1])
down = self.down.reshape(up.shape[0], up.shape[1])
weight = torch.einsum("m n w h, i m, n j -> i j w h", self.mid, up, down)
else:
weight = self.up.reshape(self.up.shape[0], -1) @ self.down.reshape(self.down.shape[0], -1)
return weight
def calc_size(self) -> int:
model_size = super().calc_size()
for val in [self.up, self.mid, self.down]:
if val is not None:
model_size += val.nelement() * val.element_size()
return model_size
def to(
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
super().to(device=device, dtype=dtype)
self.up = self.up.to(device=device, dtype=dtype)
self.down = self.down.to(device=device, dtype=dtype)
if self.mid is not None:
self.mid = self.mid.to(device=device, dtype=dtype)
class LoHALayer(LoRALayerBase):
#w1_a: torch.Tensor
#w1_b: torch.Tensor
#w2_a: torch.Tensor
#w2_b: torch.Tensor
#t1: Optional[torch.Tensor] = None
#t2: Optional[torch.Tensor] = None
def __init__(
self,
layer_key: str,
values: dict,
):
super().__init__(module_key, rank, alpha, bias)
self.w1_a = values["hada_w1_a"]
self.w1_b = values["hada_w1_b"]
self.w2_a = values["hada_w2_a"]
self.w2_b = values["hada_w2_b"]
if "hada_t1" in values:
self.t1 = values["hada_t1"]
else:
self.t1 = None
if "hada_t2" in values:
self.t2 = values["hada_t2"]
else:
self.t2 = None
self.rank = self.w1_b.shape[0]
def get_weight(self, module: torch.nn.Module):
if self.t1 is None:
weight = (self.w1_a @ self.w1_b) * (self.w2_a @ self.w2_b)
else:
rebuild1 = torch.einsum(
"i j k l, j r, i p -> p r k l", self.t1, self.w1_b, self.w1_a
)
rebuild2 = torch.einsum(
"i j k l, j r, i p -> p r k l", self.t2, self.w2_b, self.w2_a
)
weight = rebuild1 * rebuild2
return weight
def calc_size(self) -> int:
model_size = super().calc_size()
for val in [self.w1_a, self.w1_b, self.w2_a, self.w2_b, self.t1, self.t2]:
if val is not None:
model_size += val.nelement() * val.element_size()
return model_size
def to(
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
super().to(device=device, dtype=dtype)
self.w1_a = self.w1_a.to(device=device, dtype=dtype)
self.w1_b = self.w1_b.to(device=device, dtype=dtype)
if self.t1 is not None:
self.t1 = self.t1.to(device=device, dtype=dtype)
self.w2_a = self.w2_a.to(device=device, dtype=dtype)
self.w2_b = self.w2_b.to(device=device, dtype=dtype)
if self.t2 is not None:
self.t2 = self.t2.to(device=device, dtype=dtype)
class LoKRLayer(LoRALayerBase):
#w1: Optional[torch.Tensor] = None
#w1_a: Optional[torch.Tensor] = None
#w1_b: Optional[torch.Tensor] = None
#w2: Optional[torch.Tensor] = None
#w2_a: Optional[torch.Tensor] = None
#w2_b: Optional[torch.Tensor] = None
#t2: Optional[torch.Tensor] = None
def __init__(
self,
layer_key: str,
values: dict,
):
super().__init__(module_key, rank, alpha, bias)
if "lokr_w1" in values:
self.w1 = values["lokr_w1"]
self.w1_a = None
self.w1_b = None
else:
self.w1 = None
self.w1_a = values["lokr_w1_a"]
self.w1_b = values["lokr_w1_b"]
if "lokr_w2" in values:
self.w2 = values["lokr_w2"]
self.w2_a = None
self.w2_b = None
else:
self.w2 = None
self.w2_a = values["lokr_w2_a"]
self.w2_b = values["lokr_w2_b"]
if "lokr_t2" in values:
self.t2 = values["lokr_t2"]
else:
self.t2 = None
if "lokr_w1_b" in values:
self.rank = values["lokr_w1_b"].shape[0]
elif "lokr_w2_b" in values:
self.rank = values["lokr_w2_b"].shape[0]
else:
self.rank = None # unscaled
def get_weight(self, module: torch.nn.Module):
w1 = self.w1
if w1 is None:
w1 = self.w1_a @ self.w1_b
w2 = self.w2
if w2 is None:
if self.t2 is None:
w2 = self.w2_a @ self.w2_b
else:
w2 = torch.einsum('i j k l, i p, j r -> p r k l', self.t2, self.w2_a, self.w2_b)
if len(w2.shape) == 4:
w1 = w1.unsqueeze(2).unsqueeze(2)
w2 = w2.contiguous()
weight = torch.kron(w1, w2).reshape(module.weight.shape) # TODO: can we remove reshape?
return weight
def calc_size(self) -> int:
model_size = super().calc_size()
for val in [self.w1, self.w1_a, self.w1_b, self.w2, self.w2_a, self.w2_b, self.t2]:
if val is not None:
model_size += val.nelement() * val.element_size()
return model_size
def to(
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
super().to(device=device, dtype=dtype)
if self.w1 is not None:
self.w1 = self.w1.to(device=device, dtype=dtype)
else:
self.w1_a = self.w1_a.to(device=device, dtype=dtype)
self.w1_b = self.w1_b.to(device=device, dtype=dtype)
if self.w2 is not None:
self.w2 = self.w2.to(device=device, dtype=dtype)
else:
self.w2_a = self.w2_a.to(device=device, dtype=dtype)
self.w2_b = self.w2_b.to(device=device, dtype=dtype)
if self.t2 is not None:
self.t2 = self.t2.to(device=device, dtype=dtype)
class LoRAModel: #(torch.nn.Module):
_name: str
layers: Dict[str, LoRALayer]
_device: torch.device
_dtype: torch.dtype
def __init__(
self,
name: str,
layers: Dict[str, LoRALayer],
device: torch.device,
dtype: torch.dtype,
):
self._name = name
self._device = device or torch.cpu
self._dtype = dtype or torch.float32
self.layers = layers
@property
def name(self):
return self._name
@property
def device(self):
return self._device
@property
def dtype(self):
return self._dtype
def to(
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
) -> LoRAModel:
# TODO: try revert if exception?
for key, layer in self.layers.items():
layer.to(device=device, dtype=dtype)
self._device = device
self._dtype = dtype
def calc_size(self) -> int:
model_size = 0
for _, layer in self.layers.items():
model_size += layer.calc_size()
return model_size
@classmethod
def from_checkpoint(
cls,
file_path: Union[str, Path],
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
device = device or torch.device("cpu")
dtype = dtype or torch.float32
if isinstance(file_path, str):
file_path = Path(file_path)
model = cls(
device=device,
dtype=dtype,
name=file_path.stem, # TODO:
layers=dict(),
)
if file_path.suffix == ".safetensors":
state_dict = load_file(file_path.absolute().as_posix(), device="cpu")
else:
state_dict = torch.load(file_path, map_location="cpu")
state_dict = cls._group_state(state_dict)
for layer_key, values in state_dict.items():
# lora and locon
if "lora_down.weight" in values:
layer = LoRALayer(layer_key, values)
# loha
elif "hada_w1_b" in values:
layer = LoHALayer(layer_key, values)
# lokr
elif "lokr_w1_b" in values or "lokr_w1" in values:
layer = LoKRLayer(layer_key, values)
else:
# TODO: diff/ia3/... format
print(
f">> Encountered unknown lora layer module in {self.name}: {layer_key}"
)
return
# lower memory consumption by removing already parsed layer values
state_dict[layer_key].clear()
layer.to(device=device, dtype=dtype)
model.layers[layer_key] = layer
return model
@staticmethod
def _group_state(state_dict: dict):
state_dict_groupped = dict()
for key, value in state_dict.items():
stem, leaf = key.split(".", 1)
if stem not in state_dict_groupped:
state_dict_groupped[stem] = dict()
state_dict_groupped[stem][leaf] = value
return state_dict_groupped
"""
loras = [
(lora_model1, 0.7),
(lora_model2, 0.4),
]
with LoRAHelper.apply_lora_unet(unet, loras):
# unet with applied loras
# unmodified unet
"""
# TODO: rename smth like ModelPatcher and add TI method?
class ModelPatcher:
@staticmethod
def _resolve_lora_key(model: torch.nn.Module, lora_key: str, prefix: str) -> Tuple[str, torch.nn.Module]:
assert "." not in lora_key
if not lora_key.startswith(prefix):
raise Exception(f"lora_key with invalid prefix: {lora_key}, {prefix}")
module = model
module_key = ""
key_parts = lora_key[len(prefix):].split('_')
submodule_name = key_parts.pop(0)
while len(key_parts) > 0:
try:
module = module.get_submodule(submodule_name)
module_key += "." + submodule_name
submodule_name = key_parts.pop(0)
except:
submodule_name += "_" + key_parts.pop(0)
module = module.get_submodule(submodule_name)
module_key = module_key.rstrip(".")
return (module_key, module)
@staticmethod
def _lora_forward_hook(
applied_loras: List[Tuple[LoraModel, float]],
layer_name: str,
):
def lora_forward(module, input_h, output):
if len(applied_loras) == 0:
return output
for lora, weight in applied_loras:
layer = lora.layers.get(layer_name, None)
if layer is None:
continue
output += layer.forward(module, input_h, weight)
return output
return lora_forward
@classmethod
@contextmanager
def apply_lora_unet(
cls,
unet: UNet2DConditionModel,
loras: List[Tuple[LoRAModel, float]],
):
with cls.apply_lora(unet, loras, "lora_unet_"):
yield
@classmethod
@contextmanager
def apply_lora_text_encoder(
cls,
text_encoder: CLIPTextModel,
loras: List[Tuple[LoRAModel, float]],
):
with cls.apply_lora(text_encoder, loras, "lora_te_"):
yield
@classmethod
@contextmanager
def apply_lora(
cls,
model: torch.nn.Module,
loras: List[Tuple[LoraModel, float]],
prefix: str,
):
hooks = dict()
try:
for lora, lora_weight in loras:
for layer_key, layer in lora.layers.items():
if not layer_key.startswith(prefix):
continue
module_key, module = cls._resolve_lora_key(model, layer_key, prefix)
if module_key not in hooks:
hooks[module_key] = module.register_forward_hook(cls._lora_forward_hook(loras, layer_key))
yield # wait for context manager exit
finally:
for module_key, hook in hooks.items():
hook.remove()
hooks.clear()
@classmethod
@contextmanager
def apply_ti(
cls,
tokenizer: CLIPTokenizer,
text_encoder: CLIPTextModel,
ti_list: List[Any],
) -> Tuple[CLIPTokenizer, TextualInversionManager]:
init_tokens_count = None
new_tokens_added = None
try:
ti_tokenizer = copy.deepcopy(tokenizer)
ti_manager = TextualInversionManager(ti_tokenizer)
init_tokens_count = text_encoder.resize_token_embeddings(None).num_embeddings
def _get_trigger(ti, index):
trigger = ti.name
if index > 0:
trigger += f"-!pad-{i}"
return f"<{trigger}>"
# modify tokenizer
new_tokens_added = 0
for ti in ti_list:
for i in range(ti.embedding.shape[0]):
new_tokens_added += ti_tokenizer.add_tokens(_get_trigger(ti, i))
# modify text_encoder
text_encoder.resize_token_embeddings(init_tokens_count + new_tokens_added)
model_embeddings = text_encoder.get_input_embeddings()
for ti in ti_list:
ti_tokens = []
for i in range(ti.embedding.shape[0]):
embedding = ti.embedding[i]
trigger = _get_trigger(ti, i)
token_id = ti_tokenizer.convert_tokens_to_ids(trigger)
if token_id == ti_tokenizer.unk_token_id:
raise RuntimeError(f"Unable to find token id for token '{trigger}'")
if model_embeddings.weight.data[token_id].shape != embedding.shape:
raise ValueError(
f"Cannot load embedding for {trigger}. It was trained on a model with token dimension {embedding.shape[0]}, but the current model has token dimension {model_embeddings.weight.data[token_id].shape[0]}."
)
model_embeddings.weight.data[token_id] = embedding
ti_tokens.append(token_id)
if len(ti_tokens) > 1:
ti_manager.pad_tokens[ti_tokens[0]] = ti_tokens[1:]
yield ti_tokenizer, ti_manager
finally:
if init_tokens_count and new_tokens_added:
text_encoder.resize_token_embeddings(init_tokens_count)
class TextualInversionModel:
name: str
embedding: torch.Tensor # [n, 768]|[n, 1280]
@classmethod
def from_checkpoint(
cls,
file_path: Union[str, Path],
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
if not isinstance(file_path, Path):
file_path = Path(file_path)
result = cls() # TODO:
result.name = file_path.stem # TODO:
if file_path.suffix == ".safetensors":
state_dict = load_file(file_path.absolute().as_posix(), device="cpu")
else:
state_dict = torch.load(file_path, map_location="cpu")
# both v1 and v2 format embeddings
# difference mostly in metadata
if "string_to_param" in state_dict:
if len(state_dict["string_to_param"]) > 1:
print(f"Warn: Embedding \"{file_path.name}\" contains multiple tokens, which is not supported. The first token will be used.")
result.embedding = next(iter(state_dict["string_to_param"].values()))
# v3 (easynegative)
elif "emb_params" in state_dict:
result.embedding = state_dict["emb_params"]
# v4(diffusers bin files)
else:
result.embedding = next(iter(state_dict.values()))
if not isinstance(result.embedding, torch.Tensor):
raise ValueError(f"Invalid embeddings file: {file_path.name}")
return result
class TextualInversionManager(BaseTextualInversionManager):
pad_tokens: Dict[int, List[int]]
tokenizer: CLIPTokenizer
def __init__(self, tokenizer: CLIPTokenizer):
self.pad_tokens = dict()
self.tokenizer = tokenizer
def expand_textual_inversion_token_ids_if_necessary(
self, token_ids: list[int]
) -> list[int]:
if len(self.pad_tokens) == 0:
return token_ids
if token_ids[0] == self.tokenizer.bos_token_id:
raise ValueError("token_ids must not start with bos_token_id")
if token_ids[-1] == self.tokenizer.eos_token_id:
raise ValueError("token_ids must not end with eos_token_id")
new_token_ids = []
for token_id in token_ids:
new_token_ids.append(token_id)
if token_id in self.pad_tokens:
new_token_ids.extend(self.pad_tokens[token_id])
return new_token_ids

View File

@ -0,0 +1,391 @@
"""
Manage a RAM cache of diffusion/transformer models for fast switching.
They are moved between GPU VRAM and CPU RAM as necessary. If the cache
grows larger than a preset maximum, then the least recently used
model will be cleared and (re)loaded from disk when next needed.
The cache returns context manager generators designed to load the
model into the GPU within the context, and unload outside the
context. Use like this:
cache = ModelCache(max_models_cached=6)
with cache.get_model('runwayml/stable-diffusion-1-5') as SD1,
cache.get_model('stabilityai/stable-diffusion-2') as SD2:
do_something_in_GPU(SD1,SD2)
"""
import gc
import os
import sys
import hashlib
from contextlib import suppress
from pathlib import Path
from typing import Dict, Union, types, Optional, Type, Any
import torch
import logging
import invokeai.backend.util.logging as logger
from invokeai.app.services.config import get_invokeai_config
from .lora import LoRAModel, TextualInversionModel
from .models import BaseModelType, ModelType, SubModelType, ModelBase
# Maximum size of the cache, in gigs
# Default is roughly enough to hold three fp16 diffusers models in RAM simultaneously
DEFAULT_MAX_CACHE_SIZE = 6.0
# actual size of a gig
GIG = 1073741824
class ModelLocker(object):
"Forward declaration"
pass
class ModelCache(object):
"Forward declaration"
pass
class _CacheRecord:
size: int
model: Any
cache: ModelCache
_locks: int
def __init__(self, cache, model: Any, size: int):
self.size = size
self.model = model
self.cache = cache
self._locks = 0
def lock(self):
self._locks += 1
def unlock(self):
self._locks -= 1
assert self._locks >= 0
@property
def locked(self):
return self._locks > 0
@property
def loaded(self):
if self.model is not None and hasattr(self.model, "device"):
return self.model.device != self.cache.storage_device
else:
return False
class ModelCache(object):
def __init__(
self,
max_cache_size: float=DEFAULT_MAX_CACHE_SIZE,
execution_device: torch.device=torch.device('cuda'),
storage_device: torch.device=torch.device('cpu'),
precision: torch.dtype=torch.float16,
sequential_offload: bool=False,
lazy_offloading: bool=True,
sha_chunksize: int = 16777216,
logger: types.ModuleType = logger
):
'''
:param max_models: Maximum number of models to cache in CPU RAM [4]
:param execution_device: Torch device to load active model into [torch.device('cuda')]
:param storage_device: Torch device to save inactive model in [torch.device('cpu')]
:param precision: Precision for loaded models [torch.float16]
:param lazy_offloading: Keep model in VRAM until another model needs to be loaded
:param sequential_offload: Conserve VRAM by loading and unloading each stage of the pipeline sequentially
:param sha_chunksize: Chunksize to use when calculating sha256 model hash
'''
#max_cache_size = 9999
execution_device = torch.device('cuda')
self.model_infos: Dict[str, ModelBase] = dict()
self.lazy_offloading = lazy_offloading
#self.sequential_offload: bool=sequential_offload
self.precision: torch.dtype=precision
self.max_cache_size: int=max_cache_size
self.execution_device: torch.device=execution_device
self.storage_device: torch.device=storage_device
self.sha_chunksize=sha_chunksize
self.logger = logger
self._cached_models = dict()
self._cache_stack = list()
def get_key(
self,
model_path: str,
base_model: BaseModelType,
model_type: ModelType,
submodel_type: Optional[SubModelType] = None,
):
key = f"{model_path}:{base_model}:{model_type}"
if submodel_type:
key += f":{submodel_type}"
return key
#def get_model(
# self,
# repo_id_or_path: Union[str, Path],
# model_type: ModelType = ModelType.Diffusers,
# subfolder: Path = None,
# submodel: ModelType = None,
# revision: str = None,
# attach_model_part: Tuple[ModelType, str] = (None, None),
# gpu_load: bool = True,
#) -> ModelLocker: # ?? what does it return
def _get_model_info(
self,
model_path: str,
model_class: Type[ModelBase],
base_model: BaseModelType,
model_type: ModelType,
):
model_info_key = self.get_key(
model_path=model_path,
base_model=base_model,
model_type=model_type,
submodel_type=None,
)
if model_info_key not in self.model_infos:
self.model_infos[model_info_key] = model_class(
model_path,
base_model,
model_type,
)
return self.model_infos[model_info_key]
# TODO: args
def get_model(
self,
model_path: Union[str, Path],
model_class: Type[ModelBase],
base_model: BaseModelType,
model_type: ModelType,
submodel: Optional[SubModelType] = None,
gpu_load: bool = True,
) -> Any:
if not isinstance(model_path, Path):
model_path = Path(model_path)
if not os.path.exists(model_path):
raise Exception(f"Model not found: {model_path}")
model_info = self._get_model_info(
model_path=model_path,
model_class=model_class,
base_model=base_model,
model_type=model_type,
)
key = self.get_key(
model_path=model_path,
base_model=base_model,
model_type=model_type,
submodel_type=submodel,
)
# TODO: lock for no copies on simultaneous calls?
cache_entry = self._cached_models.get(key, None)
if cache_entry is None:
self.logger.info(f'Loading model {model_path}, type {base_model}:{model_type}:{submodel}')
# this will remove older cached models until
# there is sufficient room to load the requested model
self._make_cache_room(model_info.get_size(submodel))
# clean memory to make MemoryUsage() more accurate
gc.collect()
model = model_info.get_model(child_type=submodel, torch_dtype=self.precision)
if mem_used := model_info.get_size(submodel):
self.logger.debug(f'CPU RAM used for load: {(mem_used/GIG):.2f} GB')
cache_entry = _CacheRecord(self, model, mem_used)
self._cached_models[key] = cache_entry
with suppress(Exception):
self._cache_stack.remove(key)
self._cache_stack.append(key)
return self.ModelLocker(self, key, cache_entry.model, gpu_load)
class ModelLocker(object):
def __init__(self, cache, key, model, gpu_load):
self.gpu_load = gpu_load
self.cache = cache
self.key = key
self.model = model
self.cache_entry = self.cache._cached_models[self.key]
def __enter__(self) -> Any:
if not hasattr(self.model, 'to'):
return self.model
# NOTE that the model has to have the to() method in order for this
# code to move it into GPU!
if self.gpu_load:
self.cache_entry.lock()
try:
if self.cache.lazy_offloading:
self.cache._offload_unlocked_models()
if self.model.device != self.cache.execution_device:
self.cache.logger.debug(f'Moving {self.key} into {self.cache.execution_device}')
with VRAMUsage() as mem:
self.model.to(self.cache.execution_device) # move into GPU
self.cache.logger.debug(f'GPU VRAM used for load: {(mem.vram_used/GIG):.2f} GB')
self.cache.logger.debug(f'Locking {self.key} in {self.cache.execution_device}')
self.cache._print_cuda_stats()
except:
self.cache_entry.unlock()
raise
# TODO: not fully understand
# in the event that the caller wants the model in RAM, we
# move it into CPU if it is in GPU and not locked
elif self.cache_entry.loaded and not self.cache_entry.locked:
self.model.to(self.cache.storage_device)
return self.model
def __exit__(self, type, value, traceback):
if not hasattr(self.model, 'to'):
return
self.cache_entry.unlock()
if not self.cache.lazy_offloading:
self.cache._offload_unlocked_models()
self.cache._print_cuda_stats()
# TODO: should it be called untrack_model?
def uncache_model(self, cache_id: str):
with suppress(ValueError):
self._cache_stack.remove(cache_id)
self._cached_models.pop(cache_id, None)
def model_hash(
self,
model_path: Union[str, Path],
) -> str:
'''
Given the HF repo id or path to a model on disk, returns a unique
hash. Works for legacy checkpoint files, HF models on disk, and HF repo IDs
:param model_path: Path to model file/directory on disk.
'''
return self._local_model_hash(model_path)
def cache_size(self) -> float:
"Return the current size of the cache, in GB"
current_cache_size = sum([m.size for m in self._cached_models.values()])
return current_cache_size / GIG
def _has_cuda(self) -> bool:
return self.execution_device.type == 'cuda'
def _print_cuda_stats(self):
vram = "%4.2fG" % (torch.cuda.memory_allocated() / GIG)
ram = "%4.2fG" % self.cache_size()
cached_models = 0
loaded_models = 0
locked_models = 0
for model_info in self._cached_models.values():
cached_models += 1
if model_info.loaded:
loaded_models += 1
if model_info.locked:
locked_models += 1
self.logger.debug(f"Current VRAM/RAM usage: {vram}/{ram}; cached_models/loaded_models/locked_models/ = {cached_models}/{loaded_models}/{locked_models}")
def _make_cache_room(self, model_size):
# calculate how much memory this model will require
#multiplier = 2 if self.precision==torch.float32 else 1
bytes_needed = model_size
maximum_size = self.max_cache_size * GIG # stored in GB, convert to bytes
current_size = sum([m.size for m in self._cached_models.values()])
if current_size + bytes_needed > maximum_size:
self.logger.debug(f'Max cache size exceeded: {(current_size/GIG):.2f}/{self.max_cache_size:.2f} GB, need an additional {(bytes_needed/GIG):.2f} GB')
self.logger.debug(f"Before unloading: cached_models={len(self._cached_models)}")
pos = 0
while current_size + bytes_needed > maximum_size and pos < len(self._cache_stack):
model_key = self._cache_stack[pos]
cache_entry = self._cached_models[model_key]
refs = sys.getrefcount(cache_entry.model)
device = cache_entry.model.device if hasattr(cache_entry.model, "device") else None
self.logger.debug(f"Model: {model_key}, locks: {cache_entry._locks}, device: {device}, loaded: {cache_entry.loaded}, refs: {refs}")
# 2 refs:
# 1 from cache_entry
# 1 from getrefcount function
if not cache_entry.locked and refs <= 2:
self.logger.debug(f'Unloading model {model_key} to free {(model_size/GIG):.2f} GB (-{(cache_entry.size/GIG):.2f} GB)')
current_size -= cache_entry.size
del self._cache_stack[pos]
del self._cached_models[model_key]
del cache_entry
else:
pos += 1
gc.collect()
torch.cuda.empty_cache()
self.logger.debug(f"After unloading: cached_models={len(self._cached_models)}")
def _offload_unlocked_models(self):
for model_key, cache_entry in self._cached_models.items():
if not cache_entry.locked and cache_entry.loaded:
self.logger.debug(f'Offloading {model_key} from {self.execution_device} into {self.storage_device}')
cache_entry.model.to(self.storage_device)
def _local_model_hash(self, model_path: Union[str, Path]) -> str:
sha = hashlib.sha256()
path = Path(model_path)
hashpath = path / "checksum.sha256"
if hashpath.exists() and path.stat().st_mtime <= hashpath.stat().st_mtime:
with open(hashpath) as f:
hash = f.read()
return hash
self.logger.debug(f'computing hash of model {path.name}')
for file in list(path.rglob("*.ckpt")) \
+ list(path.rglob("*.safetensors")) \
+ list(path.rglob("*.pth")):
with open(file, "rb") as f:
while chunk := f.read(self.sha_chunksize):
sha.update(chunk)
hash = sha.hexdigest()
with open(hashpath, "w") as f:
f.write(hash)
return hash
class VRAMUsage(object):
def __init__(self):
self.vram = None
self.vram_used = 0
def __enter__(self):
self.vram = torch.cuda.memory_allocated()
return self
def __exit__(self, *args):
self.vram_used = torch.cuda.memory_allocated() - self.vram

View File

@ -0,0 +1,118 @@
"""
Routines for downloading and installing models.
"""
import json
import safetensors
import safetensors.torch
import shutil
import tempfile
import torch
import traceback
from dataclasses import dataclass
from diffusers import ModelMixin
from enum import Enum
from typing import Callable
from pathlib import Path
import invokeai.backend.util.logging as logger
from invokeai.app.services.config import InvokeAIAppConfig
from . import ModelManager
from .models import BaseModelType, ModelType, VariantType
from .model_probe import ModelProbe, ModelVariantInfo
from .model_cache import SilenceWarnings
class ModelInstall(object):
'''
This class is able to download and install several different kinds of
InvokeAI models. The helper function, if provided, is called on to distinguish
between v2-base and v2-768 stable diffusion pipelines. This usually involves
asking the user to select the proper type, as there is no way of distinguishing
the two type of v2 file programmatically (as far as I know).
'''
def __init__(self,
config: InvokeAIAppConfig,
model_base_helper: Callable[[Path],BaseModelType]=None,
clobber:bool = False
):
'''
:param config: InvokeAI configuration object
:param model_base_helper: A function call that accepts the Path to a checkpoint model and returns a ModelType enum
:param clobber: If true, models with colliding names will be overwritten
'''
self.config = config
self.clogger = clobber
self.helper = model_base_helper
self.prober = ModelProbe()
def install_checkpoint_file(self, checkpoint: Path)->dict:
'''
Install the checkpoint file at path and return a
configuration entry that can be added to `models.yaml`.
Model checkpoints and VAEs will be converted into
diffusers before installation. Note that the model manager
does not hold entries for anything but diffusers pipelines,
and the configuration file stanzas returned from such models
can be safely ignored.
'''
model_info = self.prober.probe(checkpoint, self.helper)
if not model_info:
raise ValueError(f"Unable to determine type of checkpoint file {checkpoint}")
key = ModelManager.create_key(
model_name = checkpoint.stem,
base_model = model_info.base_type,
model_type = model_info.model_type,
)
destination_path = self._dest_path(model_info) / checkpoint
destination_path.parent.mkdir(parents=True, exist_ok=True)
self._check_for_collision(destination_path)
stanza = {
key: dict(
name = checkpoint.stem,
description = f'{model_info.model_type} model {checkpoint.stem}',
base = model_info.base_model.value,
type = model_info.model_type.value,
variant = model_info.variant_type.value,
path = str(destination_path),
)
}
# non-pipeline; no conversion needed, just copy into right place
if model_info.model_type != ModelType.Pipeline:
shutil.copyfile(checkpoint, destination_path)
stanza[key].update({'format': 'checkpoint'})
# pipeline - conversion needed here
else:
destination_path = self._dest_path(model_info) / checkpoint.stem
config_file = self._pipeline_type_to_config_file(model_info.model_type)
from .convert_ckpt_to_diffusers import convert_ckpt_to_diffusers
with SilenceWarnings:
convert_ckpt_to_diffusers(
checkpoint,
destination_path,
extract_ema=True,
original_config_file=config_file,
scan_needed=False,
)
stanza[key].update({'format': 'folder',
'path': destination_path, # no suffix on this
})
return stanza
def _check_for_collision(self, path: Path):
if not path.exists():
return
if self.clobber:
shutil.rmtree(path)
else:
raise ValueError(f"Destination {path} already exists. Won't overwrite unless clobber=True.")
def _staging_directory(self)->tempfile.TemporaryDirectory:
return tempfile.TemporaryDirectory(dir=self.config.root_path)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,417 @@
import json
import traceback
import torch
import safetensors.torch
from dataclasses import dataclass
from enum import Enum
from diffusers import ModelMixin, ConfigMixin, StableDiffusionPipeline, AutoencoderKL, ControlNetModel
from pathlib import Path
from typing import Callable, Literal, Union, Dict
from picklescan.scanner import scan_file_path
import invokeai.backend.util.logging as logger
from .models import BaseModelType, ModelType, ModelVariantType, SchedulerPredictionType, SilenceWarnings
@dataclass
class ModelVariantInfo(object):
model_type: ModelType
base_type: BaseModelType
variant_type: ModelVariantType
prediction_type: SchedulerPredictionType
upcast_attention: bool
format: Literal['folder','checkpoint']
image_size: int
class ProbeBase(object):
'''forward declaration'''
pass
class ModelProbe(object):
PROBES = {
'folder': { },
'checkpoint': { },
}
CLASS2TYPE = {
'StableDiffusionPipeline' : ModelType.Pipeline,
'AutoencoderKL' : ModelType.Vae,
'ControlNetModel' : ModelType.ControlNet,
}
@classmethod
def register_probe(cls,
format: Literal['folder','file'],
model_type: ModelType,
probe_class: ProbeBase):
cls.PROBES[format][model_type] = probe_class
@classmethod
def heuristic_probe(cls,
model: Union[Dict, ModelMixin, Path],
prediction_type_helper: Callable[[Path],BaseModelType]=None,
)->ModelVariantInfo:
if isinstance(model,Path):
return cls.probe(model_path=model,prediction_type_helper=prediction_type_helper)
elif isinstance(model,(dict,ModelMixin,ConfigMixin)):
return cls.probe(model_path=None, model=model, prediction_type_helper=prediction_type_helper)
else:
raise Exception("model parameter {model} is neither a Path, nor a model")
@classmethod
def probe(cls,
model_path: Path,
model: Union[Dict, ModelMixin] = None,
prediction_type_helper: Callable[[Path],BaseModelType] = None)->ModelVariantInfo:
'''
Probe the model at model_path and return sufficient information about it
to place it somewhere in the models directory hierarchy. If the model is
already loaded into memory, you may provide it as model in order to avoid
opening it a second time. The prediction_type_helper callable is a function that receives
the path to the model and returns the BaseModelType. It is called to distinguish
between V2-Base and V2-768 SD models.
'''
if model_path:
format = 'folder' if model_path.is_dir() else 'checkpoint'
else:
format = 'folder' if isinstance(model,(ConfigMixin,ModelMixin)) else 'checkpoint'
model_info = None
try:
model_type = cls.get_model_type_from_folder(model_path, model) \
if format == 'folder' \
else cls.get_model_type_from_checkpoint(model_path, model)
probe_class = cls.PROBES[format].get(model_type)
if not probe_class:
return None
probe = probe_class(model_path, model, prediction_type_helper)
base_type = probe.get_base_type()
variant_type = probe.get_variant_type()
prediction_type = probe.get_scheduler_prediction_type()
model_info = ModelVariantInfo(
model_type = model_type,
base_type = base_type,
variant_type = variant_type,
prediction_type = prediction_type,
upcast_attention = (base_type==BaseModelType.StableDiffusion2 \
and prediction_type==SchedulerPredictionType.VPrediction),
format = format,
image_size = 768 if (base_type==BaseModelType.StableDiffusion2 \
and prediction_type==SchedulerPredictionType.VPrediction \
) else 512,
)
except Exception as e:
return None
return model_info
@classmethod
def get_model_type_from_checkpoint(cls, model_path: Path, checkpoint: dict)->ModelType:
if model_path.suffix not in ('.bin','.pt','.ckpt','.safetensors'):
return None
if model_path.name=='learned_embeds.bin':
return ModelType.TextualInversion
checkpoint = checkpoint or cls._scan_and_load_checkpoint(model_path)
state_dict = checkpoint.get("state_dict") or checkpoint
if any([x.startswith("model.diffusion_model") for x in state_dict.keys()]):
return ModelType.Pipeline
if any([x.startswith("encoder.conv_in") for x in state_dict.keys()]):
return ModelType.Vae
if "string_to_token" in state_dict or "emb_params" in state_dict:
return ModelType.TextualInversion
if any([x.startswith("lora") for x in state_dict.keys()]):
return ModelType.Lora
if any([x.startswith("control_model") for x in state_dict.keys()]):
return ModelType.ControlNet
if any([x.startswith("input_blocks") for x in state_dict.keys()]):
return ModelType.ControlNet
return None # give up
@classmethod
def get_model_type_from_folder(cls, folder_path: Path, model: ModelMixin)->ModelType:
'''
Get the model type of a hugging-face style folder.
'''
class_name = None
if model:
class_name = model.__class__.__name__
else:
if (folder_path / 'learned_embeds.bin').exists():
return ModelType.TextualInversion
if (folder_path / 'pytorch_lora_weights.bin').exists():
return ModelType.Lora
i = folder_path / 'model_index.json'
c = folder_path / 'config.json'
config_path = i if i.exists() else c if c.exists() else None
if config_path:
with open(config_path,'r') as file:
conf = json.load(file)
class_name = conf['_class_name']
if class_name and (type := cls.CLASS2TYPE.get(class_name)):
return type
# give up
raise ValueError("Unable to determine model type")
@classmethod
def _scan_and_load_checkpoint(cls,model_path: Path)->dict:
with SilenceWarnings():
if model_path.suffix.endswith((".ckpt", ".pt", ".bin")):
cls._scan_model(model_path, model_path)
return torch.load(model_path)
else:
return safetensors.torch.load_file(model_path)
@classmethod
def _scan_model(cls, model_name, checkpoint):
"""
Apply picklescanner to the indicated checkpoint and issue a warning
and option to exit if an infected file is identified.
"""
# scan model
scan_result = scan_file_path(checkpoint)
if scan_result.infected_files != 0:
raise "The model {model_name} is potentially infected by malware. Aborting import."
###################################################3
# Checkpoint probing
###################################################3
class ProbeBase(object):
def get_base_type(self)->BaseModelType:
pass
def get_variant_type(self)->ModelVariantType:
pass
def get_scheduler_prediction_type(self)->SchedulerPredictionType:
pass
class CheckpointProbeBase(ProbeBase):
def __init__(self,
checkpoint_path: Path,
checkpoint: dict,
helper: Callable[[Path],BaseModelType] = None
)->BaseModelType:
self.checkpoint = checkpoint or ModelProbe._scan_and_load_checkpoint(checkpoint_path)
self.checkpoint_path = checkpoint_path
self.helper = helper
def get_base_type(self)->BaseModelType:
pass
def get_variant_type(self)-> ModelVariantType:
model_type = ModelProbe.get_model_type_from_checkpoint(self.checkpoint_path,self.checkpoint)
if model_type != ModelType.Pipeline:
return ModelVariantType.Normal
state_dict = self.checkpoint.get('state_dict') or self.checkpoint
in_channels = state_dict[
"model.diffusion_model.input_blocks.0.0.weight"
].shape[1]
if in_channels == 9:
return ModelVariantType.Inpaint
elif in_channels == 5:
return ModelVariantType.Depth
elif in_channels == 4:
return ModelVariantType.Normal
else:
raise Exception("Cannot determine variant type")
class PipelineCheckpointProbe(CheckpointProbeBase):
def get_base_type(self)->BaseModelType:
checkpoint = self.checkpoint
state_dict = self.checkpoint.get('state_dict') or checkpoint
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
if key_name in state_dict and state_dict[key_name].shape[-1] == 768:
return BaseModelType.StableDiffusion1
if key_name in state_dict and state_dict[key_name].shape[-1] == 1024:
return BaseModelType.StableDiffusion2
raise Exception("Cannot determine base type")
def get_scheduler_prediction_type(self)->SchedulerPredictionType:
type = self.get_base_type()
if type == BaseModelType.StableDiffusion1:
return SchedulerPredictionType.Epsilon
checkpoint = self.checkpoint
state_dict = self.checkpoint.get('state_dict') or checkpoint
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
if key_name in state_dict and state_dict[key_name].shape[-1] == 1024:
if 'global_step' in checkpoint:
if checkpoint['global_step'] == 220000:
return SchedulerPredictionType.Epsilon
elif checkpoint["global_step"] == 110000:
return SchedulerPredictionType.VPrediction
if self.checkpoint_path and self.helper:
return self.helper(self.checkpoint_path)
else:
return None
class VaeCheckpointProbe(CheckpointProbeBase):
def get_base_type(self)->BaseModelType:
# I can't find any standalone 2.X VAEs to test with!
return BaseModelType.StableDiffusion1
class LoRACheckpointProbe(CheckpointProbeBase):
def get_base_type(self)->BaseModelType:
checkpoint = self.checkpoint
key1 = "lora_te_text_model_encoder_layers_0_mlp_fc1.lora_down.weight"
key2 = "lora_te_text_model_encoder_layers_0_self_attn_k_proj.hada_w1_a"
lora_token_vector_length = (
checkpoint[key1].shape[1]
if key1 in checkpoint
else checkpoint[key2].shape[0]
if key2 in checkpoint
else 768
)
if lora_token_vector_length == 768:
return BaseModelType.StableDiffusion1
elif lora_token_vector_length == 1024:
return BaseModelType.StableDiffusion2
else:
return None
class TextualInversionCheckpointProbe(CheckpointProbeBase):
def get_base_type(self)->BaseModelType:
checkpoint = self.checkpoint
if 'string_to_token' in checkpoint:
token_dim = list(checkpoint['string_to_param'].values())[0].shape[-1]
elif 'emb_params' in checkpoint:
token_dim = checkpoint['emb_params'].shape[-1]
else:
token_dim = list(checkpoint.values())[0].shape[0]
if token_dim == 768:
return BaseModelType.StableDiffusion1
elif token_dim == 1024:
return BaseModelType.StableDiffusion2
else:
return None
class ControlNetCheckpointProbe(CheckpointProbeBase):
def get_base_type(self)->BaseModelType:
checkpoint = self.checkpoint
for key_name in ('control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight',
'input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight'
):
if key_name not in checkpoint:
continue
if checkpoint[key_name].shape[-1] == 768:
return BaseModelType.StableDiffusion1
elif checkpoint[key_name].shape[-1] == 1024:
return BaseModelType.StableDiffusion2
elif self.checkpoint_path and self.helper:
return self.helper(self.checkpoint_path)
raise Exception("Unable to determine base type for {self.checkpoint_path}")
########################################################
# classes for probing folders
#######################################################
class FolderProbeBase(ProbeBase):
def __init__(self,
folder_path: Path,
model: ModelMixin = None,
helper: Callable=None # not used
):
self.model = model
self.folder_path = folder_path
def get_variant_type(self)->ModelVariantType:
return ModelVariantType.Normal
class PipelineFolderProbe(FolderProbeBase):
def get_base_type(self)->BaseModelType:
if self.model:
unet_conf = self.model.unet.config
scheduler_conf = self.model.scheduler.config
else:
with open(self.folder_path / 'unet' / 'config.json','r') as file:
unet_conf = json.load(file)
with open(self.folder_path / 'scheduler' / 'scheduler_config.json','r') as file:
scheduler_conf = json.load(file)
if unet_conf['cross_attention_dim'] == 768:
return BaseModelType.StableDiffusion1
elif unet_conf['cross_attention_dim'] == 1024:
return BaseModelType.StableDiffusion2
else:
raise ValueError(f'Unknown base model for {self.folder_path}')
def get_scheduler_prediction_type(self)->SchedulerPredictionType:
if self.model:
scheduler_conf = self.model.scheduler.config
else:
with open(self.folder_path / 'scheduler' / 'scheduler_config.json','r') as file:
scheduler_conf = json.load(file)
if scheduler_conf['prediction_type'] == "v_prediction":
return SchedulerPredictionType.VPrediction
elif scheduler_conf['prediction_type'] == 'epsilon':
return SchedulerPredictionType.Epsilon
else:
return None
def get_variant_type(self)->ModelVariantType:
# This only works for pipelines! Any kind of
# exception results in our returning the
# "normal" variant type
try:
if self.model:
conf = self.model.unet.config
else:
config_file = self.folder_path / 'unet' / 'config.json'
with open(config_file,'r') as file:
conf = json.load(file)
in_channels = conf['in_channels']
if in_channels == 9:
return ModelVariantType.Inpainting
elif in_channels == 5:
return ModelVariantType.Depth
elif in_channels == 4:
return ModelVariantType.Normal
except:
pass
return ModelVariantType.Normal
class VaeFolderProbe(FolderProbeBase):
def get_base_type(self)->BaseModelType:
return BaseModelType.StableDiffusion1
class TextualInversionFolderProbe(FolderProbeBase):
def get_base_type(self)->BaseModelType:
path = self.folder_path / 'learned_embeds.bin'
if not path.exists():
return None
checkpoint = ModelProbe._scan_and_load_checkpoint(path)
return TextualInversionCheckpointProbe(None,checkpoint=checkpoint).get_base_type()
class ControlNetFolderProbe(FolderProbeBase):
def get_base_type(self)->BaseModelType:
config_file = self.folder_path / 'config.json'
if not config_file.exists():
raise Exception(f"Cannot determine base type for {self.folder_path}")
with open(config_file,'r') as file:
config = json.load(file)
# no obvious way to distinguish between sd2-base and sd2-768
return BaseModelType.StableDiffusion1 \
if config['cross_attention_dim']==768 \
else BaseModelType.StableDiffusion2
class LoRAFolderProbe(FolderProbeBase):
# I've never seen one of these in the wild, so this is a noop
pass
############## register probe classes ######
ModelProbe.register_probe('folder', ModelType.Pipeline, PipelineFolderProbe)
ModelProbe.register_probe('folder', ModelType.Vae, VaeFolderProbe)
ModelProbe.register_probe('folder', ModelType.Lora, LoRAFolderProbe)
ModelProbe.register_probe('folder', ModelType.TextualInversion, TextualInversionFolderProbe)
ModelProbe.register_probe('folder', ModelType.ControlNet, ControlNetFolderProbe)
ModelProbe.register_probe('checkpoint', ModelType.Pipeline, PipelineCheckpointProbe)
ModelProbe.register_probe('checkpoint', ModelType.Vae, VaeCheckpointProbe)
ModelProbe.register_probe('checkpoint', ModelType.Lora, LoRACheckpointProbe)
ModelProbe.register_probe('checkpoint', ModelType.TextualInversion, TextualInversionCheckpointProbe)
ModelProbe.register_probe('checkpoint', ModelType.ControlNet, ControlNetCheckpointProbe)

View File

@ -0,0 +1,95 @@
import inspect
from enum import Enum
from pydantic import BaseModel
from typing import Literal, get_origin
from .base import BaseModelType, ModelType, SubModelType, ModelBase, ModelConfigBase, ModelVariantType, SchedulerPredictionType, ModelError, SilenceWarnings
from .stable_diffusion import StableDiffusion1Model, StableDiffusion2Model
from .vae import VaeModel
from .lora import LoRAModel
from .controlnet import ControlNetModel # TODO:
from .textual_inversion import TextualInversionModel
MODEL_CLASSES = {
BaseModelType.StableDiffusion1: {
ModelType.Pipeline: StableDiffusion1Model,
ModelType.Vae: VaeModel,
ModelType.Lora: LoRAModel,
ModelType.ControlNet: ControlNetModel,
ModelType.TextualInversion: TextualInversionModel,
},
BaseModelType.StableDiffusion2: {
ModelType.Pipeline: StableDiffusion2Model,
ModelType.Vae: VaeModel,
ModelType.Lora: LoRAModel,
ModelType.ControlNet: ControlNetModel,
ModelType.TextualInversion: TextualInversionModel,
},
#BaseModelType.Kandinsky2_1: {
# ModelType.Pipeline: Kandinsky2_1Model,
# ModelType.MoVQ: MoVQModel,
# ModelType.Lora: LoRAModel,
# ModelType.ControlNet: ControlNetModel,
# ModelType.TextualInversion: TextualInversionModel,
#},
}
MODEL_CONFIGS = list()
OPENAPI_MODEL_CONFIGS = list()
class OpenAPIModelInfoBase(BaseModel):
name: str
base_model: BaseModelType
type: ModelType
for base_model, models in MODEL_CLASSES.items():
for model_type, model_class in models.items():
model_configs = set(model_class._get_configs().values())
model_configs.discard(None)
MODEL_CONFIGS.extend(model_configs)
for cfg in model_configs:
model_name, cfg_name = cfg.__qualname__.split('.')[-2:]
openapi_cfg_name = model_name + cfg_name
if openapi_cfg_name in vars():
continue
api_wrapper = type(openapi_cfg_name, (cfg, OpenAPIModelInfoBase), dict(
__annotations__ = dict(
type=Literal[model_type.value],
),
))
#globals()[openapi_cfg_name] = api_wrapper
vars()[openapi_cfg_name] = api_wrapper
OPENAPI_MODEL_CONFIGS.append(api_wrapper)
def get_model_config_enums():
enums = list()
for model_config in MODEL_CONFIGS:
fields = inspect.get_annotations(model_config)
try:
field = fields["model_format"]
except:
raise Exception("format field not found")
# model_format: None
# model_format: SomeModelFormat
# model_format: Literal[SomeModelFormat.Diffusers]
# model_format: Literal[SomeModelFormat.Diffusers, SomeModelFormat.Checkpoint]
if isinstance(field, type) and issubclass(field, str) and issubclass(field, Enum):
enums.append(field)
elif get_origin(field) is Literal and all(isinstance(arg, str) and isinstance(arg, Enum) for arg in field.__args__):
enums.append(type(field.__args__[0]))
elif field is None:
pass
else:
raise Exception(f"Unsupported format definition in {model_configs.__qualname__}")
return enums

View File

@ -0,0 +1,415 @@
import os
import sys
import typing
import inspect
from enum import Enum
from abc import ABCMeta, abstractmethod
import torch
import safetensors.torch
from diffusers import DiffusionPipeline, ConfigMixin
from contextlib import suppress
from pydantic import BaseModel, Field
from typing import List, Dict, Optional, Type, Literal, TypeVar, Generic, Callable, Any, Union
class BaseModelType(str, Enum):
StableDiffusion1 = "sd-1"
StableDiffusion2 = "sd-2"
#Kandinsky2_1 = "kandinsky-2.1"
class ModelType(str, Enum):
Pipeline = "pipeline"
Vae = "vae"
Lora = "lora"
ControlNet = "controlnet" # used by model_probe
TextualInversion = "embedding"
class SubModelType(str, Enum):
UNet = "unet"
TextEncoder = "text_encoder"
Tokenizer = "tokenizer"
Vae = "vae"
Scheduler = "scheduler"
SafetyChecker = "safety_checker"
#MoVQ = "movq"
class ModelVariantType(str, Enum):
Normal = "normal"
Inpaint = "inpaint"
Depth = "depth"
class SchedulerPredictionType(str, Enum):
Epsilon = "epsilon"
VPrediction = "v_prediction"
Sample = "sample"
class ModelError(str, Enum):
NotFound = "not_found"
class ModelConfigBase(BaseModel):
path: str # or Path
description: Optional[str] = Field(None)
model_format: Optional[str] = Field(None)
# do not save to config
error: Optional[ModelError] = Field(None)
class Config:
use_enum_values = True
class EmptyConfigLoader(ConfigMixin):
@classmethod
def load_config(cls, *args, **kwargs):
cls.config_name = kwargs.pop("config_name")
return super().load_config(*args, **kwargs)
T_co = TypeVar('T_co', covariant=True)
class classproperty(Generic[T_co]):
def __init__(self, fget: Callable[[Any], T_co]) -> None:
self.fget = fget
def __get__(self, instance: Optional[Any], owner: Type[Any]) -> T_co:
return self.fget(owner)
def __set__(self, instance: Optional[Any], value: Any) -> None:
raise AttributeError('cannot set attribute')
class ModelBase(metaclass=ABCMeta):
#model_path: str
#base_model: BaseModelType
#model_type: ModelType
def __init__(
self,
model_path: str,
base_model: BaseModelType,
model_type: ModelType,
):
self.model_path = model_path
self.base_model = base_model
self.model_type = model_type
def _hf_definition_to_type(self, subtypes: List[str]) -> Type:
if len(subtypes) < 2:
raise Exception("Invalid subfolder definition!")
if all(t is None for t in subtypes):
return None
elif any(t is None for t in subtypes):
raise Exception(f"Unsupported definition: {subtypes}")
if subtypes[0] in ["diffusers", "transformers"]:
res_type = sys.modules[subtypes[0]]
subtypes = subtypes[1:]
else:
res_type = sys.modules["diffusers"]
res_type = getattr(res_type, "pipelines")
for subtype in subtypes:
res_type = getattr(res_type, subtype)
return res_type
@classmethod
def _get_configs(cls):
with suppress(Exception):
return cls.__configs
configs = dict()
for name in dir(cls):
if name.startswith("__"):
continue
value = getattr(cls, name)
if not isinstance(value, type) or not issubclass(value, ModelConfigBase):
continue
fields = inspect.get_annotations(value)
try:
field = fields["model_format"]
except:
raise Exception(f"Invalid config definition - format field not found({cls.__qualname__})")
if isinstance(field, type) and issubclass(field, str) and issubclass(field, Enum):
for model_format in field:
configs[model_format.value] = value
elif typing.get_origin(field) is Literal and all(isinstance(arg, str) and isinstance(arg, Enum) for arg in field.__args__):
for model_format in field.__args__:
configs[model_format.value] = value
elif field is None:
configs[None] = value
else:
raise Exception(f"Unsupported format definition in {cls.__qualname__}")
cls.__configs = configs
return cls.__configs
@classmethod
def create_config(cls, **kwargs) -> ModelConfigBase:
if "model_format" not in kwargs:
raise Exception("Field 'model_format' not found in model config")
configs = cls._get_configs()
return configs[kwargs["model_format"]](**kwargs)
@classmethod
def probe_config(cls, path: str, **kwargs) -> ModelConfigBase:
return cls.create_config(
path=path,
model_format=cls.detect_format(path),
)
@classmethod
@abstractmethod
def detect_format(cls, path: str) -> str:
raise NotImplementedError()
@classproperty
@abstractmethod
def save_to_config(cls) -> bool:
raise NotImplementedError()
@abstractmethod
def get_size(self, child_type: Optional[SubModelType] = None) -> int:
raise NotImplementedError()
@abstractmethod
def get_model(
self,
torch_dtype: Optional[torch.dtype],
child_type: Optional[SubModelType] = None,
) -> Any:
raise NotImplementedError()
class DiffusersModel(ModelBase):
#child_types: Dict[str, Type]
#child_sizes: Dict[str, int]
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
super().__init__(model_path, base_model, model_type)
self.child_types: Dict[str, Type] = dict()
self.child_sizes: Dict[str, int] = dict()
try:
config_data = DiffusionPipeline.load_config(self.model_path)
#config_data = json.loads(os.path.join(self.model_path, "model_index.json"))
except:
raise Exception("Invalid diffusers model! (model_index.json not found or invalid)")
config_data.pop("_ignore_files", None)
# retrieve all folder_names that contain relevant files
child_components = [k for k, v in config_data.items() if isinstance(v, list)]
for child_name in child_components:
child_type = self._hf_definition_to_type(config_data[child_name])
self.child_types[child_name] = child_type
self.child_sizes[child_name] = calc_model_size_by_fs(self.model_path, subfolder=child_name)
def get_size(self, child_type: Optional[SubModelType] = None):
if child_type is None:
return sum(self.child_sizes.values())
else:
return self.child_sizes[child_type]
def get_model(
self,
torch_dtype: Optional[torch.dtype],
child_type: Optional[SubModelType] = None,
):
# return pipeline in different function to pass more arguments
if child_type is None:
raise Exception("Child model type can't be null on diffusers model")
if child_type not in self.child_types:
return None # TODO: or raise
if torch_dtype == torch.float16:
variants = ["fp16", None]
else:
variants = [None, "fp16"]
# TODO: better error handling(differentiate not found from others)
for variant in variants:
try:
# TODO: set cache_dir to /dev/null to be sure that cache not used?
model = self.child_types[child_type].from_pretrained(
self.model_path,
subfolder=child_type.value,
torch_dtype=torch_dtype,
variant=variant,
local_files_only=True,
)
break
except Exception as e:
#print("====ERR LOAD====")
#print(f"{variant}: {e}")
pass
else:
raise Exception(f"Failed to load {self.base_model}:{self.model_type}:{child_type} model")
# calc more accurate size
self.child_sizes[child_type] = calc_model_size_by_data(model)
return model
#def convert_if_required(model_path: str, cache_path: str, config: Optional[dict]) -> str:
def calc_model_size_by_fs(
model_path: str,
subfolder: Optional[str] = None,
variant: Optional[str] = None
):
if subfolder is not None:
model_path = os.path.join(model_path, subfolder)
# this can happen when, for example, the safety checker
# is not downloaded.
if not os.path.exists(model_path):
return 0
all_files = os.listdir(model_path)
all_files = [f for f in all_files if os.path.isfile(os.path.join(model_path, f))]
fp16_files = set([f for f in all_files if ".fp16." in f or ".fp16-" in f])
bit8_files = set([f for f in all_files if ".8bit." in f or ".8bit-" in f])
other_files = set(all_files) - fp16_files - bit8_files
if variant is None:
files = other_files
elif variant == "fp16":
files = fp16_files
elif variant == "8bit":
files = bit8_files
else:
raise NotImplementedError(f"Unknown variant: {variant}")
# try read from index if exists
index_postfix = ".index.json"
if variant is not None:
index_postfix = f".index.{variant}.json"
for file in files:
if not file.endswith(index_postfix):
continue
try:
with open(os.path.join(model_path, file), "r") as f:
index_data = json.loads(f.read())
return int(index_data["metadata"]["total_size"])
except:
pass
# calculate files size if there is no index file
formats = [
(".safetensors",), # safetensors
(".bin",), # torch
(".onnx", ".pb"), # onnx
(".msgpack",), # flax
(".ckpt",), # tf
(".h5",), # tf2
]
for file_format in formats:
model_files = [f for f in files if f.endswith(file_format)]
if len(model_files) == 0:
continue
model_size = 0
for model_file in model_files:
file_stats = os.stat(os.path.join(model_path, model_file))
model_size += file_stats.st_size
return model_size
#raise NotImplementedError(f"Unknown model structure! Files: {all_files}")
return 0 # scheduler/feature_extractor/tokenizer - models without loading to gpu
def calc_model_size_by_data(model) -> int:
if isinstance(model, DiffusionPipeline):
return _calc_pipeline_by_data(model)
elif isinstance(model, torch.nn.Module):
return _calc_model_by_data(model)
else:
return 0
def _calc_pipeline_by_data(pipeline) -> int:
res = 0
for submodel_key in pipeline.components.keys():
submodel = getattr(pipeline, submodel_key)
if submodel is not None and isinstance(submodel, torch.nn.Module):
res += _calc_model_by_data(submodel)
return res
def _calc_model_by_data(model) -> int:
mem_params = sum([param.nelement()*param.element_size() for param in model.parameters()])
mem_bufs = sum([buf.nelement()*buf.element_size() for buf in model.buffers()])
mem = mem_params + mem_bufs # in bytes
return mem
def _fast_safetensors_reader(path: str):
checkpoint = dict()
device = torch.device("meta")
with open(path, "rb") as f:
definition_len = int.from_bytes(f.read(8), 'little')
definition_json = f.read(definition_len)
definition = json.loads(definition_json)
if "__metadata__" in definition and definition["__metadata__"].get("format", "pt") not in {"pt", "torch", "pytorch"}:
raise Exception("Supported only pytorch safetensors files")
definition.pop("__metadata__", None)
for key, info in definition.items():
dtype = {
"I8": torch.int8,
"I16": torch.int16,
"I32": torch.int32,
"I64": torch.int64,
"F16": torch.float16,
"F32": torch.float32,
"F64": torch.float64,
}[info["dtype"]]
checkpoint[key] = torch.empty(info["shape"], dtype=dtype, device=device)
return checkpoint
def read_checkpoint_meta(path: str):
if path.endswith(".safetensors"):
try:
checkpoint = _fast_safetensors_reader(path)
except:
# TODO: create issue for support "meta"?
checkpoint = safetensors.torch.load_file(path, device="cpu")
else:
checkpoint = torch.load(path, map_location=torch.device("meta"))
return checkpoint
import warnings
from diffusers import logging as diffusers_logging
from transformers import logging as transformers_logging
class SilenceWarnings(object):
def __init__(self):
self.transformers_verbosity = transformers_logging.get_verbosity()
self.diffusers_verbosity = diffusers_logging.get_verbosity()
def __enter__(self):
transformers_logging.set_verbosity_error()
diffusers_logging.set_verbosity_error()
warnings.simplefilter('ignore')
def __exit__(self, type, value, traceback):
transformers_logging.set_verbosity(self.transformers_verbosity)
diffusers_logging.set_verbosity(self.diffusers_verbosity)
warnings.simplefilter('default')

View File

@ -0,0 +1,92 @@
import os
import torch
from enum import Enum
from pathlib import Path
from typing import Optional, Union, Literal
from .base import (
ModelBase,
ModelConfigBase,
BaseModelType,
ModelType,
SubModelType,
EmptyConfigLoader,
calc_model_size_by_fs,
calc_model_size_by_data,
classproperty,
)
class ControlNetModelFormat(str, Enum):
Checkpoint = "checkpoint"
Diffusers = "diffusers"
class ControlNetModel(ModelBase):
#model_class: Type
#model_size: int
class Config(ModelConfigBase):
model_format: ControlNetModelFormat
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
assert model_type == ModelType.ControlNet
super().__init__(model_path, base_model, model_type)
try:
config = EmptyConfigLoader.load_config(self.model_path, config_name="config.json")
#config = json.loads(os.path.join(self.model_path, "config.json"))
except:
raise Exception("Invalid controlnet model! (config.json not found or invalid)")
model_class_name = config.get("_class_name", None)
if model_class_name not in {"ControlNetModel"}:
raise Exception(f"Invalid ControlNet model! Unknown _class_name: {model_class_name}")
try:
self.model_class = self._hf_definition_to_type(["diffusers", model_class_name])
self.model_size = calc_model_size_by_fs(self.model_path)
except:
raise Exception("Invalid ControlNet model!")
def get_size(self, child_type: Optional[SubModelType] = None):
if child_type is not None:
raise Exception("There is no child models in controlnet model")
return self.model_size
def get_model(
self,
torch_dtype: Optional[torch.dtype],
child_type: Optional[SubModelType] = None,
):
if child_type is not None:
raise Exception("There is no child models in controlnet model")
model = self.model_class.from_pretrained(
self.model_path,
torch_dtype=torch_dtype,
)
# calc more accurate size
self.model_size = calc_model_size_by_data(model)
return model
@classproperty
def save_to_config(cls) -> bool:
return False
@classmethod
def detect_format(cls, path: str):
if os.path.isdir(path):
return ControlNetModelFormat.Diffusers
else:
return ControlNetModelFormat.Checkpoint
@classmethod
def convert_if_required(
cls,
model_path: str,
output_path: str,
config: ModelConfigBase, # empty config or config of parent model
base_model: BaseModelType,
) -> str:
if cls.detect_format(model_path) != ControlNetModelFormat.Diffusers:
raise NotImplementedError("Checkpoint controlnet models currently unsupported")
else:
return model_path

View File

@ -0,0 +1,76 @@
import os
import torch
from enum import Enum
from typing import Optional, Union, Literal
from .base import (
ModelBase,
ModelConfigBase,
BaseModelType,
ModelType,
SubModelType,
classproperty,
)
# TODO: naming
from ..lora import LoRAModel as LoRAModelRaw
class LoRAModelFormat(str, Enum):
LyCORIS = "lycoris"
Diffusers = "diffusers"
class LoRAModel(ModelBase):
#model_size: int
class Config(ModelConfigBase):
model_format: LoRAModelFormat # TODO:
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
assert model_type == ModelType.Lora
super().__init__(model_path, base_model, model_type)
self.model_size = os.path.getsize(self.model_path)
def get_size(self, child_type: Optional[SubModelType] = None):
if child_type is not None:
raise Exception("There is no child models in lora")
return self.model_size
def get_model(
self,
torch_dtype: Optional[torch.dtype],
child_type: Optional[SubModelType] = None,
):
if child_type is not None:
raise Exception("There is no child models in lora")
model = LoRAModelRaw.from_checkpoint(
file_path=self.model_path,
dtype=torch_dtype,
)
self.model_size = model.calc_size()
return model
@classproperty
def save_to_config(cls) -> bool:
return False
@classmethod
def detect_format(cls, path: str):
if os.path.isdir(path):
return LoRAModelFormat.Diffusers
else:
return LoRAModelFormat.LyCORIS
@classmethod
def convert_if_required(
cls,
model_path: str,
output_path: str,
config: ModelConfigBase,
base_model: BaseModelType,
) -> str:
if cls.detect_format(model_path) == LoRAModelFormat.Diffusers:
# TODO: add diffusers lora when it stabilizes a bit
raise NotImplementedError("Diffusers lora not supported")
else:
return model_path

View File

@ -0,0 +1,321 @@
import os
import json
from enum import Enum
from pydantic import Field
from pathlib import Path
from typing import Literal, Optional, Union
from .base import (
ModelBase,
ModelConfigBase,
BaseModelType,
ModelType,
SubModelType,
ModelVariantType,
DiffusersModel,
SchedulerPredictionType,
SilenceWarnings,
read_checkpoint_meta,
classproperty,
)
from invokeai.app.services.config import InvokeAIAppConfig
from omegaconf import OmegaConf
class StableDiffusion1ModelFormat(str, Enum):
Checkpoint = "checkpoint"
Diffusers = "diffusers"
class StableDiffusion1Model(DiffusersModel):
class DiffusersConfig(ModelConfigBase):
model_format: Literal[StableDiffusion1ModelFormat.Diffusers]
vae: Optional[str] = Field(None)
variant: ModelVariantType
class CheckpointConfig(ModelConfigBase):
model_format: Literal[StableDiffusion1ModelFormat.Checkpoint]
vae: Optional[str] = Field(None)
config: Optional[str] = Field(None)
variant: ModelVariantType
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
assert base_model == BaseModelType.StableDiffusion1
assert model_type == ModelType.Pipeline
super().__init__(
model_path=model_path,
base_model=BaseModelType.StableDiffusion1,
model_type=ModelType.Pipeline,
)
@classmethod
def probe_config(cls, path: str, **kwargs):
model_format = cls.detect_format(path)
ckpt_config_path = kwargs.get("config", None)
if model_format == StableDiffusion1ModelFormat.Checkpoint:
if ckpt_config_path:
ckpt_config = OmegaConf.load(ckpt_config_path)
ckpt_config["model"]["params"]["unet_config"]["params"]["in_channels"]
else:
checkpoint = read_checkpoint_meta(path)
checkpoint = checkpoint.get('state_dict', checkpoint)
in_channels = checkpoint["model.diffusion_model.input_blocks.0.0.weight"].shape[1]
elif model_format == StableDiffusion1ModelFormat.Diffusers:
unet_config_path = os.path.join(path, "unet", "config.json")
if os.path.exists(unet_config_path):
with open(unet_config_path, "r") as f:
unet_config = json.loads(f.read())
in_channels = unet_config['in_channels']
else:
raise Exception("Not supported stable diffusion diffusers format(possibly onnx?)")
else:
raise NotImplementedError(f"Unknown stable diffusion 1.* format: {model_format}")
if in_channels == 9:
variant = ModelVariantType.Inpaint
elif in_channels == 4:
variant = ModelVariantType.Normal
else:
raise Exception("Unkown stable diffusion 1.* model format")
return cls.create_config(
path=path,
model_format=model_format,
config=ckpt_config_path,
variant=variant,
)
@classproperty
def save_to_config(cls) -> bool:
return True
@classmethod
def detect_format(cls, model_path: str):
if os.path.isdir(model_path):
return StableDiffusion1ModelFormat.Diffusers
else:
return StableDiffusion1ModelFormat.Checkpoint
@classmethod
def convert_if_required(
cls,
model_path: str,
output_path: str,
config: ModelConfigBase,
base_model: BaseModelType,
) -> str:
assert model_path == config.path
if isinstance(config, cls.CheckpointConfig):
return _convert_ckpt_and_cache(
version=BaseModelType.StableDiffusion1,
model_config=config,
output_path=output_path,
) # TODO: args
else:
return model_path
class StableDiffusion2ModelFormat(str, Enum):
Checkpoint = "checkpoint"
Diffusers = "diffusers"
class StableDiffusion2Model(DiffusersModel):
# TODO: check that configs overwriten properly
class DiffusersConfig(ModelConfigBase):
model_format: Literal[StableDiffusion2ModelFormat.Diffusers]
vae: Optional[str] = Field(None)
variant: ModelVariantType
prediction_type: SchedulerPredictionType
upcast_attention: bool
class CheckpointConfig(ModelConfigBase):
model_format: Literal[StableDiffusion2ModelFormat.Checkpoint]
vae: Optional[str] = Field(None)
config: Optional[str] = Field(None)
variant: ModelVariantType
prediction_type: SchedulerPredictionType
upcast_attention: bool
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
assert base_model == BaseModelType.StableDiffusion2
assert model_type == ModelType.Pipeline
super().__init__(
model_path=model_path,
base_model=BaseModelType.StableDiffusion2,
model_type=ModelType.Pipeline,
)
@classmethod
def probe_config(cls, path: str, **kwargs):
model_format = cls.detect_format(path)
ckpt_config_path = kwargs.get("config", None)
if model_format == StableDiffusion2ModelFormat.Checkpoint:
if ckpt_config_path:
ckpt_config = OmegaConf.load(ckpt_config_path)
ckpt_config["model"]["params"]["unet_config"]["params"]["in_channels"]
else:
checkpoint = read_checkpoint_meta(path)
checkpoint = checkpoint.get('state_dict', checkpoint)
in_channels = checkpoint["model.diffusion_model.input_blocks.0.0.weight"].shape[1]
elif model_format == StableDiffusion2ModelFormat.Diffusers:
unet_config_path = os.path.join(path, "unet", "config.json")
if os.path.exists(unet_config_path):
with open(unet_config_path, "r") as f:
unet_config = json.loads(f.read())
in_channels = unet_config['in_channels']
else:
raise Exception("Not supported stable diffusion diffusers format(possibly onnx?)")
else:
raise NotImplementedError(f"Unknown stable diffusion 2.* format: {model_format}")
if in_channels == 9:
variant = ModelVariantType.Inpaint
elif in_channels == 5:
variant = ModelVariantType.Depth
elif in_channels == 4:
variant = ModelVariantType.Normal
else:
raise Exception("Unkown stable diffusion 2.* model format")
if variant == ModelVariantType.Normal:
prediction_type = SchedulerPredictionType.VPrediction
upcast_attention = True
else:
prediction_type = SchedulerPredictionType.Epsilon
upcast_attention = False
return cls.create_config(
path=path,
model_format=model_format,
config=ckpt_config_path,
variant=variant,
prediction_type=prediction_type,
upcast_attention=upcast_attention,
)
@classproperty
def save_to_config(cls) -> bool:
return True
@classmethod
def detect_format(cls, model_path: str):
if os.path.isdir(model_path):
return StableDiffusion2ModelFormat.Diffusers
else:
return StableDiffusion2ModelFormat.Checkpoint
@classmethod
def convert_if_required(
cls,
model_path: str,
output_path: str,
config: ModelConfigBase,
base_model: BaseModelType,
) -> str:
assert model_path == config.path
if isinstance(config, cls.CheckpointConfig):
return _convert_ckpt_and_cache(
version=BaseModelType.StableDiffusion2,
model_config=config,
output_path=output_path,
) # TODO: args
else:
return model_path
def _select_ckpt_config(version: BaseModelType, variant: ModelVariantType):
ckpt_configs = {
BaseModelType.StableDiffusion1: {
ModelVariantType.Normal: "v1-inference.yaml",
ModelVariantType.Inpaint: "v1-inpainting-inference.yaml",
},
BaseModelType.StableDiffusion2: {
# code further will manually set upcast_attention and v_prediction
ModelVariantType.Normal: "v2-inference.yaml",
ModelVariantType.Inpaint: "v2-inpainting-inference.yaml",
ModelVariantType.Depth: "v2-midas-inference.yaml",
}
}
try:
# TODO: path
#model_config.config = app_config.config_dir / "stable-diffusion" / ckpt_configs[version][model_config.variant]
#return InvokeAIAppConfig.get_config().legacy_conf_dir / ckpt_configs[version][variant]
return InvokeAIAppConfig.get_config().root_dir / "configs" / "stable-diffusion" / ckpt_configs[version][variant]
except:
return None
# TODO: rework
def _convert_ckpt_and_cache(
version: BaseModelType,
model_config: Union[StableDiffusion1Model.CheckpointConfig, StableDiffusion2Model.CheckpointConfig],
output_path: str,
) -> str:
"""
Convert the checkpoint model indicated in mconfig into a
diffusers, cache it to disk, and return Path to converted
file. If already on disk then just returns Path.
"""
app_config = InvokeAIAppConfig.get_config()
if model_config.config is None:
model_config.config = _select_ckpt_config(version, model_config.variant)
if model_config.config is None:
raise Exception(f"Model variant {model_config.variant} not supported for {version}")
weights = app_config.root_dir / model_config.path
config_file = app_config.root_dir / model_config.config
output_path = Path(output_path)
if version == BaseModelType.StableDiffusion1:
upcast_attention = False
prediction_type = SchedulerPredictionType.Epsilon
elif version == BaseModelType.StableDiffusion2:
upcast_attention = model_config.upcast_attention
prediction_type = model_config.prediction_type
else:
raise Exception(f"Unknown model provided: {version}")
# return cached version if it exists
if output_path.exists():
return output_path
# TODO: I think that it more correctly to convert with embedded vae
# as if user will delete custom vae he will got not embedded but also custom vae
#vae_ckpt_path, vae_model = self._get_vae_for_conversion(weights, mconfig)
# to avoid circular import errors
from ..convert_ckpt_to_diffusers import convert_ckpt_to_diffusers
with SilenceWarnings():
convert_ckpt_to_diffusers(
weights,
output_path,
model_version=version,
model_variant=model_config.variant,
original_config_file=config_file,
extract_ema=True,
upcast_attention=upcast_attention,
prediction_type=prediction_type,
scan_needed=True,
model_root=app_config.models_path,
)
return output_path

Some files were not shown because too many files have changed in this diff Show More