Skip to content

Commit

Permalink
update esmc tokenizer and return hidden states (#160)
Browse files Browse the repository at this point in the history
Signed-off-by: tina-z-jia <145156075+tina-z-jia@users.noreply.github.com>
  • Loading branch information
tina-z-jia authored Dec 6, 2024
1 parent 5604523 commit 8127b99
Show file tree
Hide file tree
Showing 13 changed files with 88 additions and 49 deletions.
2 changes: 1 addition & 1 deletion esm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
__version__ = "3.1.0"
__version__ = "3.1.1"

7 changes: 5 additions & 2 deletions esm/layers/transformer_stack.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def forward(
affine: Affine3D | None = None,
affine_mask: torch.Tensor | None = None,
chain_id: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Forward pass of the TransformerStack.
Expand All @@ -85,6 +85,9 @@ def forward(
*batch_dims, _ = x.shape
if chain_id is None:
chain_id = torch.ones(size=batch_dims, dtype=torch.int64, device=x.device)
hiddens = []
for block in self.blocks:
x = block(x, sequence_id, affine, affine_mask, chain_id)
return self.norm(x), x
hiddens.append(x)
hiddens = torch.stack(hiddens, dim=0)
return self.norm(x), x, hiddens
4 changes: 3 additions & 1 deletion esm/models/esm3.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,9 @@ def forward(
function_tokens,
residue_annotation_tokens,
)
x, embedding = self.transformer(x, sequence_id, affine, affine_mask, chain_id)
x, embedding, _ = self.transformer(
x, sequence_id, affine, affine_mask, chain_id
)
return self.output_heads(x, embedding)

# The following methods are for the ESM3InferenceClient interface
Expand Down
31 changes: 25 additions & 6 deletions esm/models/esmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,15 @@
from esm.utils import encoding
from esm.utils.constants.models import ESMC_600M
from esm.utils.decoding import decode_sequence
from esm.utils.misc import stack_variable_length_tensors
from esm.utils.sampling import _BatchedESMProteinTensor


@dataclass
class ESMCOutput:
sequence_logits: torch.Tensor
embeddings: torch.Tensor | None
hidden_states: torch.Tensor | None


class ESMC(nn.Module, ESMCInferenceClient):
Expand Down Expand Up @@ -73,6 +75,23 @@ def device(self):
def raw_model(self):
return self

def _tokenize(self, sequence: list[str]) -> torch.Tensor:
pad = self.tokenizer.pad_token_id
assert pad is not None
return stack_variable_length_tensors(
[
encoding.tokenize_sequence(x, self.tokenizer, add_special_tokens=True)
for x in sequence
],
constant_value=pad,
).to(next(self.parameters()).device)

def _detokenize(self, sequence: torch.Tensor) -> list[str]:
pad = self.tokenizer.pad_token_id
assert pad is not None
assert sequence.ndim == 2
return [decode_sequence(x[x != pad][1:-1], self.tokenizer) for x in sequence]

def forward(
self,
sequence_tokens: torch.Tensor | None = None,
Expand All @@ -93,19 +112,19 @@ def forward(
sequence_id = sequence_tokens == self.tokenizer.pad_token_id

x = self.embed(sequence_tokens)
x, _ = self.transformer(x, sequence_id=sequence_id)
x, _, hiddens = self.transformer(x, sequence_id=sequence_id)
sequence_logits = self.sequence_head(x)
output = ESMCOutput(sequence_logits=sequence_logits, embeddings=x)
output = ESMCOutput(
sequence_logits=sequence_logits, embeddings=x, hidden_states=hiddens
)
return output

def encode(self, input: ESMProtein) -> ESMProteinTensor:
input = attr.evolve(input) # Make a copy
sequence_tokens = None

if input.sequence is not None:
sequence_tokens = encoding.tokenize_sequence(
input.sequence, self.tokenizer, add_special_tokens=True
)
sequence_tokens = self._tokenize([input.sequence])[0]
return ESMProteinTensor(sequence=sequence_tokens).to(
next(self.parameters()).device
)
Expand All @@ -114,7 +133,7 @@ def decode(self, input: ESMProteinTensor) -> ESMProtein:
input = attr.evolve(input) # Make a copy

assert input.sequence is not None
sequence = decode_sequence(input.sequence[1:-1], self.tokenizer)
sequence = self._detokenize(input.sequence)[0]

return ESMProtein(sequence=sequence)

Expand Down
2 changes: 1 addition & 1 deletion esm/models/function_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def forward(self, token_ids: torch.Tensor) -> dict[str, torch.Tensor]:
inputs = token_ids + vocab_offsets[None, :]

embed = self.embedding(inputs)
encoding, _ = self.decoder(embed)
encoding, _, _ = self.decoder(embed)
pooled = torch.mean(encoding, dim=1)

return {name: head(pooled) for name, head in self.heads.items()}
Expand Down
4 changes: 2 additions & 2 deletions esm/models/vqvae.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ def encode_local_structure(

z = self.relative_positional_embedding(res_idxs[:, 0], res_idxs)

z, _ = self.transformer.forward(
z, _, _ = self.transformer.forward(
x=z,
sequence_id=knn_sequence_id,
affine=affine,
Expand Down Expand Up @@ -397,7 +397,7 @@ def decode(

x = self.embed(structure_tokens)
# !!! NOTE: Attention mask is actually unused here so watch out
x, _ = self.decoder_stack.forward(
x, _, _ = self.decoder_stack.forward(
x, affine=None, affine_mask=None, sequence_id=sequence_id, chain_id=chain_id
)

Expand Down
17 changes: 7 additions & 10 deletions esm/pretrained.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@
StructureTokenDecoder,
StructureTokenEncoder,
)
from esm.tokenization import get_model_tokenizers
from esm.tokenization import (
get_esm3_model_tokenizers,
get_esmc_model_tokenizers,
)
from esm.utils.constants.esm3 import data_root
from esm.utils.constants.models import (
ESM3_FUNCTION_DECODER_V0,
Expand Down Expand Up @@ -62,10 +65,7 @@ def ESM3_function_decoder_v0(device: torch.device | str = "cpu"):
def ESMC_300M_202412(device: torch.device | str = "cpu"):
with torch.device(device):
model = ESMC(
d_model=960,
n_heads=15,
n_layers=30,
tokenizer=get_model_tokenizers(ESM3_OPEN_SMALL).sequence,
d_model=960, n_heads=15, n_layers=30, tokenizer=get_esmc_model_tokenizers()
).eval()
state_dict = torch.load(
data_root("esmc-300") / "data/weights/esmc_300m_2024_12_v0.pth",
Expand All @@ -79,10 +79,7 @@ def ESMC_300M_202412(device: torch.device | str = "cpu"):
def ESMC_600M_202412(device: torch.device | str = "cpu"):
with torch.device(device):
model = ESMC(
d_model=1152,
n_heads=18,
n_layers=36,
tokenizer=get_model_tokenizers(ESM3_OPEN_SMALL).sequence,
d_model=1152, n_heads=18, n_layers=36, tokenizer=get_esmc_model_tokenizers()
).eval()
state_dict = torch.load(
data_root("esmc-600") / "data/weights/esmc_600m_2024_12_v0.pth",
Expand All @@ -103,7 +100,7 @@ def ESM3_sm_open_v0(device: torch.device | str = "cpu"):
structure_encoder_fn=ESM3_structure_encoder_v0,
structure_decoder_fn=ESM3_structure_decoder_v0,
function_decoder_fn=ESM3_function_decoder_v0,
tokenizers=get_model_tokenizers(ESM3_OPEN_SMALL),
tokenizers=get_esm3_model_tokenizers(ESM3_OPEN_SMALL),
).eval()
state_dict = torch.load(
data_root("esm3") / "data/weights/esm3_sm_open_v1.pth", map_location=device
Expand Down
4 changes: 2 additions & 2 deletions esm/sdk/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import esm.utils.constants.api as C
from esm.tokenization import (
TokenizerCollectionProtocol,
get_model_tokenizers,
get_esm3_model_tokenizers,
)
from esm.utils import encoding
from esm.utils.constants.models import ESM3_OPEN_SMALL
Expand Down Expand Up @@ -226,7 +226,7 @@ def empty(
device: torch.device | str = "cpu",
) -> ESMProteinTensor:
if tokenizers is None:
tokenizers = get_model_tokenizers(ESM3_OPEN_SMALL)
tokenizers = get_esm3_model_tokenizers(ESM3_OPEN_SMALL)

return ESMProteinTensor(
sequence=encoding.get_default_sequence_tokens(
Expand Down
6 changes: 5 additions & 1 deletion esm/tokenization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class TokenizerCollection:
residue_annotations: ResidueAnnotationsTokenizer


def get_model_tokenizers(model: str = ESM3_OPEN_SMALL) -> TokenizerCollection:
def get_esm3_model_tokenizers(model: str = ESM3_OPEN_SMALL) -> TokenizerCollection:
if normalize_model_name(model) == ESM3_OPEN_SMALL:
return TokenizerCollection(
sequence=EsmSequenceTokenizer(),
Expand All @@ -48,6 +48,10 @@ def get_model_tokenizers(model: str = ESM3_OPEN_SMALL) -> TokenizerCollection:
raise ValueError(f"Unknown model: {model}")


def get_esmc_model_tokenizers() -> EsmSequenceTokenizer:
return EsmSequenceTokenizer()


def get_invalid_tokenizer_ids(tokenizer: EsmTokenizerBase) -> list[int]:
if isinstance(tokenizer, EsmSequenceTokenizer):
return [
Expand Down
4 changes: 2 additions & 2 deletions esm/utils/generation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@
from evolutionaryscale.utils.remote_inference.api_v1 import (
ESM3RemoteModelInferenceClient,
)
from projects.forge.fastapi.utils.model import _load_esm3
from projects.forge.fastapi.utils.model import _load_esm_model


@pytest.fixture()
def esm3_remote_inference_client():
model = _load_esm3(ModelName.ESM3_TINY_DEV, distributed_model=False)
model = _load_esm_model(ModelName.ESM3_TINY_DEV, distributed_model=False)
client = ESM3RemoteModelInferenceClient(
model,
tokenizers=model.tokenizers,
Expand Down
48 changes: 31 additions & 17 deletions examples/esmc_examples.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,48 @@
from esm.models.esmc import ESMC
from examples.local_generate import get_sample_protein
from esm.sdk.api import (
ESMCInferenceClient,
LogitsConfig,
LogitsOutput,
)
from esm.sdk.api import ESMCInferenceClient, ESMProtein, LogitsConfig, LogitsOutput


def main(client: ESMCInferenceClient):
# ================================================================
# Example usage: one single protein
# ================================================================
protein = get_sample_protein()
protein.coordinates = None
protein.function_annotations = None
protein.sasa = None
protein = ESMProtein(sequence="AAAAA")

# Use logits endpoint. Using bf16 for inference optimization
protein_tensor = client.encode(protein)
logits_output = client.logits(
output = client.logits(
protein_tensor, LogitsConfig(sequence=True, return_embeddings=True)
)
assert isinstance(
logits_output, LogitsOutput
), f"LogitsOutput was expected but got {logits_output}"
assert (
logits_output.logits is not None and logits_output.logits.sequence is not None
output, LogitsOutput
), f"LogitsOutput was expected but got {output}"
assert output.logits is not None and output.logits.sequence is not None
assert output.embeddings is not None and output.embeddings is not None
print(
f"Client returned logits with shape: {output.logits.sequence.shape} and embeddings with shape: {output.embeddings.shape}"
)


def raw_forward(model: ESMC):
protein = ESMProtein(sequence="AAAAA")
sequences = [protein.sequence, protein.sequence]

# ================================================================
# Example usage: directly use the model
# ================================================================
input_ids = model._tokenize(sequences)
output = model(input_ids)
logits, embeddings, hiddens = (
output.sequence_logits,
output.embeddings,
output.hidden_states,
)
print(
f"Raw model returned logits with shape: {logits.shape}, embeddings with shape: {embeddings.shape} and hidden states with shape {hiddens.shape}"
)
assert logits_output.embeddings is not None and logits_output.embeddings is not None


if __name__ == "__main__":
main(ESMC.from_pretrained("esmc_300m"))
model = ESMC.from_pretrained("esmc_300m")
main(model)
raw_forward(model)
4 changes: 2 additions & 2 deletions examples/raw_forwards.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
ESM3_structure_decoder_v0,
ESM3_structure_encoder_v0,
)
from esm.tokenization import get_model_tokenizers
from esm.tokenization import get_esm3_model_tokenizers
from esm.tokenization.function_tokenizer import (
InterProQuantizedTokenizer as EsmFunctionTokenizer,
)
Expand Down Expand Up @@ -50,7 +50,7 @@ def inverse_folding_example():

@torch.no_grad()
def conditioned_prediction_example():
tokenizers = get_model_tokenizers()
tokenizers = get_esm3_model_tokenizers()

model = ESM3_sm_open_v0("cuda")

Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "esm"
version = "3.1.0"
version = "3.1.1"
description = "EvolutionaryScale open model repository"
readme = "README.md"
requires-python = ">=3.10"
Expand All @@ -24,7 +24,7 @@ dependencies = [
"torch>=2.2.0",
"torchvision",
"torchtext",
"transformers",
"transformers<4.47.0",
"ipython",
"einops",
"biotite==0.41.2",
Expand Down

0 comments on commit 8127b99

Please sign in to comment.