-
Notifications
You must be signed in to change notification settings - Fork 4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: 44 make flash attention configurable #60
base: main
Are you sure you want to change the base?
Conversation
* fix: change pre-cmmit autoupdate schedule to monthly * fix: change the merge strategy for Changelog to Union * fix: add .envrc to .gitignore * ci: ignore pre-commit-config and readthedocs for changelog updates * ci: fix to correct hpc workflow call * fix: update precommit config * chore: update pre-commits * feat: add codeowners file * chore: update dependencies * ci: add hpc-config * docs: changelog * fix: respond to review comments --------- Co-authored-by: Jesper Dramsch <jesper.dramsch@ecmwf.int>
* feat: add configurability to dropout in MultiHeadSelfAttention Co-authored-by: Rilwan (Akanni) Adewoyin <18564167+Rilwan-Adewoyin@users.noreply.github.com> * test: adjust to dropout_p * doc: update changelog * Feature/integrate reusable workflows (#16) * ci: add public pr label * ci: add readthedocs update check * ci: add downstream ci * ci: add ci-config * chore(deps): remove unused dependency * docs: update changelog * ci: switch to main * chore: changelog 0.2.1 * Update error messages from invalid sub_graph in model instantiation (#20) * ci: inherit pypi publish flow (#17) * ci: inherit pypi publish flow Co-authored-by: Helen Theissen <helen.theissen@ecmwf.int> * docs: add to changelog * fix: typo in reusable workflow * fix: another typo * chore: bump actions/setup-python to v5 * ci: run downstream-ci for changes in src and tests * docs: update changelog --------- Co-authored-by: Helen Theissen <helen.theissen@ecmwf.int> * Update CHANGELOG.md to KeepChangelog format * [pre-commit.ci] pre-commit autoupdate (#25) updates: - [github.com/psf/black-pre-commit-mirror: 24.4.2 → 24.8.0](psf/black-pre-commit-mirror@24.4.2...24.8.0) - [github.com/astral-sh/ruff-pre-commit: v0.4.6 → v0.6.2](astral-sh/ruff-pre-commit@v0.4.6...v0.6.2) - [github.com/tox-dev/pyproject-fmt: 2.1.3 → 2.2.1](tox-dev/pyproject-fmt@2.1.3...2.2.1) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * Ci/changelog-release-updater (#26) * ci: add changelof release updater * docs: update changelog * Feature/integrate reusable workflows (#16) * ci: add public pr label * ci: add readthedocs update check * ci: add downstream ci * ci: add ci-config * chore(deps): remove unused dependency * docs: update changelog * ci: switch to main * chore: changelog 0.2.1 * Update error messages from invalid sub_graph in model instantiation (#20) * ci: inherit pypi publish flow (#17) * ci: inherit pypi publish flow Co-authored-by: Helen Theissen <helen.theissen@ecmwf.int> * docs: add to changelog * fix: typo in reusable workflow * fix: another typo * chore: bump actions/setup-python to v5 * ci: run downstream-ci for changes in src and tests * docs: update changelog --------- Co-authored-by: Helen Theissen <helen.theissen@ecmwf.int> * Update CHANGELOG.md to KeepChangelog format * Ci/changelog-release-updater (#26) * ci: add changelof release updater * docs: update changelog --------- Co-authored-by: Rilwan (Akanni) Adewoyin <18564167+Rilwan-Adewoyin@users.noreply.github.com> Co-authored-by: Gert Mertes <gert.mertes@ecmwf.int> Co-authored-by: Mario Santa Cruz <48736305+JPXKQX@users.noreply.github.com> Co-authored-by: Jesper Dramsch <jesper.dramsch@ecmwf.int> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
xfail for MultiHeadSelfAttention
for more information, see https://pre-commit.ci
* fix: change pre-cmmit autoupdate schedule to monthly * fix: change the merge strategy for Changelog to Union * fix: add .envrc to .gitignore * ci: ignore pre-commit-config and readthedocs for changelog updates * ci: fix to correct hpc workflow call * fix: update precommit config * chore: update pre-commits * feat: add codeowners file * chore: update dependencies * ci: add hpc-config * docs: changelog * fix: respond to review comments --------- Co-authored-by: Jesper Dramsch <jesper.dramsch@ecmwf.int>
* feat: add configurability to dropout in MultiHeadSelfAttention Co-authored-by: Rilwan (Akanni) Adewoyin <18564167+Rilwan-Adewoyin@users.noreply.github.com> * test: adjust to dropout_p * doc: update changelog * Feature/integrate reusable workflows (#16) * ci: add public pr label * ci: add readthedocs update check * ci: add downstream ci * ci: add ci-config * chore(deps): remove unused dependency * docs: update changelog * ci: switch to main * chore: changelog 0.2.1 * Update error messages from invalid sub_graph in model instantiation (#20) * ci: inherit pypi publish flow (#17) * ci: inherit pypi publish flow Co-authored-by: Helen Theissen <helen.theissen@ecmwf.int> * docs: add to changelog * fix: typo in reusable workflow * fix: another typo * chore: bump actions/setup-python to v5 * ci: run downstream-ci for changes in src and tests * docs: update changelog --------- Co-authored-by: Helen Theissen <helen.theissen@ecmwf.int> * Update CHANGELOG.md to KeepChangelog format * [pre-commit.ci] pre-commit autoupdate (#25) updates: - [github.com/psf/black-pre-commit-mirror: 24.4.2 → 24.8.0](psf/black-pre-commit-mirror@24.4.2...24.8.0) - [github.com/astral-sh/ruff-pre-commit: v0.4.6 → v0.6.2](astral-sh/ruff-pre-commit@v0.4.6...v0.6.2) - [github.com/tox-dev/pyproject-fmt: 2.1.3 → 2.2.1](tox-dev/pyproject-fmt@2.1.3...2.2.1) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * Ci/changelog-release-updater (#26) * ci: add changelof release updater * docs: update changelog * Feature/integrate reusable workflows (#16) * ci: add public pr label * ci: add readthedocs update check * ci: add downstream ci * ci: add ci-config * chore(deps): remove unused dependency * docs: update changelog * ci: switch to main * chore: changelog 0.2.1 * Update error messages from invalid sub_graph in model instantiation (#20) * ci: inherit pypi publish flow (#17) * ci: inherit pypi publish flow Co-authored-by: Helen Theissen <helen.theissen@ecmwf.int> * docs: add to changelog * fix: typo in reusable workflow * fix: another typo * chore: bump actions/setup-python to v5 * ci: run downstream-ci for changes in src and tests * docs: update changelog --------- Co-authored-by: Helen Theissen <helen.theissen@ecmwf.int> * Update CHANGELOG.md to KeepChangelog format * Ci/changelog-release-updater (#26) * ci: add changelof release updater * docs: update changelog --------- Co-authored-by: Rilwan (Akanni) Adewoyin <18564167+Rilwan-Adewoyin@users.noreply.github.com> Co-authored-by: Gert Mertes <gert.mertes@ecmwf.int> Co-authored-by: Mario Santa Cruz <48736305+JPXKQX@users.noreply.github.com> Co-authored-by: Jesper Dramsch <jesper.dramsch@ecmwf.int> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
xfail for MultiHeadSelfAttention
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
….com:ecmwf/anemoi-core into feature/44-make-flash-attention-configurable
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks OK to me (I haven't run the code, but Cathal did) - nice work @theissenhelen.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Brilliant work exists here,
Just a few thoughts and questions,
Looks great overall.
|
||
For the flash attention implementation, two additional parameters are available: softcap, use_alibi_slopes | ||
|
||
softcap: Softcapping prevents the logits from grwoing excessively large |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
softcap: Softcapping prevents the logits from grwoing excessively large | |
softcap: Softcapping prevents the logits from growing excessively large |
|
||
softcap: Softcapping prevents the logits from grwoing excessively large | ||
|
||
use_alibi_slopes: Adds bias of (-alibi_slope * |i + seqlen_k - seqlen_q - j|) to the attention score of |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use_alibi_slopes: Adds bias of (-alibi_slope * |i + seqlen_k - seqlen_q - j|) to the attention score of | |
use_alibi_slopes: Adds bias of `(-alibi_slope * |i + seqlen_k - seqlen_q - j|)` to the attention score of |
Please change model.processor.attention_implementation in the config to one of: {attn_funcs.keys()}" | ||
LOGGER.info(f"Using {self.attention_implementation}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this reference the config or just say that the arg was wrong?
softcap : float, optional | ||
Anything > 0 activates softcapping attention, by default None | ||
use_alibi_slopes : bool, optional | ||
Adds bias |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Given that these are only used for flash_attention
, if kwargs are needed for the other attention types, we may end up with a large number of unused kwargs, could it make sense to add attention_kwargs: dict[str, Any]
, and use that?
if self.attention_implementation == "flex_attention": | ||
assert ( | ||
self.embed_dim / self.num_heads >= 16 | ||
), f"Embedding dimension ({self.embed_dim}) devided by number of heads ({self.num_heads}) must be >= 16." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
minor typo, 'devided' -> 'divided'
# initalise the attn func here | ||
self.attention = attn_funcs[self.attention_implementation]() | ||
|
||
if self.attention_implementation == "flex_attention": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe this if should go before self.attention = attn_funcs[self.attention_implementation]()
@@ -14,6 +14,7 @@ processor: | |||
num_heads: 16 # GraphTransformer or Transformer only | |||
window_size: 512 | |||
dropout_p: 0.0 # GraphTransformer | |||
attention_implementation: flash_attention # flash_attention, flex_attention, scaled_dot_product_attention | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are use_alibi_slopes
and softcap
meant to also be passed via the config? Since in the default we have flash_attention I was wondering if it would be better to also include the defaults for those two other parameters
@@ -0,0 +1,7 @@ | |||
[pytest] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Out of curiosity, what's is this file doing?
NotImplementedError("Error. Causal not yet implemented in FlexAttn in Anemoi.") | ||
|
||
# This assumes seq_len never changes | ||
# across iterations and stages |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what's this comment about? 🤔
) | ||
self.attention = functools.partial( | ||
flex_attention, block_mask=self.block_mask | ||
) # Cache the block mask (recomended in attn blog post) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could be nice adding reference to that blog?
super().__init__() | ||
import flash_attn | ||
|
||
if version.parse(flash_attn.__version__) < version.parse("2.6.0"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Flash attention can't yet be installed through the pyproject right? Do we have some checks ahead of this to confirm that the user has installed flash_attn before trying to use it?
Current setup:
Now:
for aLiB:i adds a function to compute the slopes according to the number of attention heads
Todo: