-
Notifications
You must be signed in to change notification settings - Fork 166
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Embeddings Functionality #111
Labels
Comments
noahho
added
enhancement
New feature or request
help wanted
Extra attention is needed
labels
Jan 10, 2025
Here are the most relevant connected pieces before the refactor:
def get_embeddings(
self, X: torch.Tensor, additional_y: dict = None
) -> torch.Tensor:
"""
Get the embeddings for the input data `X`.
Parameters:
X (torch.Tensor): The input data tensor.
additional_y (dict, optional): Additional labels to use during prediction.
Returns:
torch.Tensor: The computed embeddings.
"""
return self.predict_full(
X, additional_y=additional_y, get_additional_outputs=["test_embeddings"]
)["test_embeddings"]
def predict_full(self, X, additional_y=None, get_additional_outputs=None) -> dict:
"""
Predicts the target y given the input X.
Parameters:
X:
additional_y: Additional inputs
get_additional_outputs: Keys for additional outputs to return.
Returns:
(dict: The predictions, dict: Additional outputs)
"""
X_full, y_full, additional_y, eval_position = self.predict_common_setup(
X, additional_y_eval=additional_y
)
prediction, additional_outputs = self.transformer_predict(
eval_xs=X_full,
eval_ys=y_full,
eval_position=eval_position,
additional_ys=additional_y,
get_additional_outputs=get_additional_outputs,
**get_params_from_config(self.c_processed_),
)
return {"proba": prediction, **additional_outputs}
def transformer_predict(
self,
eval_xs: torch.Tensor, # shape (num_examples, [1], num_features)
eval_ys: torch.Tensor, # shape (num_examples, [1], [1])
eval_position: int,
bar_distribution: FullSupportBarDistribution | None = None,
reweight_probs_based_on_train=False,
additional_ys=None,
cache_trainset_representations: bool = False,
get_additional_outputs: list[str] = None,
) -> tuple[torch.Tensor, dict]:
"""
Generates predictions from the transformer model.
This method builds the ensemble configurations, applies preprocessing, runs the transformer
model on the preprocessed data, and then aggregates the predictions from each configuration.
"""
# ... initialization code ...
outputs, additional_outputs = self._batch_predict(
inputs=inputs,
labels=labels,
additional_ys=additional_ys_list,
categorical_inds=categorical_inds,
eval_position=eval_position,
cache_trainset_representations=cache_trainset_representations,
get_additional_outputs=get_additional_outputs,
)
return output, additional_outputs
def _batch_predict(
self,
inputs: list[torch.Tensor],
labels: list[torch.Tensor],
additional_ys: dict[list[torch.Tensor]],
categorical_inds: list[list[int]],
eval_position: int,
cache_trainset_representations: bool = False,
get_additional_outputs: list[str] = None,
) -> (torch.Tensor, dict[list[torch.Tensor]]):
"""Handles batch processing and embedding generation"""
outputs = []
additional_outputs = (
{}
if get_additional_outputs is None
else {k: [] for k in get_additional_outputs}
)
for batch_input, batch_label, batch_categorical_inds, batch_additional_ys in zip(
inputs, labels, categorical_inds, additional_ys_inputs
):
output = model(
(
style_expanded,
{"main": batch_input.to(self.device_)},
labels,
),
single_eval_pos=eval_position,
only_return_standard_out=only_return_standard_out,
categorical_inds=batch_categorical_inds,
)
if isinstance(output, tuple):
output, output_once = output
if additional_outputs:
standard_prediction_output = output["standard"]
for k in additional_outputs:
additional_outputs[k].append(output[k].cpu())
else:
standard_prediction_output = output
outputs += [standard_prediction_output.detach().cpu()]
outputs = torch.cat(outputs, 1)
if additional_outputs:
for k in additional_outputs:
additional_outputs[k] = torch.cat(additional_outputs[k], dim=1)[
:, torch.argsort(implied_permutation), :
]
return outputs[:, torch.argsort(implied_permutation), :], additional_outputs The full embedding pipeline works like this:
Key points:
|
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
The embedding functionality was removed during the recent code refactor. Previously, the unsupervised model had a get_embedding() method that worked with the model's embedding method.
The text was updated successfully, but these errors were encountered: