-
Notifications
You must be signed in to change notification settings - Fork 22
/
swarm_run.py
31 lines (21 loc) · 941 Bytes
/
swarm_run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import os
os.environ["XLA_FLAGS"] = "--xla_gpu_cuda_data_dir=/opt/cuda-10.1"
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"
os.environ["JAX_DEBUG_NANS"] = "True"
from swarm_jax.swarm_layer import NetworkPrecision
from loader import TextLoader
from swarm_jax.model import SwarmCharTransformer
from swarm_jax.swarm import Swarm
import ray
import optax
ray.init(resources={"tpu": 999}) # pretend we have infinite tpus lol
train_dataset = TextLoader("data/enwik8", batchsize=(1, 16), sample_size=128, length=90000000)
optimizer = optax.chain(
optax.clip_by_global_norm(0.25),
optax.adam(2e-4, b1=0.9, b2=0.99, eps=1e-5))
prec = NetworkPrecision(fwd_act="uint16", rev_act="uint16", grad="uint16")
model = SwarmCharTransformer
swarm = Swarm(model, optimizer, 2 ** 16, train_dataset.get_samples, prec)
swarm.run(100000, "runs/512_30L", "ckpt/512_30L")
ray.shutdown()