-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
96 lines (74 loc) · 2.89 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
"""
Copyright 2025 Universitat Politècnica de Catalunya
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
import tensorflow as tf
from typing import List
def seg_to_global_reshape(tensor, num_dims=3):
assert num_dims > 1
perms = [1, 0] + list(range(2, num_dims))
total_flows = tf.shape(tensor)[0] * tf.shape(tensor)[1]
if num_dims == 2:
new_shape = (total_flows,)
else:
new_shape = tf.concat([[total_flows], tf.shape(tensor)[2:]], axis=0)
return tf.reshape(tf.transpose(tensor, perms), new_shape)
def prepare_targets_and_mask(targets: List[str], mask: str) -> callable:
"""Obtains map function to prepare the targets of the dataset for the current model.
Parameters
----------
targets : List[str]
List of features to be selected as targets, in that order.
mask : str
Mask feature used to determine which windows in the temporal dimension are
valid.
Returns
-------
callable
Map function to be called by the tf.data.Dataset.map method.
"""
def modified_target_map(x, y):
reshaped_mask = tf.expand_dims(seg_to_global_reshape(x[mask], num_dims=2), 1)
return x, tf.concat(
[
tf.reshape(
tf.boolean_mask(seg_to_global_reshape(x[target]), reshaped_mask),
(-1, 1),
)
for target in targets
],
axis=1,
)
return modified_target_map
def load_dataset(name: str, data_path: str = "data") -> tf.data.Dataset:
"""Function to unshard and load a dataset from the data folder.
Parameters
----------
name : str
Name of the dataset and partition [training/validation/test] to load, in format
'{name}/{partition}'.
data_path : str
Path to the data folder. By default, it is 'data', which assumes the working
directory is the root of the project.
Returns
-------
tf.data.Dataset
The dataset loaded from the shards.
"""
path = os.path.join(data_path, name)
shards = os.listdir(path)
assert len(shards) > 0, f"Invalid dataset: {name}"
ds = tf.data.Dataset.load(path + "/0", compression="GZIP")
for ii in range(1, len(shards)):
ds = ds.concatenate(tf.data.Dataset.load(path + f"/{ii}", compression="GZIP"))
return ds.prefetch(tf.data.experimental.AUTOTUNE)