-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathray_app.py
63 lines (49 loc) · 2.12 KB
/
ray_app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
from collections import Counter
import socket
import time
import ray
import argparse
parser = argparse.ArgumentParser(description='Set parameters')
parser.add_argument("--port", type=str, help="Port number used by ray clusters",default='1297')
parser.add_argument("--hostip", type=str, help="Host IP in ray cluster",default='127.0.0.1')
parser.add_argument("--rpass", type=str, help="Redis password",default='9673532156')
parser.add_argument("--localdir", type=str, help="Path to run directory",default='run_raytune.py')
parser.add_argument("--modelname", type=str, help="name of model",default='None')
parser.add_argument("--datapath", type=str, help="Path to data file",default='None')
parser.add_argument("--cpus_per_trial", type=str, help="Number of cpus to utilize per trial", default="1")
parser.add_argument("--dashport", type=str, help="Port used for the dashboard", default="8265")
parser.add_argument("-L","--runlocal", action='store_true', help="Run ray locally",default=False)
args = parser.parse_args()
PORT = args.port
HOST_IP = args.hostip
RPASS = args.rpass
LOCAL_DIR = args.localdir
MODELNAME = args.modelname
DATAPATH = args.datapath
RUNLOCAL = args.runlocal
CPUS_PER_TRIAL = float(args.cpus_per_trial)
DASHPORT = int(args.dashport)
if __name__ == '__main__':
if RUNLOCAL:
ray.init(object_store_memory=10**9,dashboard_port=DASHPORT)
else:
ray.init(address=":".join([HOST_IP,PORT]),_redis_password=RPASS)
print('''This cluster consists of
{} nodes in total
{} CPU resources in total
'''.format(len(ray.nodes()), ray.cluster_resources()['CPU']))
start = time.time()
@ray.remote
def f():
time.sleep(0.001)
# Return IP address.
return socket.gethostbyname(socket.gethostname())
object_ids = [f.remote() for _ in range(100000)]
ip_addresses = ray.get(object_ids)
end = time.time()
print('Tasks executed')
for ip_address, num_tasks in Counter(ip_addresses).items():
print(' {} tasks on {}'.format(num_tasks, ip_address))
print('Total elapsed time was: {} seconds'.format(end - start))
if RUNLOCAL:
ray.shutdown()