Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: 44 make flash attention configurable #60

Open
wants to merge 85 commits into
base: main
Choose a base branch
from

Conversation

theissenhelen
Copy link
Collaborator

@theissenhelen theissenhelen commented Jan 6, 2025

Current setup:

  • If flash-attn is available in the environment, the MultiHeadSelfAttention module automatically imports the corresponding attention function. In inference however we do not have that information.

Now:

  • flex attention available
  • user specifies whether flash-attn, flex attention or scaled dot product attention should be used in the model config.
  • adds configurable parameters (soft cap, aLiBi) for flash attention
    for aLiB:i adds a function to compute the slopes according to the number of attention heads
  • scaled dot product attention now supports sliding window (making it numerically equivalent to flash/flex)

Todo:

  • test various attention options
  • adjust test case coverage

theissenhelen and others added 30 commits September 27, 2024 08:42
* fix: change pre-cmmit autoupdate schedule to monthly

* fix: change the merge strategy for Changelog to Union

* fix: add .envrc to .gitignore

* ci: ignore pre-commit-config and readthedocs for changelog updates

* ci: fix to correct hpc workflow call

* fix: update precommit config

* chore: update pre-commits

* feat: add codeowners file

* chore: update dependencies

* ci: add hpc-config

* docs: changelog

* fix: respond to review comments

---------

Co-authored-by: Jesper Dramsch <jesper.dramsch@ecmwf.int>
* feat: add configurability to dropout in MultiHeadSelfAttention

Co-authored-by: Rilwan (Akanni) Adewoyin <18564167+Rilwan-Adewoyin@users.noreply.github.com>

* test: adjust to dropout_p

* doc: update changelog

* Feature/integrate reusable workflows (#16)

* ci: add public pr label

* ci: add readthedocs update check

* ci: add downstream ci

* ci: add ci-config

* chore(deps): remove unused dependency

* docs: update changelog

* ci: switch to main

* chore: changelog 0.2.1

* Update error messages from invalid sub_graph in model instantiation (#20)

* ci: inherit pypi publish flow (#17)

* ci: inherit pypi publish flow

Co-authored-by: Helen Theissen <helen.theissen@ecmwf.int>

* docs: add to changelog

* fix: typo in reusable workflow

* fix: another typo

* chore: bump actions/setup-python to v5

* ci: run downstream-ci for changes in src and tests

* docs: update changelog

---------

Co-authored-by: Helen Theissen <helen.theissen@ecmwf.int>

* Update CHANGELOG.md to KeepChangelog format

* [pre-commit.ci] pre-commit autoupdate (#25)

updates:
- [github.com/psf/black-pre-commit-mirror: 24.4.2 → 24.8.0](psf/black-pre-commit-mirror@24.4.2...24.8.0)
- [github.com/astral-sh/ruff-pre-commit: v0.4.6 → v0.6.2](astral-sh/ruff-pre-commit@v0.4.6...v0.6.2)
- [github.com/tox-dev/pyproject-fmt: 2.1.3 → 2.2.1](tox-dev/pyproject-fmt@2.1.3...2.2.1)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* Ci/changelog-release-updater (#26)

* ci: add changelof release updater

* docs: update changelog

* Feature/integrate reusable workflows (#16)

* ci: add public pr label

* ci: add readthedocs update check

* ci: add downstream ci

* ci: add ci-config

* chore(deps): remove unused dependency

* docs: update changelog

* ci: switch to main

* chore: changelog 0.2.1

* Update error messages from invalid sub_graph in model instantiation (#20)

* ci: inherit pypi publish flow (#17)

* ci: inherit pypi publish flow

Co-authored-by: Helen Theissen <helen.theissen@ecmwf.int>

* docs: add to changelog

* fix: typo in reusable workflow

* fix: another typo

* chore: bump actions/setup-python to v5

* ci: run downstream-ci for changes in src and tests

* docs: update changelog

---------

Co-authored-by: Helen Theissen <helen.theissen@ecmwf.int>

* Update CHANGELOG.md to KeepChangelog format

* Ci/changelog-release-updater (#26)

* ci: add changelof release updater

* docs: update changelog

---------

Co-authored-by: Rilwan (Akanni) Adewoyin <18564167+Rilwan-Adewoyin@users.noreply.github.com>
Co-authored-by: Gert Mertes <gert.mertes@ecmwf.int>
Co-authored-by: Mario Santa Cruz <48736305+JPXKQX@users.noreply.github.com>
Co-authored-by: Jesper Dramsch <jesper.dramsch@ecmwf.int>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
xfail for MultiHeadSelfAttention
* fix: change pre-cmmit autoupdate schedule to monthly

* fix: change the merge strategy for Changelog to Union

* fix: add .envrc to .gitignore

* ci: ignore pre-commit-config and readthedocs for changelog updates

* ci: fix to correct hpc workflow call

* fix: update precommit config

* chore: update pre-commits

* feat: add codeowners file

* chore: update dependencies

* ci: add hpc-config

* docs: changelog

* fix: respond to review comments

---------

Co-authored-by: Jesper Dramsch <jesper.dramsch@ecmwf.int>
* feat: add configurability to dropout in MultiHeadSelfAttention

Co-authored-by: Rilwan (Akanni) Adewoyin <18564167+Rilwan-Adewoyin@users.noreply.github.com>

* test: adjust to dropout_p

* doc: update changelog

* Feature/integrate reusable workflows (#16)

* ci: add public pr label

* ci: add readthedocs update check

* ci: add downstream ci

* ci: add ci-config

* chore(deps): remove unused dependency

* docs: update changelog

* ci: switch to main

* chore: changelog 0.2.1

* Update error messages from invalid sub_graph in model instantiation (#20)

* ci: inherit pypi publish flow (#17)

* ci: inherit pypi publish flow

Co-authored-by: Helen Theissen <helen.theissen@ecmwf.int>

* docs: add to changelog

* fix: typo in reusable workflow

* fix: another typo

* chore: bump actions/setup-python to v5

* ci: run downstream-ci for changes in src and tests

* docs: update changelog

---------

Co-authored-by: Helen Theissen <helen.theissen@ecmwf.int>

* Update CHANGELOG.md to KeepChangelog format

* [pre-commit.ci] pre-commit autoupdate (#25)

updates:
- [github.com/psf/black-pre-commit-mirror: 24.4.2 → 24.8.0](psf/black-pre-commit-mirror@24.4.2...24.8.0)
- [github.com/astral-sh/ruff-pre-commit: v0.4.6 → v0.6.2](astral-sh/ruff-pre-commit@v0.4.6...v0.6.2)
- [github.com/tox-dev/pyproject-fmt: 2.1.3 → 2.2.1](tox-dev/pyproject-fmt@2.1.3...2.2.1)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* Ci/changelog-release-updater (#26)

* ci: add changelof release updater

* docs: update changelog

* Feature/integrate reusable workflows (#16)

* ci: add public pr label

* ci: add readthedocs update check

* ci: add downstream ci

* ci: add ci-config

* chore(deps): remove unused dependency

* docs: update changelog

* ci: switch to main

* chore: changelog 0.2.1

* Update error messages from invalid sub_graph in model instantiation (#20)

* ci: inherit pypi publish flow (#17)

* ci: inherit pypi publish flow

Co-authored-by: Helen Theissen <helen.theissen@ecmwf.int>

* docs: add to changelog

* fix: typo in reusable workflow

* fix: another typo

* chore: bump actions/setup-python to v5

* ci: run downstream-ci for changes in src and tests

* docs: update changelog

---------

Co-authored-by: Helen Theissen <helen.theissen@ecmwf.int>

* Update CHANGELOG.md to KeepChangelog format

* Ci/changelog-release-updater (#26)

* ci: add changelof release updater

* docs: update changelog

---------

Co-authored-by: Rilwan (Akanni) Adewoyin <18564167+Rilwan-Adewoyin@users.noreply.github.com>
Co-authored-by: Gert Mertes <gert.mertes@ecmwf.int>
Co-authored-by: Mario Santa Cruz <48736305+JPXKQX@users.noreply.github.com>
Co-authored-by: Jesper Dramsch <jesper.dramsch@ecmwf.int>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
xfail for MultiHeadSelfAttention
@FussyDuck
Copy link

FussyDuck commented Jan 6, 2025

CLA assistant check
All committers have signed the CLA.

@theissenhelen theissenhelen marked this pull request as ready for review January 17, 2025 10:59
@theissenhelen theissenhelen changed the title Feature/44 make flash attention configurable feat: 44 make flash attention configurable Jan 17, 2025
mishooax
mishooax previously approved these changes Jan 20, 2025
Copy link
Member

@mishooax mishooax left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks OK to me (I haven't run the code, but Cathal did) - nice work @theissenhelen.

Copy link
Member

@HCookie HCookie left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Brilliant work exists here,
Just a few thoughts and questions,
Looks great overall.


For the flash attention implementation, two additional parameters are available: softcap, use_alibi_slopes

softcap: Softcapping prevents the logits from grwoing excessively large
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
softcap: Softcapping prevents the logits from grwoing excessively large
softcap: Softcapping prevents the logits from growing excessively large


softcap: Softcapping prevents the logits from grwoing excessively large

use_alibi_slopes: Adds bias of (-alibi_slope * |i + seqlen_k - seqlen_q - j|) to the attention score of
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
use_alibi_slopes: Adds bias of (-alibi_slope * |i + seqlen_k - seqlen_q - j|) to the attention score of
use_alibi_slopes: Adds bias of `(-alibi_slope * |i + seqlen_k - seqlen_q - j|)` to the attention score of

Comment on lines +120 to +121
Please change model.processor.attention_implementation in the config to one of: {attn_funcs.keys()}"
LOGGER.info(f"Using {self.attention_implementation}")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this reference the config or just say that the arg was wrong?

Comment on lines +77 to +80
softcap : float, optional
Anything > 0 activates softcapping attention, by default None
use_alibi_slopes : bool, optional
Adds bias
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given that these are only used for flash_attention, if kwargs are needed for the other attention types, we may end up with a large number of unused kwargs, could it make sense to add attention_kwargs: dict[str, Any], and use that?

if self.attention_implementation == "flex_attention":
assert (
self.embed_dim / self.num_heads >= 16
), f"Embedding dimension ({self.embed_dim}) devided by number of heads ({self.num_heads}) must be >= 16."
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minor typo, 'devided' -> 'divided'

# initalise the attn func here
self.attention = attn_funcs[self.attention_implementation]()

if self.attention_implementation == "flex_attention":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe this if should go before self.attention = attn_funcs[self.attention_implementation]()

@@ -14,6 +14,7 @@ processor:
num_heads: 16 # GraphTransformer or Transformer only
window_size: 512
dropout_p: 0.0 # GraphTransformer
attention_implementation: flash_attention # flash_attention, flex_attention, scaled_dot_product_attention

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are use_alibi_slopes and softcap meant to also be passed via the config? Since in the default we have flash_attention I was wondering if it would be better to also include the defaults for those two other parameters

@@ -0,0 +1,7 @@
[pytest]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Out of curiosity, what's is this file doing?

NotImplementedError("Error. Causal not yet implemented in FlexAttn in Anemoi.")

# This assumes seq_len never changes
# across iterations and stages
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's this comment about? 🤔

)
self.attention = functools.partial(
flex_attention, block_mask=self.block_mask
) # Cache the block mask (recomended in attn blog post)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could be nice adding reference to that blog?

super().__init__()
import flash_attn

if version.parse(flash_attn.__version__) < version.parse("2.6.0"):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Flash attention can't yet be installed through the pyproject right? Do we have some checks ahead of this to confirm that the user has installed flash_attn before trying to use it?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
Status: Under Review
Development

Successfully merging this pull request may close these issues.

7 participants