-
Notifications
You must be signed in to change notification settings - Fork 0
/
correlation.py
82 lines (61 loc) · 2.68 KB
/
correlation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import json
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
def create_correlation_matrices(data):
leaderboards = data["leaderboards"]
correlation_matrices = {}
# Define the list of environments to consider
environments = ["babyai", "babaisai", "crafter", "textworld", "minihack", "nle"]
for leaderboard in leaderboards:
lb_name = leaderboard["name"]
results = leaderboard["results"]
# List to hold each agent's performance
agent_performances = []
for agent in results:
agent_name = agent["name"]
performance = {"agent": agent_name}
# Extract performance scores for each environment
for env in environments:
if env in agent:
# Use the mean score (first element of the list)
performance[env] = agent[env][0]
agent_performances.append(performance)
# Create a DataFrame from the agent performances
df = pd.DataFrame(agent_performances)
# Check if there are at least two agents to compute correlation
if len(df) < 2:
print(f"Not enough data to compute correlation matrix for {lb_name}.")
continue
# Set 'agent' as the index
df.set_index("agent", inplace=True)
# Compute the correlation matrix
corr_matrix = df.corr()
# Store the correlation matrix
correlation_matrices[lb_name] = corr_matrix
return correlation_matrices
# Load the data from 'data.json'
with open("template/data.json", "r") as f:
data = json.load(f)
# Create the correlation matrices
corr_matrices = create_correlation_matrices(data)
# Plot the correlation matrices and find the environment with the highest average correlation
for lb_name in ["LLM", "VLM"]:
if lb_name in corr_matrices:
corr_matrix = corr_matrices[lb_name]
# Plot the correlation matrix
plt.figure(figsize=(8, 6))
sns.heatmap(corr_matrix, annot=True, cmap="coolwarm", fmt=".2f", linewidths=0.5)
plt.title(f"Correlation Matrix for {lb_name}")
plt.tight_layout()
plt.show()
# Calculate the average correlation per environment, excluding self-correlation
avg_corr = corr_matrix.apply(lambda x: x.drop(labels=x.name).mean(), axis=1)
# Find the environment with the highest average correlation
max_env = avg_corr.idxmax()
print(
f"The environment with the highest average correlation in {lb_name} is '{max_env}' with an average correlation of {avg_corr[max_env]:.2f}"
)
else:
print(f"No correlation matrix available for {lb_name}.")