Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Embeddings Functionality #111

Open
noahho opened this issue Jan 10, 2025 · 1 comment
Open

Embeddings Functionality #111

noahho opened this issue Jan 10, 2025 · 1 comment
Labels
enhancement New feature or request help wanted Extra attention is needed

Comments

@noahho
Copy link
Collaborator

noahho commented Jan 10, 2025

The embedding functionality was removed during the recent code refactor. Previously, the unsupervised model had a get_embedding() method that worked with the model's embedding method.

@noahho noahho added enhancement New feature or request help wanted Extra attention is needed labels Jan 10, 2025
@noahho
Copy link
Collaborator Author

noahho commented Jan 12, 2025

Here are the most relevant connected pieces before the refactor:

  1. Main embedding method:
def get_embeddings(
    self, X: torch.Tensor, additional_y: dict = None
) -> torch.Tensor:
    """
    Get the embeddings for the input data `X`.

    Parameters:
        X (torch.Tensor): The input data tensor.
        additional_y (dict, optional): Additional labels to use during prediction.

    Returns:
        torch.Tensor: The computed embeddings.
    """
    return self.predict_full(
        X, additional_y=additional_y, get_additional_outputs=["test_embeddings"]
    )["test_embeddings"]
  1. The predict_full path that processes embeddings:
def predict_full(self, X, additional_y=None, get_additional_outputs=None) -> dict:
    """
    Predicts the target y given the input X.

    Parameters:
        X:
        additional_y: Additional inputs
        get_additional_outputs: Keys for additional outputs to return.

    Returns:
         (dict: The predictions, dict: Additional outputs)
    """
    X_full, y_full, additional_y, eval_position = self.predict_common_setup(
        X, additional_y_eval=additional_y
    )

    prediction, additional_outputs = self.transformer_predict(
        eval_xs=X_full,
        eval_ys=y_full, 
        eval_position=eval_position,
        additional_ys=additional_y,
        get_additional_outputs=get_additional_outputs,
        **get_params_from_config(self.c_processed_),
    )

    return {"proba": prediction, **additional_outputs}
  1. The transformer_predict method that generates embeddings:
def transformer_predict(
    self,
    eval_xs: torch.Tensor,  # shape (num_examples, [1], num_features)
    eval_ys: torch.Tensor,  # shape (num_examples, [1], [1])
    eval_position: int,
    bar_distribution: FullSupportBarDistribution | None = None,
    reweight_probs_based_on_train=False,
    additional_ys=None,
    cache_trainset_representations: bool = False,
    get_additional_outputs: list[str] = None,
) -> tuple[torch.Tensor, dict]:
    """
    Generates predictions from the transformer model.

    This method builds the ensemble configurations, applies preprocessing, runs the transformer
    model on the preprocessed data, and then aggregates the predictions from each configuration.
    """
    # ... initialization code ...

    outputs, additional_outputs = self._batch_predict(
        inputs=inputs,
        labels=labels,
        additional_ys=additional_ys_list,
        categorical_inds=categorical_inds,
        eval_position=eval_position,
        cache_trainset_representations=cache_trainset_representations,
        get_additional_outputs=get_additional_outputs,
    )

    return output, additional_outputs
  1. The batch predict method that handles embeddings:
def _batch_predict(
    self,
    inputs: list[torch.Tensor],
    labels: list[torch.Tensor],
    additional_ys: dict[list[torch.Tensor]],
    categorical_inds: list[list[int]],
    eval_position: int,
    cache_trainset_representations: bool = False,
    get_additional_outputs: list[str] = None,
) -> (torch.Tensor, dict[list[torch.Tensor]]):
    """Handles batch processing and embedding generation"""
    
    outputs = []
    additional_outputs = (
        {}
        if get_additional_outputs is None
        else {k: [] for k in get_additional_outputs}
    )

    for batch_input, batch_label, batch_categorical_inds, batch_additional_ys in zip(
        inputs, labels, categorical_inds, additional_ys_inputs
    ):
        output = model(
            (
                style_expanded,
                {"main": batch_input.to(self.device_)},
                labels,
            ),
            single_eval_pos=eval_position,
            only_return_standard_out=only_return_standard_out,
            categorical_inds=batch_categorical_inds,
        )

        if isinstance(output, tuple):
            output, output_once = output
            
        if additional_outputs:
            standard_prediction_output = output["standard"]
            for k in additional_outputs:
                additional_outputs[k].append(output[k].cpu())
        else:
            standard_prediction_output = output

        outputs += [standard_prediction_output.detach().cpu()]

    outputs = torch.cat(outputs, 1)

    if additional_outputs:
        for k in additional_outputs:
            additional_outputs[k] = torch.cat(additional_outputs[k], dim=1)[
                :, torch.argsort(implied_permutation), :
            ]

    return outputs[:, torch.argsort(implied_permutation), :], additional_outputs

The full embedding pipeline works like this:

  1. get_embeddings() is called with input data X
  2. This calls predict_full() with get_additional_outputs=["test_embeddings"]
  3. predict_full() sets up the data and calls transformer_predict()
  4. transformer_predict() preprocesses the data and calls _batch_predict()
  5. _batch_predict() runs the model on batches and collects embeddings in additional_outputs
  6. The embeddings are returned up through the chain back to the user

Key points:

  • Embeddings are generated as "test_embeddings" in the additional outputs
  • The model processes batches sequentially to handle large datasets
  • The embeddings capture the learned internal representations from the transformer
  • The embeddings maintain the ordering of the input data
  • Additional labels can be provided to influence embedding generation
  • The process is integrated with the prediction pipeline for efficiency

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request help wanted Extra attention is needed
Projects
None yet
Development

No branches or pull requests

1 participant