-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathtrain_config.yaml
47 lines (37 loc) · 1.33 KB
/
train_config.yaml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
#task: contrastive_lecun_classes
task: ae_contrastive_lecun_classes
# options: ae, contrastive_lecun_classes, ae_contrastive_lecun_classes, ae_triplet_classes, ae_triplet_hierarchical_classes
# Model name (e.g. IncrFeatStridedConvFCUpsampReflectPadAE)
model_name: IncrFeatStridedConvFCUpsampReflectPadAE
trk_path:
pickle_path:
rbx_classes: False
# Dataset name (e.g. `fibercup`, `ismrm2015_phantom`, etc.)
dataset_name: mni_brain
# Number of points for streamline resampling
num_points: 256
# Apply flip to streamlines randomly
random_flip: True
# Normalize streamlines to the volume's isocenter and volume
normalize: False
# Training parameters
batch_size: 128 # not used for contrastive learning, see `contrastive_num_pairs`
num_steps_per_train_epoch: 100
num_steps_per_valid_epoch: 100
contrastive_num_pairs: 96 # e.g. if this value is 4, there will be 4 positive pairs, 8 pairs total; batch size: 16
contrastive_margin: 1.25
contrastive_loss_weight: 400 # use 4000 for triplet and 400 for contrastive
distance_function: "l2" # only use with triplet loss (options: l2, cosine_similarity)
to_swap: False # only use with triplet loss
epochs: 3
data_in_memory: 204800
# Auto-encoder common parameters
latent_space_dims: 32
# Log interval
log_interval: 10
# Weight's filename
weights:
# Visualization
viz: False
viz_num_batches: 10
num_workers: 24