-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathperplexity.h
139 lines (113 loc) · 4.79 KB
/
perplexity.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
#pragma once
void softmax(float* x, int size) {
// find max value (for numerical stability)
float max_val = x[0];
for (int i = 1; i < size; i++) {
if (x[i] > max_val) {
max_val = x[i];
}
}
// exp and sum
float sum = 0.0f;
for (int i = 0; i < size; i++) {
x[i] = expf(x[i] - max_val);
sum += x[i];
}
// normalize
for (int i = 0; i < size; i++) {
x[i] /= sum;
}
}
// tokens is an array of integers representing the input text
// logits is a 2D array of floats representing the predicted probabilities of each token for each possible word in the vocabulary
// vocab_size is a constant integer representing the size of the vocabulary
// num_tokens is an integer representing the number of tokens in the input text
float compute_perplexity(int* tokens, float* logits, int num_tokens, int vocab_size) {
// initialize a variable to store the sum of log probabilities
double sum = 0.0;
// loop through each token in the input text
for (int i = 0; i < num_tokens; i++) {
// get the index of the actual word in the vocabulary
int word_index = tokens[i];
// the logits that we get from GPU are pre-softmax, need to apply softmax first
softmax(&logits[i * vocab_size], vocab_size);
// get the predicted probability of that word from the logits array
double prob = logits[i * vocab_size + word_index];
//printf(" %g,", prob);
// add the log probability to the sum
sum += log(prob);
}
// compute the average log probability
double avg_log_prob = sum / num_tokens;
// compute the perplexity as the exponentiation of the negative average log probability
return float(exp(-avg_log_prob));
}
void run_transformer(bool gen_token, Config* p, RunState* s, TransformerWeights* w, bool copyLogits, Sampler* pSampler);
// ----------------------------------------------------------------------------
float get_dataset_perplexity(char* dataset, Tokenizer* tokenizer, Config* config, RunState* state, TransformerWeights* weights, Sampler *pSampler) {
int bytes = strlen(dataset);
int* datasetTokens = &(state->shared_data->tokens[1]);
printf("\nTokenizing Dataset...");
int totalTokens;
encode(tokenizer, dataset, 0, 0, datasetTokens, &totalTokens);
printf("done!\n");
printf("Found %d characters, %d tokens", bytes, totalTokens);
int numTokens = totalTokens;
if (numTokens >= config->seq_len) {
numTokens = config->seq_len - 1;
printf("\nTruncated to %d tokens", numTokens);
}
printf("\nRunning the network to get logits...");
// run the transformer model to get logits
cudaMemset(state->pos, 0, sizeof(int));
state->shared_data->pos = 0;
state->shared_data->tokens[0] = bos_token;
for (int pos = 0; pos < numTokens; pos++) {
run_transformer(false, config, state, weights, true, pSampler);
cudaDeviceSynchronize();
}
printf("done!\n");
printf("Computing perplexity...");
// copy the logits and compute preplexity
float* logits_arr = (float*)malloc(numTokens * config->vocab_size * sizeof(float));
cudaMemcpy(logits_arr, state->logits_array, numTokens * config->vocab_size * sizeof(float), cudaMemcpyDeviceToHost);
float pplx = compute_perplexity(datasetTokens, logits_arr, numTokens, config->vocab_size);
printf("\nPerplexity computed on %d tokens: %f\n\n", numTokens, pplx);
free(logits_arr);
return pplx;
}
void parseDataSetAndComputePreplexity(char* textFileName, Tokenizer* tokenizer, Config* config, RunState* state, TransformerWeights* weights, Sampler *pSampler)
{
FILE* fp = fopen(textFileName, "rb+");
printf("\nLoading Dataset...");
// find the number of bytes in the file
fseek(fp, 0, SEEK_END);
int bytes = ftell(fp);
fseek(fp, 0, SEEK_SET);
char *dataset = (char*)malloc(bytes + 1);
fread(dataset, 1, bytes, fp);
fclose(fp);
printf("done!\n");
dataset[bytes] = 0; // null terminate in case it wasn't
int count = 0;
double pplx_product = 1;
// search for <|endoftext|> and break down the dataset into multiple sequences
char* currentSeq = dataset;
while (currentSeq) {
char* nextseq;
if (nextseq = strstr(currentSeq, "<|endoftext|>")) {
*nextseq = 0;
nextseq += 13;
pplx_product *= get_dataset_perplexity(currentSeq, tokenizer, config, state, weights, pSampler);
count++;
currentSeq = nextseq;
}
else {
pplx_product *= get_dataset_perplexity(currentSeq, tokenizer, config, state, weights, pSampler);
count++;
break;
}
}
free(dataset);
printf("\nGeomean perplexity on %d sequences: %f\n\n", count, pow(pplx_product, 1.0 / count));
}