-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmain.py
111 lines (90 loc) · 2.86 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import numpy as np
import argparse
import csv
from sklearn.cluster import Birch
import matplotlib.pyplot as plt
import pandas as pd
import plotly.plotly
import plotly.graph_objs as go
import seaborn as sns
from typing import Tuple, Dict, List
def load_data(file_name) -> List[List]:
print("--->Loading csv file")
with open(file_name) as csv_file:
csv_reader = csv.reader(csv_file, delimiter=",")
line_count = 0
data = []
for line in csv_reader:
if line_count == 0:
print(f'Column names: [{", ".join(line)}]')
else:
data.append(line)
line_count += 1
print(f'Loaded {line_count} records')
return data
def compute_clusters(data: List) -> np.ndarray:
print("--->Computing clusters")
birch = Birch(
branching_factor=50,
n_clusters=5,
threshold=0.3,
copy=True,
compute_labels=True
)
birch.fit(data)
predictions = np.array(birch.predict(data))
return predictions
def show_results(data: np.ndarray, labels: np.ndarray, plot_handler = "seaborn") -> None:
labels = np.reshape(labels, (1, labels.size))
data = np.concatenate((data, labels.T), axis=1)
# Seaborn plot
if plot_handler == "seaborn":
facet = sns.lmplot(
data=pd.DataFrame(data, columns=["Income", "Spending", "Label"]),
x="Income",
y="Spending",
hue='Label',
fit_reg=False,
legend=True,
legend_out=True
)
# Pure matplotlib plot
if plot_handler == "matplotlib":
fig = plt.figure()
ax = fig.add_subplot(111)
scatter = ax.scatter(data[:,0], data[:, 1], c=data[:, 2], s=50)
ax.set_title("Clusters")
ax.set_xlabel("Income")
ax.set_ylabel("Spending")
plt.colorbar(scatter)
plt.show()
def show_data_corelation(data=None, csv_file_name=None):
data_set = None
if csv_file_name is None:
cor = np.corrcoef(data)
print("Corelation matrix:")
print(cor)
else:
data_set = pd.read_csv(csv_file_name)
print(data_set.describe())
data_set = data_set[["Age", "Annual Income (k$)", "Spending Score (1-100)"]]
cor = data_set.corr()
sns.heatmap(cor, square=True)
plt.show()
return data_set
def main(args) -> None:
data = load_data(args.data_file)
filtered_data = np.array([[item[3], item[4]] for item in data])
data_set = None #Alternative data loaded using pandas
if args.describe == True:
data_set = show_data_corelation(csv_file_name=args.data_file)
filtered_data = np.array(filtered_data).astype(np.float64)
labels = compute_clusters(filtered_data)
show_results(filtered_data, labels, args.plot_handler)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Do some clustering")
parser.add_argument("--data-file", type=str, default="Mall_Customers.csv", help="dataset file name")
parser.add_argument("--describe", type=bool, default=False, help="describe the dataset")
parser.add_argument("--plot-handler", type=str, default="seaborn", help="what library to use for data visualisation")
args = parser.parse_args()
main(args)