Skip to content

Commit

Permalink
v3.0.4 (#95)
Browse files Browse the repository at this point in the history
Co-authored-by: Zeming Lin <zeming@evolutionaryscale.ai>
  • Loading branch information
ebetica and Zeming Lin authored Aug 30, 2024
1 parent eadc104 commit 1d66d81
Show file tree
Hide file tree
Showing 24 changed files with 736 additions and 83 deletions.
2 changes: 1 addition & 1 deletion esm/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "3.0.3"
__version__ = "3.0.4"
3 changes: 1 addition & 2 deletions esm/tokenization/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from dataclasses import dataclass
from typing import Protocol

from esm.utils.constants.esm3 import VQVAE_SPECIAL_TOKENS
from esm.utils.constants.models import ESM3_OPEN_SMALL

from .function_tokenizer import InterProQuantizedTokenizer
Expand Down Expand Up @@ -36,7 +35,7 @@ def get_model_tokenizers(model: str = ESM3_OPEN_SMALL) -> TokenizerCollection:
if model == ESM3_OPEN_SMALL:
return TokenizerCollection(
sequence=EsmSequenceTokenizer(),
structure=StructureTokenizer(vq_vae_special_tokens=VQVAE_SPECIAL_TOKENS),
structure=StructureTokenizer(),
secondary_structure=SecondaryStructureTokenizer(kind="ss8"),
sasa=SASADiscretizingTokenizer(),
function=InterProQuantizedTokenizer(),
Expand Down
16 changes: 16 additions & 0 deletions esm/tokenization/function_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,22 @@ def pad_token(self) -> str:
def pad_token_id(self) -> int:
return self.vocab_to_index[self.pad_token]

@property
def chain_break_token(self) -> str:
return "<pad>"

@property
def chain_break_token_id(self) -> int:
return self.vocab_to_index[self.chain_break_token]

@property
def all_token_ids(self):
return list(range(len(self.vocab)))

@property
def special_token_ids(self):
return [self.vocab_to_index[token] for token in self.special_tokens]


def _texts_to_keywords(texts: list[str]) -> list[str]:
"""Breaks InterPro/GO free-text description set into bag-of-n-grams for n={1,2}.
Expand Down
16 changes: 16 additions & 0 deletions esm/tokenization/residue_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,3 +222,19 @@ def pad_token(self) -> str:
@property
def pad_token_id(self) -> int:
return self.vocab_to_index[self.pad_token]

@property
def chain_break_token(self) -> str:
return "<pad>"

@property
def chain_break_token_id(self) -> int:
return self.vocab_to_index[self.chain_break_token]

@property
def all_token_ids(self):
return list(range(len(self.vocab)))

@property
def special_token_ids(self):
return [self.vocab_to_index[token] for token in self.special_tokens]
16 changes: 16 additions & 0 deletions esm/tokenization/sasa_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,3 +127,19 @@ def pad_token(self) -> str:
@property
def pad_token_id(self) -> int:
return self.vocab_to_index[self.pad_token]

@property
def chain_break_token(self) -> str:
return "<pad>"

@property
def chain_break_token_id(self) -> int:
return self.vocab_to_index[self.chain_break_token]

@property
def all_token_ids(self):
return list(range(len(self.vocab)))

@property
def special_token_ids(self):
return [self.vocab_to_index[token] for token in self.special_tokens]
29 changes: 26 additions & 3 deletions esm/tokenization/sequence_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def __init__(
pad_token="<pad>",
mask_token="<mask>",
eos_token="<eos>",
chainbreak_token="|",
chain_break_token="|",
**kwargs,
):
all_tokens = C.SEQUENCE_VOCAB
Expand All @@ -30,8 +30,15 @@ def __init__(
# a character-level tokenizer is the same as BPE with no token merges
bpe = BPE(token_to_id, merges=[], unk_token=unk_token)
tokenizer = Tokenizer(bpe)
special_tokens = [cls_token, pad_token, mask_token, eos_token, chainbreak_token]
additional_special_tokens = [chainbreak_token]
special_tokens = [
cls_token,
pad_token,
mask_token,
eos_token,
chain_break_token,
]
self.cb_token = chain_break_token
additional_special_tokens = [chain_break_token]

tokenizer.add_special_tokens(
special_tokens,
Expand Down Expand Up @@ -66,3 +73,19 @@ def bos_token(self):
@property
def bos_token_id(self):
return self.cls_token_id

@property
def chain_break_token(self):
return self.cb_token

@property
def chain_break_token_id(self):
return self.convert_tokens_to_ids(self.chain_break_token)

@property
def all_token_ids(self):
return list(range(self.vocab_size))

@property
def special_token_ids(self):
return self.all_special_ids
16 changes: 16 additions & 0 deletions esm/tokenization/ss_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,3 +107,19 @@ def pad_token(self) -> str:
@property
def pad_token_id(self) -> int:
return self.vocab_to_index[self.pad_token]

@property
def chain_break_token(self) -> str:
return "<pad>"

@property
def chain_break_token_id(self) -> int:
return self.vocab_to_index[self.chain_break_token]

@property
def all_token_ids(self):
return list(range(len(self.vocab)))

@property
def special_token_ids(self):
return [self.vocab_to_index[token] for token in self.special_tokens]
26 changes: 23 additions & 3 deletions esm/tokenization/structure_tokenizer.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,19 @@
from esm.tokenization.tokenizer_base import EsmTokenizerBase
from esm.utils.constants import esm3 as C


class StructureTokenizer(EsmTokenizerBase):
"""A convenince class for accessing special token ids of
the StructureTokenEncoder and StructureTokenDecoder."""

def __init__(self, vq_vae_special_tokens: dict[str, int]):
self.vq_vae_special_tokens = vq_vae_special_tokens
def __init__(self, cookbook_size: int = C.VQVAE_CODEBOOK_SIZE):
self.vq_vae_special_tokens = {
"MASK": cookbook_size,
"EOS": cookbook_size + 1,
"BOS": cookbook_size + 2,
"PAD": cookbook_size + 3,
"CHAINBREAK": cookbook_size + 4,
}

def mask_token(self) -> str:
raise NotImplementedError(
Expand Down Expand Up @@ -44,10 +51,23 @@ def pad_token(self) -> str:
def pad_token_id(self) -> int:
return self.vq_vae_special_tokens["PAD"]

def chain_break_token(self) -> str:
raise NotImplementedError(
"Structure tokens are defined on 3D coordinates, not strings."
)

@property
def chainbreak_token_id(self) -> int:
def chain_break_token_id(self) -> int:
return self.vq_vae_special_tokens["CHAINBREAK"]

@property
def all_token_ids(self):
return list(range(C.VQVAE_CODEBOOK_SIZE + len(self.vq_vae_special_tokens)))

@property
def special_token_ids(self):
return self.vq_vae_special_tokens.values()

def encode(self, *args, **kwargs):
raise NotImplementedError(
"The StructureTokenizer class is provided as a convenience for "
Expand Down
16 changes: 16 additions & 0 deletions esm/tokenization/tokenizer_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,19 @@ def pad_token(self) -> str:
@property
def pad_token_id(self) -> int:
...

@property
def chain_break_token(self) -> str:
...

@property
def chain_break_token_id(self) -> int:
...

@property
def all_token_ids(self):
...

@property
def special_token_ids(self):
...
6 changes: 5 additions & 1 deletion esm/utils/decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def decode_protein_tensor(
input: ESMProteinTensor,
tokenizers: TokenizerCollectionProtocol,
structure_token_decoder: StructureTokenDecoder,
function_token_decoder: FunctionTokenDecoder,
function_token_decoder: FunctionTokenDecoder | None = None,
) -> ESMProtein:
input = attr.evolve(input) # Make a copy

Expand Down Expand Up @@ -90,6 +90,10 @@ def decode_protein_tensor(
if input.sasa is not None:
sasa = decode_sasa(input.sasa, tokenizers.sasa)
if input.function is not None:
if function_token_decoder is None:
raise ValueError(
"Cannot decode function annotations without a function token decoder"
)
function_track_annotations = decode_function_annotations(
input.function,
function_token_decoder=function_token_decoder,
Expand Down
31 changes: 20 additions & 11 deletions esm/utils/encoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,11 +174,17 @@ def get_default_sequence_tokens(
sequence_length: int,
sequence_tokenizer: EsmSequenceTokenizer,
) -> torch.Tensor:
return tokenize_sequence(
get_default_sequence(sequence_length),
sequence_tokenizer,
add_special_tokens=True,
assert sequence_tokenizer.mask_token_id is not None
assert sequence_tokenizer.bos_token_id is not None
assert sequence_tokenizer.eos_token_id is not None

sequence_tokens = torch.full(
(sequence_length + 2,), sequence_tokenizer.mask_token_id
)
sequence_tokens[0] = sequence_tokenizer.bos_token_id
sequence_tokens[-1] = sequence_tokenizer.eos_token_id

return sequence_tokens


def get_default_structure_tokens(
Expand All @@ -200,19 +206,22 @@ def get_default_structure_tokens(
def get_default_secondary_structure_tokens(
sequence_length: int, secondary_structure_tokenizer: SecondaryStructureTokenizer
) -> torch.Tensor:
return tokenize_secondary_structure(
get_default_secondary_structure(sequence_length),
secondary_structure_tokenizer,
add_special_tokens=True,
ss8_tokens = torch.full(
(sequence_length + 2,), secondary_structure_tokenizer.mask_token_id
)
ss8_tokens[0] = secondary_structure_tokenizer.bos_token_id
ss8_tokens[-1] = secondary_structure_tokenizer.eos_token_id

return ss8_tokens


def get_default_sasa_tokens(
sequence_length: int, sasa_tokenizer: SASADiscretizingTokenizer
) -> torch.Tensor:
return tokenize_sasa(
get_default_sasa(sequence_length), sasa_tokenizer, add_special_tokens=True
)
sasa_tokens = torch.full((sequence_length + 2,), sasa_tokenizer.mask_token_id)
sasa_tokens[0] = sasa_tokenizer.bos_token_id
sasa_tokens[-1] = sasa_tokenizer.eos_token_id
return sasa_tokens


def get_default_function_tokens(
Expand Down
27 changes: 26 additions & 1 deletion esm/utils/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,6 +495,7 @@ def _sample_per_prompt(
sampling_config: SamplingConfig,
tokenizers: TokenizerCollectionProtocol,
decode_sasa_tokens: bool = True,
mask_logits_of_invalid_ids: bool = True,
) -> ForwardAndSampleOutput:
assert logits_output.logits is not None

Expand All @@ -513,11 +514,19 @@ def maybe_clone(x: torch.Tensor | None) -> torch.Tensor | None:
if config is None:
tokens_dir[track] = maybe_clone(getattr(protein, track))
continue
tokenizer = getattr(tokenizers, track)
valid_ids = (
set(tokenizer.all_token_ids)
- set(tokenizer.special_token_ids)
- set(config.invalid_ids)
)
sampling_metadata = _sample_track(
logits=getattr(logits_output.logits, track),
tokens=getattr(protein, track),
sampling_track_config=config,
mask_idx=getattr(tokenizers, track).mask_token_id,
valid_ids=list(valid_ids),
mask_logits_of_invalid_ids=mask_logits_of_invalid_ids,
)
tokens_dir[track] = sampling_metadata.pop("sampled_tokens") # (L,)
track_sampling_metadata_dir[track] = sampling_metadata
Expand All @@ -536,12 +545,19 @@ def maybe_clone(x: torch.Tensor | None) -> torch.Tensor | None:
assert logits_output.logits.sasa is not None
assert protein.sasa is not None

valid_ids = (
set(tokenizers.sasa.all_token_ids)
- set(tokenizer.special_token_ids)
- set(config.invalid_ids)
)
sasa_logits = logits_output.logits.sasa
sasa_value = sample_sasa_logits(
sasa_logits,
protein.sasa,
sampling_track_config=config,
mask_idx=tokenizers.sasa.mask_token_id,
valid_ids=list(valid_ids),
mask_logits_of_invalid_ids=mask_logits_of_invalid_ids,
)
tokens_dir["sasa"] = sasa_value

Expand All @@ -558,6 +574,9 @@ def maybe_clone(x: torch.Tensor | None) -> torch.Tensor | None:
getattr(protein, "residue_annotations")
)
else:
if config.invalid_ids is not None and len(config.invalid_ids) > 0:
warn("For function sampling, invalid_ids sampling config is not supported.")

sampling_metadata = _sample_function_track(
tokenizers.function,
tokens=getattr(protein, "function"),
Expand Down Expand Up @@ -621,6 +640,8 @@ def _sample_track(
tokens: torch.Tensor,
sampling_track_config: SamplingTrackConfig,
mask_idx: int,
valid_ids: list[int],
mask_logits_of_invalid_ids: bool = True,
) -> dict[str, torch.Tensor]:
"""Works with inputs that have batch dimension."""
# Sample in all positions
Expand All @@ -629,7 +650,11 @@ def _sample_track(
# since the logits may be computed with a longer padded batch, while tokens
# are the original input sequence.
sampled_tokens = sample_logits(
logits, temperature=temperature, top_p=sampling_track_config.top_p
logits,
temperature=temperature,
valid_ids=valid_ids,
top_p=sampling_track_config.top_p,
mask_logits_of_invalid_ids=mask_logits_of_invalid_ids,
)
log_probs = logits.log_softmax(-1)
sampling_mask = get_sampling_mask(tokens, sampling_track_config, mask_idx)
Expand Down
6 changes: 3 additions & 3 deletions esm/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def unbinpack(
return stack_variable_length_tensors(unpacked_tensors, pad_value)


def fp32_autocast_context(device_type: str) -> ContextManager[torch.amp.autocast]:
def fp32_autocast_context(device_type: str) -> ContextManager[torch.amp.autocast]: # type: ignore
"""
Returns an autocast context manager that disables downcasting by AMP.
Expand All @@ -187,9 +187,9 @@ def fp32_autocast_context(device_type: str) -> ContextManager[torch.amp.autocast
An autocast context manager with the specified behavior.
"""
if device_type == "cpu":
return torch.amp.autocast(device_type, enabled=False)
return torch.amp.autocast(device_type, enabled=False) # type: ignore
elif device_type == "cuda":
return torch.amp.autocast(device_type, dtype=torch.float32)
return torch.amp.autocast(device_type, dtype=torch.float32) # type: ignore
else:
raise ValueError(f"Unsupported device type: {device_type}")

Expand Down
Loading

0 comments on commit 1d66d81

Please sign in to comment.