diff --git a/polytope_feature/datacube/tensor_index_tree.py b/polytope_feature/datacube/tensor_index_tree.py index 604aa744..e97ad6e0 100644 --- a/polytope_feature/datacube/tensor_index_tree.py +++ b/polytope_feature/datacube/tensor_index_tree.py @@ -108,6 +108,12 @@ def add_value(self, value): new_values.sort() self.values = tuple(new_values) + def add_values(self, values): + new_values = list(self.values) + new_values.extend(values) + new_values.sort() + self.values = tuple(new_values) + def create_child(self, axis, value, next_nodes): node = TensorIndexTree(axis, (value,)) existing_child = self.find_child(node) diff --git a/polytope_feature/engine/hullslicer.py b/polytope_feature/engine/hullslicer.py index 9a99682a..26a8b4c5 100644 --- a/polytope_feature/engine/hullslicer.py +++ b/polytope_feature/engine/hullslicer.py @@ -97,7 +97,7 @@ def find_values_between(self, polytope, ax, node, datacube, lower, upper): self.axis_values_between[(flattened_tuple, ax.name, lower, upper, method)] = values return values - def remap_values(self, ax, value): + def remap_value(self, ax, value): remapped_val = self.remapped_vals.get((value, ax.name), None) if remapped_val is None: remapped_val = value @@ -109,6 +109,12 @@ def remap_values(self, ax, value): self.remapped_vals[(value, ax.name)] = remapped_val return remapped_val + def remap_values(self, ax, values): + if not isinstance(values, List): + return self.remap_value(ax, values) + remapped_vals = [self.remap_value(ax, val) for val in values] + return remapped_vals + def _build_sliceable_child(self, polytope, ax, node, datacube, values, next_nodes, slice_axis_idx): for i, value in enumerate(values): if i == 0 or ax.name not in self.compressed_axes: @@ -122,18 +128,17 @@ def _build_sliceable_child(self, polytope, ax, node, datacube, values, next_node child["unsliced_polytopes"].add(new_polytope) next_nodes.append(child) else: - remapped_val = self.remap_values(ax, value) - child.add_value(remapped_val) + remapped_val = self.remap_values(ax, values[1:]) + child.add_values(remapped_val) + break def _build_branch(self, ax, node, datacube, next_nodes): if ax.name not in self.compressed_axes: - parent_node = node.parent right_unsliced_polytopes = [] for polytope in node["unsliced_polytopes"]: if ax.name in polytope._axes: right_unsliced_polytopes.append(polytope) for i, polytope in enumerate(right_unsliced_polytopes): - node._parent = parent_node lower, upper, slice_axis_idx = polytope.extents(ax.name) # here, first check if the axis is an unsliceable axis and directly build node if it is # NOTE: we should have already created the ax_is_unsliceable cache before @@ -154,9 +159,7 @@ def _build_branch(self, ax, node, datacube, next_nodes): all_lowers = [] first_polytope = False first_slice_axis_idx = False - parent_node = node.parent for polytope in node["unsliced_polytopes"]: - node._parent = parent_node if ax.name in polytope._axes: # keep track of the first polytope defined on the given axis if not first_polytope: