diff --git a/esm/__init__.py b/esm/__init__.py index 8d1c862..8e10cb4 100644 --- a/esm/__init__.py +++ b/esm/__init__.py @@ -1 +1 @@ -__version__ = "3.0.3" +__version__ = "3.0.4" diff --git a/esm/tokenization/__init__.py b/esm/tokenization/__init__.py index 9d2e5a9..d22c0de 100644 --- a/esm/tokenization/__init__.py +++ b/esm/tokenization/__init__.py @@ -1,7 +1,6 @@ from dataclasses import dataclass from typing import Protocol -from esm.utils.constants.esm3 import VQVAE_SPECIAL_TOKENS from esm.utils.constants.models import ESM3_OPEN_SMALL from .function_tokenizer import InterProQuantizedTokenizer @@ -36,7 +35,7 @@ def get_model_tokenizers(model: str = ESM3_OPEN_SMALL) -> TokenizerCollection: if model == ESM3_OPEN_SMALL: return TokenizerCollection( sequence=EsmSequenceTokenizer(), - structure=StructureTokenizer(vq_vae_special_tokens=VQVAE_SPECIAL_TOKENS), + structure=StructureTokenizer(), secondary_structure=SecondaryStructureTokenizer(kind="ss8"), sasa=SASADiscretizingTokenizer(), function=InterProQuantizedTokenizer(), diff --git a/esm/tokenization/function_tokenizer.py b/esm/tokenization/function_tokenizer.py index 7c59ffb..e60d3da 100644 --- a/esm/tokenization/function_tokenizer.py +++ b/esm/tokenization/function_tokenizer.py @@ -348,6 +348,22 @@ def pad_token(self) -> str: def pad_token_id(self) -> int: return self.vocab_to_index[self.pad_token] + @property + def chain_break_token(self) -> str: + return "" + + @property + def chain_break_token_id(self) -> int: + return self.vocab_to_index[self.chain_break_token] + + @property + def all_token_ids(self): + return list(range(len(self.vocab))) + + @property + def special_token_ids(self): + return [self.vocab_to_index[token] for token in self.special_tokens] + def _texts_to_keywords(texts: list[str]) -> list[str]: """Breaks InterPro/GO free-text description set into bag-of-n-grams for n={1,2}. diff --git a/esm/tokenization/residue_tokenizer.py b/esm/tokenization/residue_tokenizer.py index 1e67b10..5430a23 100644 --- a/esm/tokenization/residue_tokenizer.py +++ b/esm/tokenization/residue_tokenizer.py @@ -222,3 +222,19 @@ def pad_token(self) -> str: @property def pad_token_id(self) -> int: return self.vocab_to_index[self.pad_token] + + @property + def chain_break_token(self) -> str: + return "" + + @property + def chain_break_token_id(self) -> int: + return self.vocab_to_index[self.chain_break_token] + + @property + def all_token_ids(self): + return list(range(len(self.vocab))) + + @property + def special_token_ids(self): + return [self.vocab_to_index[token] for token in self.special_tokens] diff --git a/esm/tokenization/sasa_tokenizer.py b/esm/tokenization/sasa_tokenizer.py index 4d7221b..a07c830 100644 --- a/esm/tokenization/sasa_tokenizer.py +++ b/esm/tokenization/sasa_tokenizer.py @@ -127,3 +127,19 @@ def pad_token(self) -> str: @property def pad_token_id(self) -> int: return self.vocab_to_index[self.pad_token] + + @property + def chain_break_token(self) -> str: + return "" + + @property + def chain_break_token_id(self) -> int: + return self.vocab_to_index[self.chain_break_token] + + @property + def all_token_ids(self): + return list(range(len(self.vocab))) + + @property + def special_token_ids(self): + return [self.vocab_to_index[token] for token in self.special_tokens] diff --git a/esm/tokenization/sequence_tokenizer.py b/esm/tokenization/sequence_tokenizer.py index 0926aab..a7e840d 100644 --- a/esm/tokenization/sequence_tokenizer.py +++ b/esm/tokenization/sequence_tokenizer.py @@ -21,7 +21,7 @@ def __init__( pad_token="", mask_token="", eos_token="", - chainbreak_token="|", + chain_break_token="|", **kwargs, ): all_tokens = C.SEQUENCE_VOCAB @@ -30,8 +30,15 @@ def __init__( # a character-level tokenizer is the same as BPE with no token merges bpe = BPE(token_to_id, merges=[], unk_token=unk_token) tokenizer = Tokenizer(bpe) - special_tokens = [cls_token, pad_token, mask_token, eos_token, chainbreak_token] - additional_special_tokens = [chainbreak_token] + special_tokens = [ + cls_token, + pad_token, + mask_token, + eos_token, + chain_break_token, + ] + self.cb_token = chain_break_token + additional_special_tokens = [chain_break_token] tokenizer.add_special_tokens( special_tokens, @@ -66,3 +73,19 @@ def bos_token(self): @property def bos_token_id(self): return self.cls_token_id + + @property + def chain_break_token(self): + return self.cb_token + + @property + def chain_break_token_id(self): + return self.convert_tokens_to_ids(self.chain_break_token) + + @property + def all_token_ids(self): + return list(range(self.vocab_size)) + + @property + def special_token_ids(self): + return self.all_special_ids diff --git a/esm/tokenization/ss_tokenizer.py b/esm/tokenization/ss_tokenizer.py index c540103..1b41b31 100644 --- a/esm/tokenization/ss_tokenizer.py +++ b/esm/tokenization/ss_tokenizer.py @@ -107,3 +107,19 @@ def pad_token(self) -> str: @property def pad_token_id(self) -> int: return self.vocab_to_index[self.pad_token] + + @property + def chain_break_token(self) -> str: + return "" + + @property + def chain_break_token_id(self) -> int: + return self.vocab_to_index[self.chain_break_token] + + @property + def all_token_ids(self): + return list(range(len(self.vocab))) + + @property + def special_token_ids(self): + return [self.vocab_to_index[token] for token in self.special_tokens] diff --git a/esm/tokenization/structure_tokenizer.py b/esm/tokenization/structure_tokenizer.py index 76b91b2..8a072f9 100644 --- a/esm/tokenization/structure_tokenizer.py +++ b/esm/tokenization/structure_tokenizer.py @@ -1,12 +1,19 @@ from esm.tokenization.tokenizer_base import EsmTokenizerBase +from esm.utils.constants import esm3 as C class StructureTokenizer(EsmTokenizerBase): """A convenince class for accessing special token ids of the StructureTokenEncoder and StructureTokenDecoder.""" - def __init__(self, vq_vae_special_tokens: dict[str, int]): - self.vq_vae_special_tokens = vq_vae_special_tokens + def __init__(self, cookbook_size: int = C.VQVAE_CODEBOOK_SIZE): + self.vq_vae_special_tokens = { + "MASK": cookbook_size, + "EOS": cookbook_size + 1, + "BOS": cookbook_size + 2, + "PAD": cookbook_size + 3, + "CHAINBREAK": cookbook_size + 4, + } def mask_token(self) -> str: raise NotImplementedError( @@ -44,10 +51,23 @@ def pad_token(self) -> str: def pad_token_id(self) -> int: return self.vq_vae_special_tokens["PAD"] + def chain_break_token(self) -> str: + raise NotImplementedError( + "Structure tokens are defined on 3D coordinates, not strings." + ) + @property - def chainbreak_token_id(self) -> int: + def chain_break_token_id(self) -> int: return self.vq_vae_special_tokens["CHAINBREAK"] + @property + def all_token_ids(self): + return list(range(C.VQVAE_CODEBOOK_SIZE + len(self.vq_vae_special_tokens))) + + @property + def special_token_ids(self): + return self.vq_vae_special_tokens.values() + def encode(self, *args, **kwargs): raise NotImplementedError( "The StructureTokenizer class is provided as a convenience for " diff --git a/esm/tokenization/tokenizer_base.py b/esm/tokenization/tokenizer_base.py index 7cbce34..a8032ea 100644 --- a/esm/tokenization/tokenizer_base.py +++ b/esm/tokenization/tokenizer_base.py @@ -40,3 +40,19 @@ def pad_token(self) -> str: @property def pad_token_id(self) -> int: ... + + @property + def chain_break_token(self) -> str: + ... + + @property + def chain_break_token_id(self) -> int: + ... + + @property + def all_token_ids(self): + ... + + @property + def special_token_ids(self): + ... diff --git a/esm/utils/decoding.py b/esm/utils/decoding.py index 802bca9..0de37c7 100644 --- a/esm/utils/decoding.py +++ b/esm/utils/decoding.py @@ -40,7 +40,7 @@ def decode_protein_tensor( input: ESMProteinTensor, tokenizers: TokenizerCollectionProtocol, structure_token_decoder: StructureTokenDecoder, - function_token_decoder: FunctionTokenDecoder, + function_token_decoder: FunctionTokenDecoder | None = None, ) -> ESMProtein: input = attr.evolve(input) # Make a copy @@ -90,6 +90,10 @@ def decode_protein_tensor( if input.sasa is not None: sasa = decode_sasa(input.sasa, tokenizers.sasa) if input.function is not None: + if function_token_decoder is None: + raise ValueError( + "Cannot decode function annotations without a function token decoder" + ) function_track_annotations = decode_function_annotations( input.function, function_token_decoder=function_token_decoder, diff --git a/esm/utils/encoding.py b/esm/utils/encoding.py index 9395891..9112d46 100644 --- a/esm/utils/encoding.py +++ b/esm/utils/encoding.py @@ -174,11 +174,17 @@ def get_default_sequence_tokens( sequence_length: int, sequence_tokenizer: EsmSequenceTokenizer, ) -> torch.Tensor: - return tokenize_sequence( - get_default_sequence(sequence_length), - sequence_tokenizer, - add_special_tokens=True, + assert sequence_tokenizer.mask_token_id is not None + assert sequence_tokenizer.bos_token_id is not None + assert sequence_tokenizer.eos_token_id is not None + + sequence_tokens = torch.full( + (sequence_length + 2,), sequence_tokenizer.mask_token_id ) + sequence_tokens[0] = sequence_tokenizer.bos_token_id + sequence_tokens[-1] = sequence_tokenizer.eos_token_id + + return sequence_tokens def get_default_structure_tokens( @@ -200,19 +206,22 @@ def get_default_structure_tokens( def get_default_secondary_structure_tokens( sequence_length: int, secondary_structure_tokenizer: SecondaryStructureTokenizer ) -> torch.Tensor: - return tokenize_secondary_structure( - get_default_secondary_structure(sequence_length), - secondary_structure_tokenizer, - add_special_tokens=True, + ss8_tokens = torch.full( + (sequence_length + 2,), secondary_structure_tokenizer.mask_token_id ) + ss8_tokens[0] = secondary_structure_tokenizer.bos_token_id + ss8_tokens[-1] = secondary_structure_tokenizer.eos_token_id + + return ss8_tokens def get_default_sasa_tokens( sequence_length: int, sasa_tokenizer: SASADiscretizingTokenizer ) -> torch.Tensor: - return tokenize_sasa( - get_default_sasa(sequence_length), sasa_tokenizer, add_special_tokens=True - ) + sasa_tokens = torch.full((sequence_length + 2,), sasa_tokenizer.mask_token_id) + sasa_tokens[0] = sasa_tokenizer.bos_token_id + sasa_tokens[-1] = sasa_tokenizer.eos_token_id + return sasa_tokens def get_default_function_tokens( diff --git a/esm/utils/generation.py b/esm/utils/generation.py index bccf119..d53d805 100644 --- a/esm/utils/generation.py +++ b/esm/utils/generation.py @@ -495,6 +495,7 @@ def _sample_per_prompt( sampling_config: SamplingConfig, tokenizers: TokenizerCollectionProtocol, decode_sasa_tokens: bool = True, + mask_logits_of_invalid_ids: bool = True, ) -> ForwardAndSampleOutput: assert logits_output.logits is not None @@ -513,11 +514,19 @@ def maybe_clone(x: torch.Tensor | None) -> torch.Tensor | None: if config is None: tokens_dir[track] = maybe_clone(getattr(protein, track)) continue + tokenizer = getattr(tokenizers, track) + valid_ids = ( + set(tokenizer.all_token_ids) + - set(tokenizer.special_token_ids) + - set(config.invalid_ids) + ) sampling_metadata = _sample_track( logits=getattr(logits_output.logits, track), tokens=getattr(protein, track), sampling_track_config=config, mask_idx=getattr(tokenizers, track).mask_token_id, + valid_ids=list(valid_ids), + mask_logits_of_invalid_ids=mask_logits_of_invalid_ids, ) tokens_dir[track] = sampling_metadata.pop("sampled_tokens") # (L,) track_sampling_metadata_dir[track] = sampling_metadata @@ -536,12 +545,19 @@ def maybe_clone(x: torch.Tensor | None) -> torch.Tensor | None: assert logits_output.logits.sasa is not None assert protein.sasa is not None + valid_ids = ( + set(tokenizers.sasa.all_token_ids) + - set(tokenizer.special_token_ids) + - set(config.invalid_ids) + ) sasa_logits = logits_output.logits.sasa sasa_value = sample_sasa_logits( sasa_logits, protein.sasa, sampling_track_config=config, mask_idx=tokenizers.sasa.mask_token_id, + valid_ids=list(valid_ids), + mask_logits_of_invalid_ids=mask_logits_of_invalid_ids, ) tokens_dir["sasa"] = sasa_value @@ -558,6 +574,9 @@ def maybe_clone(x: torch.Tensor | None) -> torch.Tensor | None: getattr(protein, "residue_annotations") ) else: + if config.invalid_ids is not None and len(config.invalid_ids) > 0: + warn("For function sampling, invalid_ids sampling config is not supported.") + sampling_metadata = _sample_function_track( tokenizers.function, tokens=getattr(protein, "function"), @@ -621,6 +640,8 @@ def _sample_track( tokens: torch.Tensor, sampling_track_config: SamplingTrackConfig, mask_idx: int, + valid_ids: list[int], + mask_logits_of_invalid_ids: bool = True, ) -> dict[str, torch.Tensor]: """Works with inputs that have batch dimension.""" # Sample in all positions @@ -629,7 +650,11 @@ def _sample_track( # since the logits may be computed with a longer padded batch, while tokens # are the original input sequence. sampled_tokens = sample_logits( - logits, temperature=temperature, top_p=sampling_track_config.top_p + logits, + temperature=temperature, + valid_ids=valid_ids, + top_p=sampling_track_config.top_p, + mask_logits_of_invalid_ids=mask_logits_of_invalid_ids, ) log_probs = logits.log_softmax(-1) sampling_mask = get_sampling_mask(tokens, sampling_track_config, mask_idx) diff --git a/esm/utils/misc.py b/esm/utils/misc.py index b65ba11..4562283 100644 --- a/esm/utils/misc.py +++ b/esm/utils/misc.py @@ -176,7 +176,7 @@ def unbinpack( return stack_variable_length_tensors(unpacked_tensors, pad_value) -def fp32_autocast_context(device_type: str) -> ContextManager[torch.amp.autocast]: +def fp32_autocast_context(device_type: str) -> ContextManager[torch.amp.autocast]: # type: ignore """ Returns an autocast context manager that disables downcasting by AMP. @@ -187,9 +187,9 @@ def fp32_autocast_context(device_type: str) -> ContextManager[torch.amp.autocast An autocast context manager with the specified behavior. """ if device_type == "cpu": - return torch.amp.autocast(device_type, enabled=False) + return torch.amp.autocast(device_type, enabled=False) # type: ignore elif device_type == "cuda": - return torch.amp.autocast(device_type, dtype=torch.float32) + return torch.amp.autocast(device_type, dtype=torch.float32) # type: ignore else: raise ValueError(f"Unsupported device type: {device_type}") diff --git a/esm/utils/sampling.py b/esm/utils/sampling.py index b2081ca..3db37fd 100644 --- a/esm/utils/sampling.py +++ b/esm/utils/sampling.py @@ -153,7 +153,9 @@ def validate_sampling_config( def sample_logits( logits: torch.Tensor, temperature: float | torch.Tensor, + valid_ids: list[int] = [], top_p: float | torch.Tensor = 1.0, + mask_logits_of_invalid_ids: bool = True, ): """Default sampling from logits. @@ -167,15 +169,22 @@ def sample_logits( temperature = _tensorize_like(temperature, logits) + batch_dims = logits.size()[:-1] + logits = logits.reshape(-1, logits.shape[-1]) + + # Only sample from valid ids + # the /logits endpoint should receive unmodified logits + if mask_logits_of_invalid_ids: + mask = torch.ones_like(logits, dtype=torch.bool) + mask[:, valid_ids] = False + logits[mask] = -torch.inf + if torch.all(temperature == 0): ids = logits.argmax(-1) return ids assert not torch.any(temperature == 0), "Partial temperature 0 not supported." - batch_dims = logits.size()[:-1] - logits = logits.reshape(-1, logits.shape[-1]) - # Sample from all logits probs = F.softmax(logits / temperature[..., None], dim=-1) ids = torch.multinomial(probs, 1).squeeze(1) @@ -250,7 +259,16 @@ def sample_sasa_logits( tokens: torch.Tensor, sampling_track_config: SamplingTrackConfig, mask_idx: int, + valid_ids: list[int], + mask_logits_of_invalid_ids: bool = True, ) -> torch.Tensor: + # Only sample from valid ids + # the /logits endpoint should receive unmodified logits + if mask_logits_of_invalid_ids: + mask = torch.ones_like(logits, dtype=torch.bool) + mask[:, valid_ids] = False + logits[mask] = -torch.inf + sasa_probs = torch.nn.functional.softmax(logits, dim=-1) max_prob_idx = torch.argmax(sasa_probs, dim=-1) sasa_bins = torch.tensor([0] + SASA_DISCRETIZATION_BOUNDARIES, dtype=torch.float) diff --git a/esm/utils/structure/protein_chain.py b/esm/utils/structure/protein_chain.py index 2169aed..efe8a64 100644 --- a/esm/utils/structure/protein_chain.py +++ b/esm/utils/structure/protein_chain.py @@ -213,7 +213,6 @@ def cbeta_contacts(self, distance_threshold: float = 8.0) -> np.ndarray: distance = self.pdist_CB contacts = (distance < distance_threshold).astype(np.int64) contacts[np.isnan(distance)] = -1 - contacts = squareform(contacts) np.fill_diagonal(contacts, -1) return contacts @@ -391,15 +390,16 @@ def rmsd( def lddt_ca( self, - target: ProteinChain, + native: ProteinChain, mobile_inds: list[int] | np.ndarray | None = None, target_inds: list[int] | np.ndarray | None = None, **kwargs, ) -> float | np.ndarray: """Compute the LDDT between this protein chain and another. + NOTE: LDDT IS NOT SYMMETRIC. The call should always be prediction.lddt_ca(native). Arguments: - target (ProteinChain): The other protein chain to compare to. + native (ProteinChain): The ground truth protein chain mobile_inds (list[int], np.ndarray, optional): The indices of the mobile atoms to align. These are NOT residue indices target_inds (list[int], np.ndarray, optional): The indices of the target atoms to align. These are NOT residue indices @@ -407,11 +407,10 @@ def lddt_ca( float | np.ndarray: The LDDT score between the two protein chains, either a single float or per-residue LDDT scores if `per_residue` is True. """ - lddt = compute_lddt_ca( torch.tensor(self.atom37_positions[mobile_inds]).unsqueeze(0), - torch.tensor(target.atom37_positions[target_inds]).unsqueeze(0), - torch.tensor(self.atom37_mask[mobile_inds]).unsqueeze(0), + torch.tensor(native.atom37_positions[target_inds]).unsqueeze(0), + torch.tensor(native.atom37_mask[mobile_inds]).unsqueeze(0), **kwargs, ) return float(lddt) if lddt.numel() == 1 else lddt.numpy().flatten() diff --git a/esm/utils/structure/protein_structure.py b/esm/utils/structure/protein_structure.py index 6bea8ad..6779d3b 100644 --- a/esm/utils/structure/protein_structure.py +++ b/esm/utils/structure/protein_structure.py @@ -6,7 +6,7 @@ import torch import torch.nn.functional as F from torch import Tensor -from torch.cuda.amp import autocast # type: ignore +from torch.amp import autocast # type: ignore from esm.utils import residue_constants from esm.utils.misc import unbinpack @@ -66,7 +66,7 @@ def normalize(x: ArrayOrTensor): @torch.no_grad() -@autocast(enabled=False) +@autocast("cuda", enabled=False) def compute_alignment_tensors( mobile: torch.Tensor, target: torch.Tensor, @@ -161,7 +161,7 @@ def compute_alignment_tensors( @torch.no_grad() -@autocast(enabled=False) +@autocast("cuda", enabled=False) def compute_rmsd_no_alignment( aligned: torch.Tensor, target: torch.Tensor, @@ -210,7 +210,7 @@ def compute_rmsd_no_alignment( @torch.no_grad() -@autocast(enabled=False) +@autocast("cuda", enabled=False) def compute_affine_and_rmsd( mobile: torch.Tensor, target: torch.Tensor, diff --git a/esm/widgets/components/results_visualizer.py b/esm/widgets/components/results_visualizer.py index 1e4f161..e2544c0 100644 --- a/esm/widgets/components/results_visualizer.py +++ b/esm/widgets/components/results_visualizer.py @@ -31,6 +31,7 @@ def create_results_visualizer( None, ] | None = None, + include_title: bool = True, ) -> widgets.Widget: if modality == "structure": # Sort structures by pTM @@ -118,13 +119,16 @@ def on_prev_button_clicked(b): update_page() - return widgets.VBox( - [ - widgets.HTML(value="

Generated Samples

"), - widgets.HBox([prev_button, next_button, page_label]), - output, - ] - ) + results_ui = widgets.VBox([]) + title = (widgets.HTML(value="

Generated Samples

"),) + nav_bar = widgets.HBox([prev_button, next_button, page_label]) + + if include_title: + results_ui.children += title + if total_pages > 1: + results_ui.children += (nav_bar,) + results_ui.children += (output,) + return results_ui def add_line_breaks(sequence: str, line_length: int = 120) -> str: @@ -148,15 +152,27 @@ def create_sequence_results_page( copy_to_prompt_button.on_click( lambda b: copy_to_prompt_callback(item.sequence) ) - entry = widgets.VBox( - [ - copy_to_prompt_button, - widgets.HTML( - value=f'
{add_line_breaks(item.sequence, line_length) if item.sequence else "No sequence"}
' - ), - ], - layout={"border": "1px solid gray"}, - ) + + if copy_to_prompt_callback: + entry = widgets.VBox( + [ + copy_to_prompt_button, + widgets.HTML( + value=f'
{add_line_breaks(item.sequence, line_length) if item.sequence else "No sequence"}
' + ), + ], + layout={"border": "1px solid gray"}, + ) + else: + entry = widgets.VBox( + [ + widgets.HTML( + value=f'
{add_line_breaks(item.sequence, line_length) if item.sequence else "No sequence"}
' + ) + ], + layout={"border": "1px solid gray"}, + ) + sequence_items.append(entry) return widgets.VBox(sequence_items) @@ -185,11 +201,17 @@ def create_sasa_results_page( data_array=sasa, cmap="Reds", ) - sasa_items.append( - widgets.VBox( - [copy_to_prompt_button, output], layout={"border": "1px solid gray"} + + if copy_to_prompt_callback: + sasa_items.append( + widgets.VBox( + [copy_to_prompt_button, output], layout={"border": "1px solid gray"} + ) + ) + else: + sasa_items.append( + widgets.VBox([output], layout={"border": "1px solid gray"}) ) - ) return widgets.VBox(sasa_items) @@ -236,11 +258,15 @@ def create_secondary_structure_results_page( highlighted_ranges=[], cmap="Set2", ) - ss_items.append( - widgets.VBox( - [copy_to_prompt_button, output], layout={"border": "1px solid gray"} + + if copy_to_prompt_callback: + ss_items.append( + widgets.VBox( + [copy_to_prompt_button, output], layout={"border": "1px solid gray"} + ) ) - ) + else: + ss_items.append(widgets.VBox([output], layout={"border": "1px solid gray"})) return widgets.VBox(ss_items) @@ -304,8 +330,14 @@ def confidence_to_color(confidence) -> str: ) row = i // grid_size col = i % grid_size + + if copy_to_prompt_callback: + header = widgets.HBox([copy_to_prompt_button, ptm_label]) + else: + header = widgets.HBox([ptm_label]) + grid[row, col] = widgets.VBox( - [widgets.HBox([copy_to_prompt_button, ptm_label]), output], + [header, output], layout={"border": "1px solid gray"}, ) return grid @@ -344,10 +376,12 @@ def create_function_annotations_results_page( interpro_annotations, sequence_length=len(item), ) - function_items.append( - widgets.VBox( + if copy_to_prompt_callback: + content = widgets.VBox( [copy_to_prompt_button, image], layout={"border": "1px solid gray"} ) - ) + else: + content = widgets.VBox([image], layout={"border": "1px solid gray"}) + function_items.append(content) return widgets.VBox(function_items) diff --git a/esm/widgets/utils/protein_import.py b/esm/widgets/utils/protein_import.py index 3c19574..9fdb4b2 100644 --- a/esm/widgets/utils/protein_import.py +++ b/esm/widgets/utils/protein_import.py @@ -9,9 +9,15 @@ class ProteinImporter: - def __init__(self) -> None: - self._protein_list = [] - self._protein_workspace = {} + def __init__( + self, + max_proteins: int | None = None, + autoload: bool = False, + ) -> None: + self._protein_list: list[tuple[str, ProteinChain]] = [] + self._protein_workspace: dict[str, str] = {} + self.max_proteins = max_proteins + self.autoload = autoload # Workspace section self.workspace_title = widgets.HTML( @@ -58,7 +64,7 @@ def __init__(self) -> None: self.error_output = widgets.Output() self.entries_box = widgets.VBox() - self.pdb_id_add_button.on_click(self.add_pdb_id) + self.pdb_id_add_button.on_click(self.on_click_add) self.pdb_uploader.observe(self.on_upload, names="value") self.delete_callbacks: list[Callable[[], None]] = [] @@ -77,11 +83,18 @@ def __init__(self) -> None: def protein_list(self): return self._protein_list - def add_pdb_id(self, _): + def on_click_add(self, _): pdb_id = self.pdb_id_input.value chain_id = self.pdb_chain_input.value or "detect" + self.add_pdb_id(pdb_id, chain_id) + + def add_pdb_id(self, pdb_id: str, chain_id: str): try: self.error_output.clear_output() + + if self.max_proteins and len(self._protein_list) >= self.max_proteins: + raise ValueError("Maximum number of proteins reached") + if not pdb_id: raise ValueError("PDB ID or Filename is required") if pdb_id.lower().endswith(".pdb"): @@ -124,11 +137,19 @@ def delete_entry(b): def on_upload(self, _): try: self.error_output.clear_output() + + if self.max_proteins and len(self._protein_list) >= self.max_proteins: + raise ValueError("Maximum number of proteins reached") + uploaded_file = next(iter(self.pdb_uploader.value)) filename: str = uploaded_file["name"] str_content = codecs.decode(uploaded_file["content"], encoding="utf-8") self._protein_workspace[filename] = str_content self.workspace.children += (widgets.Label(value=f"{filename}"),) + + if self.autoload: + self.add_pdb_id(filename, "detect") + except Exception as e: with self.error_output: wrapped_print(f"Error: {e}") diff --git a/esm/widgets/views/inverse_folding.py b/esm/widgets/views/inverse_folding.py new file mode 100644 index 0000000..903ceed --- /dev/null +++ b/esm/widgets/views/inverse_folding.py @@ -0,0 +1,98 @@ +from ipywidgets import widgets + +from esm.sdk.api import ( + ESM3InferenceClient, + ESMProtein, + ESMProteinError, + GenerationConfig, +) +from esm.widgets.components.results_visualizer import ( + create_results_visualizer, +) +from esm.widgets.utils.printing import wrapped_print +from esm.widgets.utils.protein_import import ( + ProteinImporter, +) + + +def create_inverse_folding_ui(client: ESM3InferenceClient) -> widgets.Widget: + # Alow a single protein and immediately load it from workspace + protein_importer = ProteinImporter(max_proteins=1, autoload=True) + output = widgets.Output() + inverse_folding_ui = widgets.VBox([protein_importer.importer_ui, output]) + + inverse_fold_button = widgets.Button( + description="Inverse Fold", + disabled=True, + tooltip="Click to predict the protein sequence from the structure", + style={"button_color": "lightgreen"}, + ) + + def get_protein() -> ESMProtein: + [first_protein] = protein_importer.protein_list + protein_id, protein_chain = first_protein + protein = ESMProtein.from_protein_chain(protein_chain) + + # NOTE: We ignore all properties except structure + protein.sequence = None + protein.secondary_structure = None + protein.sasa = None + protein.function_annotations = None + return protein + + def on_new_protein(_): + is_protein = len(protein_importer.protein_list) > 0 + inverse_fold_button.disabled = not is_protein + inverse_folding_ui.children = [ + protein_importer.importer_ui, + inverse_fold_button, + output, + ] + + def validate_inverse_fold(_): + if len(protein_importer.protein_list) == 0: + inverse_fold_button.disabled = True + else: + inverse_fold_button.disabled = False + + def on_click_inverse_fold(_): + try: + # Reset the output and results + output.clear_output() + inverse_folding_ui.children = [ + protein_importer.importer_ui, + inverse_fold_button, + output, + ] + # Predict the protein's sequence + protein = get_protein() + with output: + print("Predicting the protein sequence from the structure...") + protein = client.generate( + input=protein, + config=GenerationConfig(track="sequence", num_steps=1), + ) + if isinstance(protein, ESMProteinError): + wrapped_print(f"Protein Error: {protein.error_msg}") + elif isinstance(protein, ESMProtein): + sequence_results = create_results_visualizer( + modality="sequence", + samples=[protein], + items_per_page=1, + include_title=False, + ) + output.clear_output() + inverse_folding_ui.children = [ + protein_importer.importer_ui, + inverse_fold_button, + sequence_results, + ] + except Exception as e: + with output: + wrapped_print(e) + + inverse_fold_button.on_click(on_click_inverse_fold) + protein_importer.entries_box.observe(on_new_protein, names="children") + protein_importer.register_delete_callback(lambda: validate_inverse_fold(None)) + + return inverse_folding_ui diff --git a/esm/widgets/views/prediction.py b/esm/widgets/views/prediction.py new file mode 100644 index 0000000..fd6938f --- /dev/null +++ b/esm/widgets/views/prediction.py @@ -0,0 +1,177 @@ +from ipywidgets import widgets + +from esm.sdk.api import ( + ESM3InferenceClient, + ESMProtein, + ESMProteinError, + GenerationConfig, +) +from esm.widgets.components.results_visualizer import ( + create_results_visualizer, +) +from esm.widgets.utils.printing import wrapped_print +from esm.widgets.utils.protein_import import ( + ProteinImporter, +) + + +def create_prediction_ui(client: ESM3InferenceClient) -> widgets.Widget: + # Alow a single protein and immediately load it from workspace + protein_importer = ProteinImporter(max_proteins=1, autoload=True) + + sequence_input_ui = widgets.VBox( + [ + widgets.HTML(value="

Or enter a protein sequence:

"), + widgets.Textarea( + placeholder="Enter protein sequence", + layout=widgets.Layout(width="400px", height="100px"), + ), + ] + ) + + input_ui = widgets.Tab(children=[protein_importer.importer_ui, sequence_input_ui]) + input_ui.set_title(0, "Add Protein") + input_ui.set_title(1, "Enter Sequence") + + predict_button = widgets.Button( + description="Predict", + disabled=True, + tooltip="Click to predict the protein's properties", + style={"button_color": "lightgreen"}, + ) + + output = widgets.Output() + + prediction_ui = widgets.VBox([input_ui, output]) + + def get_protein() -> ESMProtein: + if input_ui.selected_index == 0: + [first_protein] = protein_importer.protein_list + protein_id, protein_chain = first_protein + protein = ESMProtein.from_protein_chain(protein_chain) + + # NOTE: We ignore all properties except sequence and structure + protein.secondary_structure = None + protein.sasa = None + protein.function_annotations = None + return protein + else: + sequence = sequence_input_ui.children[1].value + return ESMProtein(sequence=sequence) + + def on_new_sequence(_): + is_sequence = len(sequence_input_ui.children[1].value) > 0 + predict_button.disabled = not is_sequence + prediction_ui.children = [input_ui, predict_button, output] + + def on_new_protein(_): + is_protein = len(protein_importer.protein_list) > 0 + predict_button.disabled = not is_protein + prediction_ui.children = [input_ui, predict_button, output] + + def validate_predict(_): + if input_ui.selected_index == 0: + if len(protein_importer.protein_list) > 0: + predict_button.disabled = False + else: + predict_button.disabled = True + else: + if len(sequence_input_ui.children[1].value) == 0: + predict_button.disabled = True + else: + predict_button.disabled = False + + def on_click_predict(_): + try: + # Reset the output and results + output.clear_output() + prediction_ui.children = [ + input_ui, + predict_button, + output, + ] + # Predict the protein's properties + with output: + protein = get_protein() + + tracks = ["structure", "secondary_structure", "sasa", "function"] + + success = True + for track in tracks: + print(f"Predicting {track}...") + protein = client.generate( + protein, config=GenerationConfig(track=track) + ) + if isinstance(protein, ESMProteinError): + wrapped_print(f"Protein Error: {protein.error_msg}") + success = False + + assert isinstance(protein, ESMProtein) + + if success: + structure_results = create_results_visualizer( + modality="structure", + samples=[protein], + items_per_page=1, + include_title=False, + ) + secondary_structure_results = create_results_visualizer( + modality="secondary_structure", + samples=[protein], + items_per_page=1, + include_title=False, + ) + sasa_results = create_results_visualizer( + modality="sasa", + samples=[protein], + items_per_page=1, + include_title=False, + ) + function_results = create_results_visualizer( + modality="function", + samples=[protein], + items_per_page=1, + include_title=False, + ) + results_ui = widgets.Tab( + children=[ + structure_results, + secondary_structure_results, + sasa_results, + function_results, + ] + ) + results_ui.set_title(0, "Structure") + results_ui.set_title(1, "Secondary Structure") + results_ui.set_title(2, "SASA") + results_ui.set_title(3, "Function") + + output.clear_output() + prediction_ui.children = [ + input_ui, + predict_button, + output, + results_ui, + ] + + except Exception as e: + with output: + wrapped_print(e) + + predict_button.on_click(on_click_predict) + protein_importer.entries_box.observe( + on_new_protein, + names="children", + ) + protein_importer.register_delete_callback(lambda: validate_predict(None)) + + sequence_input_ui.children[1].observe( + on_new_sequence, + names="value", + ) + input_ui.observe( + validate_predict, + names="selected_index", + ) + + return prediction_ui diff --git a/pyproject.toml b/pyproject.toml index 1b62455..b8bd44f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "esm" -version = "3.0.3" +version = "3.0.4" description = "EvolutionaryScale open model repository" readme = "README.md" requires-python = ">=3.10" @@ -27,7 +27,7 @@ dependencies = [ "transformers", "ipython", "einops", - "biotite", + "biotite==0.41.2", "msgpack-numpy", "biopython", "scikit-learn", diff --git a/tools/generation.ipynb b/tools/generate.ipynb similarity index 90% rename from tools/generation.ipynb rename to tools/generate.ipynb index bf6e5c5..c492ecd 100644 --- a/tools/generation.ipynb +++ b/tools/generate.ipynb @@ -20,16 +20,6 @@ "nest_asyncio.apply()" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "%load_ext autoreload\n", - "%autoreload 2" - ] - }, { "cell_type": "code", "execution_count": null, diff --git a/tools/invfold.ipynb b/tools/invfold.ipynb new file mode 100644 index 0000000..70fc0c9 --- /dev/null +++ b/tools/invfold.ipynb @@ -0,0 +1,79 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!pip install esm pydssp pygtrie dna-features-viewer nest_asyncio" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import nest_asyncio\n", + "\n", + "nest_asyncio.apply()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from esm.widgets.utils.types import ClientInitContainer\n", + "from esm.widgets.views.login import create_login_ui\n", + "\n", + "client_init = ClientInitContainer()\n", + "create_login_ui(client_init)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from esm.widgets.views.inverse_folding import (\n", + " create_inverse_folding_ui,\n", + ")\n", + "\n", + "client = client_init()\n", + "create_inverse_folding_ui(client)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.0" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/tools/predict.ipynb b/tools/predict.ipynb new file mode 100644 index 0000000..bd5712d --- /dev/null +++ b/tools/predict.ipynb @@ -0,0 +1,77 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!pip install esm pydssp pygtrie dna-features-viewer nest_asyncio" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import nest_asyncio\n", + "\n", + "nest_asyncio.apply()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from esm.widgets.utils.types import ClientInitContainer\n", + "from esm.widgets.views.login import create_login_ui\n", + "\n", + "client_init = ClientInitContainer()\n", + "create_login_ui(client_init)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from esm.widgets.views.prediction import create_prediction_ui\n", + "\n", + "client = client_init()\n", + "create_prediction_ui(client)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.0" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}