Skip to content

Commit

Permalink
Documentation updates
Browse files Browse the repository at this point in the history
  • Loading branch information
jpwchang committed May 20, 2024
1 parent d2e9f51 commit 6d5ac17
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 8 deletions.
8 changes: 4 additions & 4 deletions convokit/forecaster/CRAFT/runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,14 +104,14 @@ def train(input_variable, dialog_lengths, dialog_lengths_list, utt_lengths, batc
return loss.item()

def evaluateBatch(encoder, context_encoder, predictor, voc, input_batch, dialog_lengths,
dialog_lengths_list, utt_lengths, batch_indices, dialog_indices, batch_size, device, max_length):
dialog_lengths_list, utt_lengths, batch_indices, dialog_indices, batch_size, device, max_length, threshold=0.5):
# Set device options
input_batch = input_batch.to(device)
dialog_lengths = dialog_lengths.to(device)
utt_lengths = utt_lengths.to(device)
# Predict future attack using predictor
scores = predictor(input_batch, dialog_lengths, dialog_lengths_list, utt_lengths, batch_indices, dialog_indices, batch_size, max_length)
predictions = (scores > 0.5).float()
predictions = (scores > threshold).float()
return predictions, scores

def validate(dataset, encoder, context_encoder, predictor, voc, batch_size, device, max_length, batch_iterator_func):
Expand Down Expand Up @@ -211,7 +211,7 @@ def trainIters(voc, pairs, val_pairs, encoder, context_encoder, attack_clf,

return best_model

def evaluateDataset(dataset, encoder, context_encoder, predictor, voc, batch_size, device, max_length, batch_iterator_func, pred_col_name, score_col_name):
def evaluateDataset(dataset, encoder, context_encoder, predictor, voc, batch_size, device, max_length, batch_iterator_func, threshold, pred_col_name, score_col_name):
# create a batch iterator for the given data
batch_iterator = batch_iterator_func(voc, dataset, batch_size, shuffle=False)
# find out how many iterations we will need to cover the whole dataset
Expand All @@ -229,7 +229,7 @@ def evaluateDataset(dataset, encoder, context_encoder, predictor, voc, batch_siz
# run the model
predictions, scores = evaluateBatch(encoder, context_encoder, predictor, voc, input_variable,
dialog_lengths, dialog_lengths_list, utt_lengths, batch_indices, dialog_indices,
true_batch_size, device, max_length)
true_batch_size, device, max_length, threshold)

# format the output as a dataframe (which we can later re-join with the corpus)
for i in range(true_batch_size):
Expand Down
23 changes: 23 additions & 0 deletions convokit/forecaster/CRAFTModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,13 @@ class CRAFTModel(ForecasterModel):
(craft-wiki-pretrained and craft-cmv-pretrained, respectively), which provide trained versions of the
underlying utterance and conversation encoder layers but leave the classification layers at their
random initializations so that they can be fitted to your data.
:param initial_weights: Specifies where to find the saved model to be loaded to initialize CRAFT. To use ConvoKit's provided models, use "craft-wiki-pretrained" for the model pretrained on Wikipedia data, or "craft-wiki-finetuned" for the model already fine-tuned on CGA-WIKI. Replace "wiki" with "cmv" for the Reddit CMV equivalents. Alternatively, if you have a custom model you want to use, you can pass in the full path to the saved PyTorch checkpoint file.
:param vocab_index2word: File containing the mapping from vocabulary index to raw string tokens. If you are using a provided model, you MUST leave this as the default value of "auto" (other values will be ignored and overridden to "auto"). Conversely, if using a custom model, you CANNOT leave this as "auto" and you must provide a full path to the vocabulary file that you made for your custom model.
:param vocab_word2index: File containing the mapping from raw string tokens to vocabulary index. If you are using a provided model, you MUST leave this as the default value of "auto" (other values will be ignored and overridden to "auto"). Conversely, if using a custom model, you CANNOT leave this as "auto" and you must provide a full path to the vocabulary file that you made for your custom model.
:param decision_threshold: Output probability beyond which a forecast should be considered "positive"/"True". Highly recommended to leave this at auto, which will use published values for the provided models, or 0.5 for custom models.
:param torch_device: "cpu" or "cuda" (for GPUs). If you have access to a GPU it is strongly recommended to set this to "cuda"; the default is "cpu" only for compatibility with non-GPU setups.
:param config: Allows overwriting of CRAFT hyperparameters. Strongly recommended to keep this at default unless you know what you're doing!
"""

def __init__(
Expand Down Expand Up @@ -170,6 +177,12 @@ def _init_craft(self):
return embedding, encoder, context_encoder, attack_clf

def fit(self, contexts, val_contexts=None):
"""
Fine-tune the CRAFT model, and save the best model according to validation performance.
:param contexts: an iterator over context tuples, provided by the Forecaster framework
:param val_contexts: an iterator over context tuples to be used only for validation. IMPORTANT: this is marked Optional only for compatibility with the generic Forecaster API; CRAFT actually REQUIRES a validation set so leaving this parameter at None will raise an error!
"""
# convert the input contexts into CRAFT's data format
train_pairs = self._context_to_craft_data(contexts)
print("Processed", len(train_pairs), "context tuples for model training")
Expand Down Expand Up @@ -209,6 +222,15 @@ def fit(self, contexts, val_contexts=None):
self._model = best_model

def transform(self, contexts, forecast_attribute_name, forecast_prob_attribute_name):
"""
Run a fine-tuned CRAFT model on the provided data
:param contexts: context tuples from the Forecaster framework
:param forecast_attribute_name: Forecaster will use this to look up the table column containing your model's discretized predictions (see output specification below)
:param forecast_prob_attribute_name: Forecaster will use this to look up the table column containing your model's raw forecast probabilities (see output specification below)
:return: a Pandas DataFrame, with one row for each context, indexed by the ID of that context's current utterance. Contains two columns, one with raw probabilities named according to forecast_prob_attribute_name, and one with discretized (binary) forecasts named according to forecast_attribute_name
"""
# convert the input contexts into CRAFT's data format
test_pairs = self._context_to_craft_data(contexts)
print("Processed", len(test_pairs), "context tuples for model evaluation")
Expand All @@ -235,6 +257,7 @@ def transform(self, contexts, forecast_attribute_name, forecast_prob_attribute_n
self._device,
MAX_LENGTH,
batchIterator,
self._decision_threshold,
forecast_attribute_name,
forecast_prob_attribute_name
)
Expand Down
2 changes: 2 additions & 0 deletions convokit/forecaster/forecasterModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,7 @@ def transform(self, contexts, forecast_attribute_name, forecast_prob_attribute_n
in the form of a DataFrame indexed by (current) utterance ID
:param contexts: an iterator over context tuples
:return: a Pandas DataFrame, with one row for each context, indexed by the ID of that context's current utterance. Contains two columns, one with raw probabilities named according to forecast_prob_attribute_name, and one with discretized (binary) forecasts named according to forecast_attribute_name. Subclass implementations of ForecasterModel MUST adhere to this return value specification!
"""
pass
44 changes: 40 additions & 4 deletions examples/forecaster/CRAFT Forecaster demo.ipynb
Original file line number Diff line number Diff line change
@@ -1,5 +1,17 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "a76caad0-a440-43cf-bfcd-af0ad6d68042",
"metadata": {},
"source": [
"# ConvoKit Forecaster framework: CRAFT demo\n",
"\n",
"The `Forecaster` class provides a generic interface to *conversational forecasting models*, a class of models designed to computationally capture the trajectory of conversations in order to predict future events. Though individual conversational forecasting models can get quite complex, the `Forecaster` API abstracts away the implementation details into a standard fit-transform interface. To demonstrate the power of this framework, this notebook walks through an example of fine-tuning the CRAFT conversational forecasting model (Chang and Danescu-Niculescu-Mizil, 2019) on the CGA-CMV corpus. You will see how the `Forecaster` API allows us to load the data, select training, validation, and testing samples, train the CRAFT model, and perform evaluation - replicating the original paper's full pipeline (minus pre-training, which is considered outside the scope of ConvoKit) all in only a few lines of code!\n",
"\n",
"Let's start by importing the necessary ConvoKit classes and functions, and loading the CGA-CMV corpus."
]
},
{
"cell_type": "code",
"execution_count": 1,
Expand Down Expand Up @@ -34,7 +46,15 @@
"id": "a4d27b1a-3d1f-4039-b10f-21b6dd230c10",
"metadata": {},
"source": [
"## Define selectors and filters for the Forecaster"
"## Define selectors for the Forecaster\n",
"\n",
"Core to the flexibility of the `Forecaster` framework is the concept of *selectors*. \n",
"\n",
"To capture the temporal dimension of the conversational forecasting task, `Forecaster` iterates through conversations in chronological utterance order, at each step presenting to the backend forecasting model a \"context tuple\" containing both the comment itself and the full \"context\" preceding that comment. As a general framework, `Forecaster` on its own does not try to make any further assumptions about what \"context\" should contain or look like; it simply presents context as a chronologically ordered list of all utterances up to and including the current one. \n",
"\n",
"But in practice, we often want to be pickier about what we mean by \"context\". At a basic level, we might want to select only specific contexts during training versus during evaluation. The simplest version of this is the desire to split the conversations by training and testing splits, but more specifically, we might also want to select only certain contexts within a conversation. This is necessary for CRAFT training, which works by taking only the chronologically last context (i.e., all utterances up to and not including the toxic comment, or up to the end of the conversation) as a labeled training instance. This is where selectors come in! A selector is a user-provided function that takes in a context and returns a boolean representing whether or not that context should be used. You can provide separate selectors for `fit` and `transform`, and `fit` also takes in a second selector that you can use to define validation data.\n",
"\n",
"Here we show how to implement the necessary selectors for CRAFT."
]
},
{
Expand Down Expand Up @@ -69,7 +89,11 @@
"id": "9614aff5-843e-4b3b-b03f-6f57f8e76b8a",
"metadata": {},
"source": [
"## Initialize the Forecaster and CRAFTModel backend"
"## Initialize the Forecaster and CRAFTModel backend\n",
"\n",
"Now the rest of the process is pretty straightforward! We simply need to:\n",
"1. Initialize a backend `ForecasterModel` for the `Forecaster` to use, in this case we choose ConvoKit's implementation of CRAFT.\n",
"2. Initialize a `Forecaster` instance to wrap that `ForecasterModel` in a generic fit-transform API"
]
},
{
Expand Down Expand Up @@ -108,7 +132,9 @@
"id": "b840f526-dafd-4022-b5a1-90adecbd1591",
"metadata": {},
"source": [
"## Fine-tune the model using Forecaster.fit"
"## Fine-tune the model using Forecaster.fit\n",
"\n",
"And now, just like any other ConvoKit Transformer, model training is done simply by calling `fit` (note how we pass in the selectors we previously defined!)..."
]
},
{
Expand Down Expand Up @@ -1081,7 +1107,9 @@
"id": "3238e632-4e1b-4ba1-938a-b9aeef07f0cf",
"metadata": {},
"source": [
"## Run the fitted model on the test set and perform evaluation"
"## Run the fitted model on the test set and perform evaluation\n",
"\n",
"...and inference is done simply by calling `transform`! (again, note the selector)"
]
},
{
Expand Down Expand Up @@ -1218,6 +1246,14 @@
"corpus = craft_forecaster.transform(corpus, transform_selector)"
]
},
{
"cell_type": "markdown",
"id": "9f517cfb-3a1d-4279-bf09-96695938d30c",
"metadata": {},
"source": [
"Finally, to get a human-readable interpretation of model performance, we can use `summarize` to generate a table of standard performance metrics. It also returns a table of conversation-level predictions in case you want to do more complex analysis!"
]
},
{
"cell_type": "code",
"execution_count": 8,
Expand Down

0 comments on commit 6d5ac17

Please sign in to comment.