-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Yasha Ektefaie
committed
Dec 21, 2024
1 parent
7c02c4b
commit 2aa1fd0
Showing
5 changed files
with
498 additions
and
257 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,3 @@ | ||
from .spectra import Spectra | ||
from .dataset import SpectraDataset | ||
from .dataset import SpectraDataset | ||
from .utils import Spectral_Property_Graph, FlattenedAdjacency |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,38 +1,40 @@ | ||
from abc import ABC, abstractmethod | ||
from typing import List, Dict | ||
|
||
class SpectraDataset(ABC): | ||
|
||
def __init__(self, input_file, name): | ||
self.input_file = input_file | ||
self.name = name | ||
self.samples = self.parse(input_file) | ||
|
||
@abstractmethod | ||
def sample_to_index(self, idx): | ||
""" | ||
Given a sample, return the data idx | ||
""" | ||
pass | ||
|
||
self.sample_to_index = self.parse(input_file) | ||
self.samples = list(self.sample_to_index.keys()) | ||
self.samples.sort() | ||
|
||
@abstractmethod | ||
def parse(self, input_file): | ||
def parse(self, input_file: str) -> Dict: | ||
""" | ||
Given a dataset file, parse the dataset file. | ||
Make sure there are only unique entries! | ||
Given a dataset file, parse the dataset file to return a dictionary mapping a sample ID to the data | ||
""" | ||
pass | ||
raise NotImplementedError("Must implement parse method to use SpectraDataset, see documentation for more information") | ||
|
||
@abstractmethod | ||
def __len__(self): | ||
""" | ||
Return the length of the dataset | ||
""" | ||
pass | ||
return len(self.samples) | ||
|
||
@abstractmethod | ||
def __getitem__(self, idx): | ||
""" | ||
Given a dataset idx, return the element at that index | ||
""" | ||
pass | ||
if isinstance(idx, int): | ||
return self.sample_to_index[self.samples[idx]] | ||
return self.sample_to_index[idx] | ||
|
||
def index(self, value): | ||
""" | ||
Given a value, return the index of that value | ||
""" | ||
if value not in self.samples: | ||
raise ValueError(f"{value} not in the dataset") | ||
return self.samples.index(value) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,83 +1,91 @@ | ||
import random | ||
import networkx as nx | ||
from .utils import is_clique, connected_components, is_integer | ||
from scipy import stats | ||
import random | ||
import numpy as np | ||
from tqdm import tqdm | ||
import torch | ||
from .utils import FlattenedAdjacency, Spectral_Property_Graph, cross_split_overlap | ||
|
||
def run_independent_set(spectral_parameter, input_G, seed = None, | ||
debug=False, distribution = None, binary = True): | ||
total_deleted = 0 | ||
independent_set = [] | ||
|
||
if seed is not None: | ||
random.seed(seed) | ||
def run_independent_set(spectral_parameter: int, | ||
input_G: Spectral_Property_Graph, | ||
seed: int = 42, | ||
binary: bool = True, | ||
minimum: int = None, | ||
degree_choosing: bool = False, | ||
num_splits: int = None): | ||
|
||
G = input_G.copy() | ||
|
||
if binary: | ||
#First check if any connected component of the graph is a clique, if so, add it as one unit to the independent set | ||
components = list(connected_components(G)) | ||
deleted = 0 | ||
for i, component in enumerate(components): | ||
subgraph = G.subgraph(component) | ||
if is_clique(subgraph): | ||
print(f"Component {i} is too densly connected, adding samples as a single unit to independent set and deleting them from the graph") | ||
independent_set.append(list(subgraph.nodes())) | ||
G.remove_nodes_from(subgraph.nodes()) | ||
else: | ||
for node in list(subgraph.nodes()): | ||
if subgraph.degree(node) == len(subgraph.nodes()) - 1: | ||
deleted += 1 | ||
G.remove_node(node) | ||
|
||
print(f"Deleted {deleted} nodes from the graph since they were connected to all other nodes") | ||
|
||
iterations = 0 | ||
total_num_deleted = 0 | ||
independent_set = [] | ||
random.seed(seed) | ||
|
||
n = input_G.num_nodes() | ||
indices_to_scan = list(range(n)) | ||
if spectral_parameter == 0: | ||
return indices_to_scan | ||
pbar = tqdm(total = len(indices_to_scan)) | ||
|
||
#Trying a non-percentile approach | ||
#Note this assumes there are 20 | ||
if not binary: | ||
if num_splits is None: | ||
raise Exception("Num splits must be specified for non-binary graphs, see documentation for more information") | ||
threshold = spectral_parameter*(torch.max(input_G) - torch.min(input_G))/num_splits | ||
else: | ||
threshold = 0 | ||
print(f"Threshold is {threshold}") | ||
indices_deleted = [] | ||
|
||
expected_number_delete = int(n * spectral_parameter) | ||
print(expected_number_delete) | ||
|
||
while not nx.is_empty(G): | ||
chosen_node = random.sample(list(G.nodes()), 1)[0] | ||
while len(indices_to_scan) > 0: | ||
print(len(indices_deleted)) | ||
indices_deleted = [] | ||
if degree_choosing: | ||
chosen_node, _ = input_G.get_minimum_degree_node(indices_to_scan) | ||
else: | ||
chosen_node = random.sample(indices_to_scan, 1)[0] | ||
|
||
indices_to_scan.remove(chosen_node) | ||
|
||
to_iterate = indices_to_scan[:] | ||
|
||
independent_set.append(chosen_node) | ||
neighbors = G.neighbors(chosen_node) | ||
neighbors_to_delete = [] | ||
indices_to_gather = [] | ||
|
||
for index in to_iterate: | ||
indices_to_gather.append((chosen_node, index)) | ||
|
||
values = input_G.get_weights(indices_to_gather) | ||
|
||
indices_deleted.extend(list(torch.tensor(to_iterate).cuda()[values > threshold].cpu().numpy())) | ||
|
||
indices_deleted = list(set(indices_deleted)) | ||
indices_to_scan = set(indices_to_scan) | ||
|
||
for neighbor in neighbors: | ||
if not binary: | ||
if spectral_parameter == 1.0: | ||
neighbors_to_delete.append(neighbor) | ||
else: | ||
edge_weight = G[chosen_node][neighbor]['weight'] | ||
if distribution is None: | ||
raise Exception("Distribution must be provided if binary is set to False, must precompute similarities") | ||
if random.random() < spectral_parameter and (1-spectral_parameter)*100 < stats.percentileofscore(distribution, edge_weight): | ||
neighbors_to_delete.append(neighbor) | ||
else: | ||
if spectral_parameter == 1.0: | ||
neighbors_to_delete.append(neighbor) | ||
elif spectral_parameter != 0.0: | ||
if len(indices_deleted) > expected_number_delete: | ||
indices_deleted = [chosen_node] | ||
total_num_deleted += 1 | ||
else: | ||
independent_set.append(chosen_node) | ||
for i in indices_deleted: | ||
if binary: | ||
if random.random() < spectral_parameter: | ||
neighbors_to_delete.append(neighbor) | ||
indices_to_scan.remove(i) | ||
total_num_deleted += 1 | ||
else: | ||
indices_to_scan.remove(i) | ||
total_num_deleted += 1 | ||
|
||
if minimum is not None: | ||
if n - total_num_deleted <= minimum - len(independent_set): | ||
independent_set.extend(indices_to_scan) | ||
return independent_set | ||
|
||
if debug: | ||
print(f"Iteration {iterations} Stats") | ||
print(f"Deleted {len(neighbors_to_delete)} nodes from {G.degree(chosen_node)} neighbors of node {chosen_node}") | ||
total_deleted += len(neighbors_to_delete) | ||
|
||
for neighbor in neighbors_to_delete: | ||
G.remove_node(neighbor) | ||
|
||
if chosen_node not in neighbors_to_delete: | ||
G.remove_node(chosen_node) | ||
|
||
iterations += 1 | ||
|
||
for node in list(G.nodes()): | ||
#Append the nodes left to G | ||
independent_set.append(node) | ||
indices_deleted.append(chosen_node) | ||
indices_to_scan = list(indices_to_scan) | ||
pbar.update(len(indices_deleted)) | ||
|
||
if debug: | ||
print(f"{len(input_G.nodes())} nodes in the original graph") | ||
print(f"Total deleted {total_deleted}") | ||
print(f"{len(independent_set)} nodes in the independent set") | ||
|
||
pbar.close() | ||
|
||
return independent_set | ||
return independent_set |
Oops, something went wrong.