-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathjhu_process_data.py
93 lines (83 loc) · 3.07 KB
/
jhu_process_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
GLOBALCONFIRMED = './COVID-19/csse_covid_19_data/csse_covid_19_time_series/time_series_covid19_confirmed_global.csv'
GLOBALDEATHS = './COVID-19/csse_covid_19_data/csse_covid_19_time_series/time_series_covid19_deaths_global.csv'
DEPRCONFIRMED = './COVID-19/csse_covid_19_data/csse_covid_19_time_series/time_series_19-covid-Confirmed.csv'
DEPRDEATHS = './COVID-19/csse_covid_19_data/csse_covid_19_time_series/time_series_19-covid-Deaths.csv'
DEPRRECOVERED = './COVID-19/csse_covid_19_data/csse_covid_19_time_series/time_series_19-covid-Recovered.csv'
import csv
import pdb
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
# takes in a list of curves and labels and plots them
def plot_curves(curves,colors=None,labels=None,fig=None,log=True):
if(labels==None):
labels = [None]*len(curves)
if(colors==None):
colors = [None]*len(curves)
if(fig != None):
plt.figure(fig)
for i in range(len(curves)):
c = curves[i]
if(log):
c = np.log10(c)
sns.lineplot(hue=colors[i],data=c,label=labels[i])
plt.xlabel("time (days)")
plt.ylabel("#")
def to_daily(data):
return np.diff(data)
def naive_est(recd,dead,conf):
idx = np.where(conf>20)
conf[conf==0]=1
return dead[idx]/conf[idx]
def resolved_est(recd,dead,conf):
O = recd+dead
idx = np.where(conf>20)
O[O==0]=1
return dead[idx]/O[idx]
def lagged_est(recd,dead,conf,T=5):
idx = np.where(conf>20)
conf[conf==0]=1
dead = dead[idx]
conf = conf[idx]
dead = dead[T:]
conf = conf[:-T]
est = dead/conf
return np.insert(est,0,np.zeros((T,)))
def plot_estimator(estimate, fig=None,label=None):
if(fig != None):
plt.figure(fig)
sns.lineplot(data=estimate,label=label)
plt.xlabel("time (days)")
plt.ylabel("CFR")
def csv2dict(data_file,colcut=4):
lab_data = dict() # This will be a dictionary of dictionaries.
# The first level is by country. The second level is by state.
# If there are no states within the country, the second level
# will just be the country name again.
with open(data_file, newline='') as csvfile:
reader = csv.reader(csvfile, delimiter=',')
i = 0
for row in reader:
row = np.array(row)
i += 1
if(i == 1):
lab_data['Dates'] = row[colcut:]
continue
country_str=row[1]
state_str = row[0]
if(state_str == ''):
state_str = country_str
row[row==''] = '0' #Sometimes the csv has errors
arr = row[colcut:].astype(np.double) # Cut off the strings at the beginning.
# Populate the arrays.
if country_str not in lab_data:
lab_data[country_str] = dict()
lab_data[country_str][state_str]=arr
return lab_data
if __name__ == "__main__":
# The rows are cities, the columns are time points.
C = csv2dict(GLOBALCONFIRMED)
D = csv2dict(GLOBALDEATHS)
C_D = csv2dict(DEPRCONFIRMED)
D_D = csv2dict(DEPRDEATHS)
R_D = csv2dict(DEPRRECOVERED)