Commit Graph

94 Commits

Author SHA1 Message Date
Jonathan
0d7205ff79
Handle mixed-dtype mismatches in autocast linear and conv wrappers (#9006)
* Handle CustomConv2d bias dtype mismatches

* Fix mixed-dtype autocast regressions

* Format custom_conv2d with ruff
2026-04-20 20:31:35 +00:00
CypherNaugh_0x
9deb545cc1
External models (Gemini Nano Banana & OpenAI GPT Image) (#8633) (#8884)
* feat: initial external model support

* feat: support reference images for external models

* fix: sorting lint error

* chore: hide Reidentify button for external models

* review: enable auto-install/remove fro external models

* feat: show external mode name during install

* review: model descriptions

* review: implemented review comments

* review: added optional seed control for external models

* chore: fix linter warning

* review: save api keys to a seperate file

* docs: updated external model docs

* chore: fix linter errors

* fix: sync configured external starter models on startup

* feat(ui): add provider-specific external generation nodes

* feat: expose external panel schemas in model configs

* feat(ui): drive external panels from panel schema

* docs: sync app config docstring order

* feat: add gemini 3.1 flash image preview starter model

* feat: update gemini image model limits

* fix: resolve TypeScript errors and move external provider config to api_keys.yaml

Add 'external', 'external_image_generator', and 'external_api' to Zod
enum schemas (zBaseModelType, zModelType, zModelFormat) to match the
generated OpenAPI types. Remove redundant union workarounds from
component prop types and Record definitions.

Fix type errors in ModelEdit (react-hook-form Control invariance),
parsing.tsx (model identifier narrowing), buildExternalGraph (edge
typing), and ModelSettings import/export buttons.

Move external_gemini_base_url and external_openai_base_url into
api_keys.yaml alongside the API keys so all external provider config
lives in one dedicated file, separate from invokeai.yaml.

* feat: add resolution presets and imageConfig support for Gemini 3 models

Add combined resolution preset selector for external models that maps
aspect ratio + image size to fixed dimensions. Gemini 3 Pro and 3.1 Flash
now send imageConfig (aspectRatio + imageSize) via generationConfig instead
of text-based aspect ratio hints used by Gemini 2.5 Flash.

Backend: ExternalResolutionPreset model, resolution_presets capability field,
image_size on ExternalGenerationRequest, and Gemini provider imageConfig logic.

Frontend: ExternalSettingsAccordion with combo resolution select, dimension
slider disabling for fixed-size models, and panel schema constraint wiring
for Steps/Guidance/Seed controls.

* Remove unused external model fields and add provider-specific parameters

- Remove negative_prompt, steps, guidance, reference_image_weights,
  reference_image_modes from external model nodes (unused by any provider)
- Remove supports_negative_prompt, supports_steps, supports_guidance
  from ExternalModelCapabilities
- Add provider_options dict to ExternalGenerationRequest for
  provider-specific parameters
- Add OpenAI-specific fields: quality, background, input_fidelity
- Add Gemini-specific fields: temperature, thinking_level
- Add new OpenAI starter models: GPT Image 1.5, GPT Image 1 Mini,
  DALL-E 3, DALL-E 2
- Fix OpenAI provider to use output_format (GPT Image) vs
  response_format (DALL-E) and send model ID in requests
- Add fixed aspect ratio sizes for OpenAI models (bucketing)
- Add ExternalProviderRateLimitError with retry logic for 429 responses
- Add provider-specific UI components in ExternalSettingsAccordion
- Simplify ParamSteps/ParamGuidance by removing dead external overrides
- Update all backend and frontend tests

* Chore Ruff check & format

* Chore typegen

* feat: full canvas workflow integration for external models

- Add missing aspect ratios (4:5, 5:4, 8:1, 4:1, 1:4, 1:8) to type
  system for external model support
- Sync canvas bbox when external model resolution preset is selected
- Use params preset dimensions in buildExternalGraph to prevent
  "unsupported aspect ratio" errors
- Lock all bbox controls (resize handles, aspect ratio select,
  width/height sliders, swap/optimal buttons) for external models
  with fixed dimension presets
- Disable denoise strength slider for external models (not applicable)
- Sync bbox aspect ratio changes back to paramsSlice for external models
- Initialize bbox dimensions when switching to an external model

* Chore typegen Linux seperator

* feat: full canvas workflow integration for external models
- Update buildExternalGraph test to include dimensions in mock params

* Merge remote-tracking branch 'upstream/main' into external-models

* Chore pnpm fix

* add missing parameter

* docs: add External Models guide with Gemini and OpenAI provider pages

* fix(external-models): address PR review feedback

- Gemini recall: write temperature, thinking_level, image_size to image metadata;
  wire external graph as metadata receiver; add recall handlers.
- Canvas: gate regional guidance, inpaint mask, and control layer for external models.
- Canvas: throw a clear error on outpainting for external models (was falling back to
  inpaint and hitting an API-side mask/image size mismatch).
- Workflow editor: add ui_model_provider_id filter so OpenAI and Gemini nodes only
  list their own provider's models.
- Workflow editor: silently drop seed when the selected model does not support it
  instead of raising a capability error.
- Remove the legacy external_image_generation invocation and the graph-builder
  fallback; providers must register a dedicated node.
- Regenerate schema.ts.
- remove Gemini debug dumps to outputs/external_debug

* fix(external-models): resolve TSC errors in metadata parsing and external graph

- Export imageSizeChanged from paramsSlice (required by the new ImageSize
  recall handler).
- Emit the external graph's metadata model entry via zModelIdentifierField
  since ExternalApiModelConfig is not part of the AnyModelConfig union.

* chore: prettier format ModelIdentifierFieldInputComponent

* fix: remove unsupported thinkingConfig from Gemini image models and restrict GPT Image models to txt2img

* chore typegen

* chore(docs): regenerate settings.json for external provider fields

* fix(external): fix mask handling and mode support for external providers

- Remove img2img and inpaint modes from Gemini models (Gemini has no
  bitmap mask or dedicated edit API; image editing works via reference
  images in the UI)
- Fix DALL-E 2 inpainting: convert grayscale mask to RGBA with alpha
  channel transparency (OpenAI expects transparent=edit area) and
  convert init image to RGBA when mask is present

* fix(external): update mode support and UI for external providers

- Remove DALL-E 2 from starter models (deprecated, shutdown May 12 2026)
- Enable img2img for GPT Image 1/1.5/1-mini (supports edits endpoint)
- Set Gemini models to txt2img only (no mask/edit API; editing via
  ref images)
- Hide mode/init_image/mask_image fields on Gemini node (not usable)
- Hide mask_image field on OpenAI node (no model supports inpaint)

* Chore typegen

* fix(external): improve OpenAI node UX and disable cache by default

- Hide OpenAI node's mode and init_image fields: OpenAI's API has no
  img2img/inpaint distinction (the edits endpoint is invoked
  automatically when reference images are provided). init_image is
  functionally identical to a reference image and was misleading users.
- Default use_cache to False for external image generation nodes:
  external API calls are non-deterministic and incur usage costs.
  Cache hits returned stale image references that did not produce new
  gallery entries on repeat invokes.

* fix(external): duplicate cached images on cache hit instead of skipping

External image generation nodes use the standard invocation cache, but
returning the cached output (with stale image_name references) on cache
hits resulted in no new gallery entries — the Invoke button would spin
indefinitely on repeat invokes with identical parameters.

Override invoke_internal so that on cache hit, the cached images are
loaded and re-saved as new gallery entries. The expensive API call is
still skipped (cost saving), but the user sees a new image as expected.

* Chore typegen + ruff

* CHore ruff format

* fix(external): restore OpenAI advanced settings on Remix recall

Remix recall iterates through ImageMetadataHandlers but only Gemini's
temperature handler was wired up — OpenAI's quality, background, and
input_fidelity were stored in image metadata but never parsed back into
the params slice. Add the three missing handlers so Remix restores
these settings as expected.

---------

Co-authored-by: Alexander Eichhorn <alex@eichhorn.dev>
Co-authored-by: Alexander Eichhorn <alex@code-with.us>
Co-authored-by: Lincoln Stein <lincoln.stein@gmail.com>
2026-04-20 17:13:26 +00:00
Lincoln Stein
b42274a57e
Feat[model support]: Qwen Image — full pipeline with edit, generate LoRA, GGUF, quantization, and UI (#9000) 2026-04-12 14:39:13 +02:00
Jonathan
ee600973ed
Broaden text encoder partial-load recovery (#9034) 2026-04-09 20:09:40 -04:00
Jonathan
dc5007fe95
Fix/model cache Qwen/CogView4 cancel repair (#8959)
* Repair partially loaded Qwen models after cancel to avoid device mismatches

* ruff

* Repair CogView4 text encoder after canceled partial loads

* Avoid MPS CI crash in repair regression test

* Fix MPS device assertion in repair test
2026-03-15 10:04:15 -04:00
copilot-swe-agent[bot]
b7afd9b5b3 Fix test failures caused by MagicMock TypeError
Configure mock logger to return a valid log level for getEffectiveLevel()
to prevent TypeError when comparing with logging.DEBUG constant.

The issue was that ModelCache._log_cache_state() checks
self._logger.getEffectiveLevel() > logging.DEBUG, and when the logger
is a MagicMock without configuration, getEffectiveLevel() returns another
MagicMock, causing a TypeError when compared with an int.

Fixes all 4 test failures in test_model_cache_timeout.py

Co-authored-by: lstein <111189+lstein@users.noreply.github.com>
2025-12-24 05:42:45 +00:00
Lincoln Stein
9d1de81fe2 (style) correct ruff formatting error 2025-12-24 00:19:25 -05:00
copilot-swe-agent[bot]
8d76b4e4d4 Fix ruff whitespace errors and improve timeout logging
- Remove all trailing whitespace (W293 errors)
- Add debug logging when timeout fires but activity detected
- Add debug logging when timeout fires but cache is empty
- Only log "Clearing model cache" message when actually clearing
- Prevents misleading timeout messages during active generation

Co-authored-by: lstein <111189+lstein@users.noreply.github.com>
2025-12-24 04:05:57 +00:00
copilot-swe-agent[bot]
c3217d8a08 Address code review feedback
- Remove unused variable in test
- Add clarifying comment for daemon thread setting
- Add detailed comment explaining cache clearing with 1000 GB value
- Improve code documentation

Co-authored-by: lstein <111189+lstein@users.noreply.github.com>
2025-12-24 00:27:39 +00:00
copilot-swe-agent[bot]
75a14e2a4b Add unit tests for model cache timeout functionality
- Created test_model_cache_timeout.py with comprehensive tests
- Tests timeout clearing behavior
- Tests activity resetting timeout
- Tests no-timeout default behavior
- Tests shutdown canceling timers

Co-authored-by: lstein <111189+lstein@users.noreply.github.com>
2025-12-24 00:24:31 +00:00
psychedelicious
454d05bbde
refactor: model manager v3 (#8607)
* feat(mm): add UnknownModelConfig

* refactor(ui): move model categorisation-ish logic to central location, simplify model manager models list

* refactor(ui)refactor(ui): more cleanup of model categories

* refactor(ui): remove unused excludeSubmodels

I can't remember what this was for and don't see any reference to it.
Maybe it's just remnants from a previous implementation?

* feat(nodes): add unknown as model base

* chore(ui): typegen

* feat(ui): add unknown model base support in ui

* feat(ui): allow changing model type in MM, fix up base and variant selects

* feat(mm): omit model description instead of making it "base type filename model"

* feat(app): add setting to allow unknown models

* feat(ui): allow changing model format in MM

* feat(app): add the installed model config to install complete events

* chore(ui): typegen

* feat(ui): toast warning when installed model is unidentified

* docs: update config docstrings

* chore(ui): typegen

* tests(mm): fix test for MM, leave the UnknownModelConfig class in the list of configs

* tidy(ui): prefer types from zod schemas for model attrs

* chore(ui): lint

* fix(ui): wrong translation string

* feat(mm): normalized model storage

Store models in a flat directory structure. Each model is in a dir named
its unique key (a UUID). Inside that dir is either the model file or the
model dir.

* feat(mm): add migration to flat model storage

* fix(mm): normalized multi-file/diffusers model installation no worky

now worky

* refactor: port MM probes to new api

- Add concept of match certainty to new probe
- Port CLIP Embed models to new API
- Fiddle with stuff

* feat(mm): port TIs to new API

* tidy(mm): remove unused probes

* feat(mm): port spandrel to new API

* fix(mm): parsing for spandrel

* fix(mm): loader for clip embed

* fix(mm): tis use existing weight_files method

* feat(mm): port vae to new API

* fix(mm): vae class inheritance and config_path

* tidy(mm): patcher types and import paths

* feat(mm): better errors when invalid model config found in db

* feat(mm): port t5 to new API

* feat(mm): make config_path optional

* refactor(mm): simplify model classification process

Previously, we had a multi-phase strategy to identify models from their
files on disk:
1. Run each model config classes' `matches()` method on the files. It
checks if the model could possibly be an identified as the candidate
model type. This was intended to be a quick check. Break on the first
match.
2. If we have a match, run the config class's `parse()` method. It
derive some additional model config attrs from the model files. This was
intended to encapsulate heavier operations that may require loading the
model into memory.
3. Derive the common model config attrs, like name, description,
calculate the hash, etc. Some of these are also heavier operations.

This strategy has some issues:
- It is not clear how the pieces fit together. There is some
back-and-forth between different methods and the config base class. It
is hard to trace the flow of logic until you fully wrap your head around
the system and therefore difficult to add a model architecture to the
probe.
- The assumption that we could do quick, lightweight checks before
heavier checks is incorrect. We often _must_ load the model state dict
in the `matches()` method. So there is no practical perf benefit to
splitting up the responsibility of `matches()` and `parse()`.
- Sometimes we need to do the same checks in `matches()` and `parse()`.
In these cases, splitting the logic is has a negative perf impact
because we are doing the same work twice.
- As we introduce the concept of an "unknown" model config (i.e. a model
that we cannot identify, but still record in the db; see #8582), we will
_always_ run _all_ the checks for every model. Therefore we need not try
to defer heavier checks or resource-intensive ops like hashing. We are
going to do them anyways.
- There are situations where a model may match multiple configs. One
known case are SD pipeline models with merged LoRAs. In the old probe
API, we relied on the implicit order of checks to know that if a model
matched for pipeline _and_ LoRA, we prefer the pipeline match. But, in
the new API, we do not have this implicit ordering of checks. To resolve
this in a resilient way, we need to get all matches up front, then use
tie-breaker logic to figure out which should win (or add "differential
diagnosis" logic to the matchers).
- Field overrides weren't handled well by this strategy. They were only
applied at the very end, if a model matched successfully. This means we
cannot tell the system "Hey, this model is type X with base Y. Trust me
bro.". We cannot override the match logic. As we move towards letting
users correct mis-identified models (see #8582), this is a requirement.

We can simplify the process significantly and better support "unknown"
models.

Firstly, model config classes now have a single `from_model_on_disk()`
method that attempts to construct an instance of the class from the
model files. This replaces the `matches()` and `parse()` methods.

If we fail to create the config instance, a special exception is raised
that indicates why we think the files cannot be identified as the given
model config class.

Next, the flow for model identification is a bit simpler:
- Derive all the common fields up-front (name, desc, hash, etc).
- Merge in overrides.
- Call `from_model_on_disk()` for every config class, passing in the
fields. Overrides are handled in this method.
- Record the results for each config class and choose the best one.

The identification logic is a bit more verbose, with the special
exceptions and handling of overrides, but it is very clear what is
happening.

The one downside I can think of for this strategy is we do need to check
every model type, instead of stopping at the first match. It's a bit
less efficient. In practice, however, this isn't a hot code path, and
the improved clarity is worth far more than perf optimizations that the
end user will likely never notice.

* refactor(mm): remove unused methods in config.py

* refactor(mm): add model config parsing utils

* fix(mm): abstractmethod bork

* tidy(mm): clarify that model id utils are private

* fix(mm): fall back to UnknownModelConfig correctly

* feat(mm): port CLIPVisionDiffusersConfig to new api

* feat(mm): port SigLIPDiffusersConfig to new api

* feat(mm): make match helpers more succint

* feat(mm): port flux redux to new api

* feat(mm): port ip adapter to new api

* tidy(mm): skip optimistic override handling for now

* refactor(mm): continue iterating on config

* feat(mm): port flux "control lora" and t2i adapter to new api

* tidy(ui): use Extract to get model config types

* fix(mm): t2i base determination

* feat(mm): port cnet to new api

* refactor(mm): add config validation utils, make it all consistent and clean

* feat(mm): wip port of main models to new api

* feat(mm): wip port of main models to new api

* feat(mm): wip port of main models to new api

* docs(mm): add todos

* tidy(mm): removed unused model merge class

* feat(mm): wip port main models to new api

* tidy(mm): clean up model heuristic utils

* tidy(mm): clean up ModelOnDisk caching

* tidy(mm): flux lora format util

* refactor(mm): make config classes narrow

Simpler logic to identify, less complexity to add new model, fewer
useless attrs that do not relate to the model arch, etc

* refactor(mm): diffusers loras

w

* feat(mm): consistent naming for all model config classes

* fix(mm): tag generation & scattered probe fixes

* tidy(mm): consistent class names

* refactor(mm): split configs into separate files

* docs(mm): add comments for identification utils

* chore(ui): typegen

* refactor(mm): remove legacy probe, new configs dir structure, update imports

* fix(mm): inverted condition

* docs(mm): update docsstrings in factory.py

* docs(mm): document flux variant attr

* feat(mm): add helper method for legacy configs

* feat(mm): satisfy type checker in flux denoise

* docs(mm): remove extraneous comment

* fix(mm): ensure unknown model configs get unknown attrs

* fix(mm): t5 identification

* fix(mm): sdxl ip adapter identification

* feat(mm): more flexible config matching utils

* fix(mm): clip vision identification

* feat(mm): add sanity checks before probing paths

* docs(mm): add reminder for self for field migrations

* feat(mm): clearer naming for main config class hierarchy

* feat(mm): fix clip vision starter model bases, add ref to actual models

* feat(mm): add model config schema migration logic

* fix(mm): duplicate import

* refactor(mm): split big migration into 3

Split the big migration that did all of these things into 3:

- Migration 22: Remove unique contraint on base/name/type in models
table
- Migration 23: Migrate configs to v6.8.0 schemas
- Migration 24: Normalize file storage

* fix(mm): pop base/type/format when creating unknown model config

* fix(db): migration 22 insert only real cols

* fix(db): migration 23 fall back to unknown model when config change fails

* feat(db): run migrations 23 and 24

* fix(mm): false negative on flux lora

* fix(mm): vae checkpoint probe checking for dir instead of file

* fix(mm): ModelOnDisk skips dirs when looking for weights

Previously a path w/ any of the known weights suffixes would be seen as
a weights file, even if it was a directory. We now check to ensure the
candidate path is actually a file before adding it to the list of
weights.

* feat(mm): add method to get main model defaults from a base

* feat(mm): do not log when multiple non-unknown model matches

* refactor(mm): continued iteration on model identifcation

* tests(mm): refactor model identification tests

Overhaul of model identification (probing) tests. Previously we didn't
test the correctness of probing except in a few narrow cases - now we
do.

See tests/model_identification/README.md for a detailed overview of the
new test setup. It includes instructions for adding a new test case. In
brief:

- Download the model you want to add as a test case
- Run a script against it to generate the test model files
- Fill in the expected model type/format/base/etc in the generated test
metadata JSON file

Included test cases:
- All starter models
- A handful of other models that I had installed
- Models present in the previous test cases as smoke tests, now also
tested for correctness

* fix(mm): omit type/format/base when creating unknown config instance

* feat(mm): use ValueError for model id sanity checks

* feat(mm): add flag for updating models to allow class changes

* tests(mm): fix remaining MM tests

* feat: allow users to edit models freely

* feat(ui): add warning for model settings edit

* tests(mm): flux state dict tests

* tidy: remove unused file

* fix(mm): lora state dict loading in model id

* feat(ui): use translation string for model edit warning

* docs(db): update version numbers in migration comments

* chore: bump version to v6.9.0a1

* docs: update model id readme

* tests(mm): attempt to fix windows model id tests

* fix(mm): issue with deleting single file models

* feat(mm): just delete the dir w/ rmtree when deleting model

* tests(mm): windows CI issue

* fix(ui): typegen schema sync

* fix(mm): fixes for migration 23

- Handle CLIP Embed and Main SD models missing variant field
- Handle errors when calling the discriminator function, previously only
handled ValidationError but it could be a ValueError or something else
- Better logging for config migration

* chore: bump version to v6.9.0a2

* chore: bump version to v6.9.0a3
2025-10-15 10:18:53 +11:00
psychedelicious
a8a07598c8 chore: ruff 2025-08-18 21:14:00 +10:00
psychedelicious
23206e22e8 tests: skip excessively flaky MPS-specific tests in CI 2025-08-18 21:14:00 +10:00
Kevin Turner
52a8ad1c18 chore: rename model.size to model.file_size
to disambiguate from RAM size or pixel size
2025-04-10 09:53:03 +10:00
Kevin Turner
98260a8efc test: add size field to test model configs 2025-04-10 09:53:03 +10:00
Billy
182580ff69 Imports 2025-03-26 12:55:10 +11:00
Billy
8e9d5c1187 Ruff formatting 2025-03-26 12:30:31 +11:00
Billy
99aac5870e Remove star imports 2025-03-26 12:27:00 +11:00
Ryan Dick
5357d6e08e Rename ConcatenatedLoRALayer to MergedLayerPatch. And other minor cleanup. 2025-01-28 14:51:35 +00:00
Ryan Dick
28514ba59a Update ConcatenatedLoRALayer to work with all sub-layer types. 2025-01-28 14:51:35 +00:00
Ryan Dick
e2f05d0800 Add unit tests for LoKR patch layers. The new tests trigger a bug when LoKR layers are applied to BnB-quantized layers (also impacts several other LoRA variant types). 2025-01-22 09:20:40 +11:00
Ryan Dick
36a3869af0 Add keep_ram_copy_of_weights config option. 2025-01-16 15:35:25 +00:00
Ryan Dick
c76d08d1fd Add keep_ram_copy option to CachedModelOnlyFullLoad. 2025-01-16 15:08:23 +00:00
Ryan Dick
04087c38ce Add keep_ram_copy option to CachedModelWithPartialLoad. 2025-01-16 14:51:44 +00:00
Ryan Dick
974b4671b1 Deprecate the ram and vram configs to make the migration to dynamic
memory limits smoother for users who had previously overriden these
values.
2025-01-07 16:45:29 +00:00
Ryan Dick
d7ab464176 Offload the current model when locking if it is already partially loaded and we have insufficient VRAM. 2025-01-07 02:53:44 +00:00
Ryan Dick
5eafe1ec7a Fix ModelCache execution device selection in unit tests. 2025-01-07 01:20:15 +00:00
Ryan Dick
a167632f09 Calculate model cache size limits dynamically based on the available RAM / VRAM. 2025-01-07 01:14:20 +00:00
Ryan Dick
402dd840a1 Add seed to flaky unit test. 2025-01-07 00:31:00 +00:00
Ryan Dick
d0bfa019be Add 'enable_partial_loading' config flag. 2025-01-07 00:31:00 +00:00
Ryan Dick
535e45cedf First pass at adding partial loading support to the ModelCache. 2025-01-07 00:30:58 +00:00
Ryan Dick
9a0a226ce1 Fix bitsandbytes imports in unit tests on MacOS. 2024-12-30 10:41:48 -05:00
Ryan Dick
52fc5a64d4 Add a unit test for a LoRA patch applied to a quantized linear layer with weights streamed from CPU to GPU. 2024-12-29 17:14:55 +00:00
Ryan Dick
a8bef59699 First pass at making custom layer patches work with weights streamed from the CPU to the GPU. 2024-12-29 17:01:37 +00:00
Ryan Dick
6d49ee839c Switch the LayerPatcher to use 'custom modules' to manage layer patching. 2024-12-29 01:18:30 +00:00
Ryan Dick
918f541af8 Add unit test for a SetParameterLayer patch applied to a CustomFluxRMSNorm layer. 2024-12-28 20:44:48 +00:00
Ryan Dick
93e76b61d6 Add CustomFluxRMSNorm layer. 2024-12-28 20:33:38 +00:00
Ryan Dick
f2981979f9 Get custom layer patches working with all quantized linear layer types. 2024-12-27 22:00:22 +00:00
Ryan Dick
ef970a1cdc Add support for FluxControlLoRALayer in CustomLinear layers and add a unit test for it. 2024-12-27 21:00:47 +00:00
Ryan Dick
5ee7405f97 Add more unit tests for custom module LoRA patching: multiple LoRAs and ConcatenatedLoRALayers. 2024-12-27 19:47:21 +00:00
Ryan Dick
e24e386a27 Add support for patches to CustomModuleMixin and add a single unit test (more to come). 2024-12-27 18:57:13 +00:00
Ryan Dick
b06d61e3c0 Improve custom layer wrap/unwrap logic. 2024-12-27 16:29:48 +00:00
Ryan Dick
7d6ab0ceb2 Add a CustomModuleMixin class with a flag for enabling/disabling autocasting (since it incurs some runtime speed overhead.) 2024-12-26 20:08:30 +00:00
Ryan Dick
9692a36dd6 Use a fixture to parameterize tests in test_all_custom_modules.py so that a fresh instance of the layer under test is initialized for each test. 2024-12-26 19:41:25 +00:00
Ryan Dick
b0b699a01f Add unit test to test that isinstance(...) behaves as expected with custom module types. 2024-12-26 18:45:56 +00:00
Ryan Dick
a8b2c4c3d2 Add inference tests for all custom module types (i.e. to test autocasting from cpu to device). 2024-12-26 18:33:46 +00:00
Ryan Dick
03944191db Split test_autocast_modules.py into separate test files to mirror the source file structure. 2024-12-24 22:29:11 +00:00
Ryan Dick
987c9ae076 Move custom autocast modules to separate files in a custom_modules/ directory. 2024-12-24 22:21:31 +00:00
Ryan Dick
0fc538734b Skip flaky test when running on Github Actions, and further reduce peak unit test memory. 2024-12-24 14:32:11 +00:00
Ryan Dick
7214d4969b Workaround a weird quirk of QuantState.to() and add a unit test to exercise it. 2024-12-24 14:32:11 +00:00