diff --git a/src/scripts/proximity.py b/src/scripts/proximity.py index 4352630..ad16dc9 100644 --- a/src/scripts/proximity.py +++ b/src/scripts/proximity.py @@ -7,15 +7,16 @@ from dataclasses import dataclass from pathlib import Path from typing import Any, Dict, List, Tuple - +from datetime import datetime import numpy as np import pandas as pd import pedpy from joblib import Parallel, delayed from tqdm import tqdm +from enum import Enum parent_dir = str(Path(__file__).resolve().parent.parent) -print(parent_dir) +print("parent dir", parent_dir) if parent_dir not in sys.path: print(parent_dir) sys.path.append(parent_dir) @@ -34,12 +35,23 @@ # print(f"distances {len(path_distances)}") +class Method(Enum): + """Method to calculate distances. vect and merge are Euc.""" + + VECT = "vect" + ARC = "arc" + MERGE = "merge" + + @dataclass class InitData: + """Class to hold data.""" + countries: List[str] files: Dict[str, List[str]] result_csv: Path fps: int + method: str def load_file(file: str) -> pedpy.TrajectoryData: @@ -74,7 +86,6 @@ def filter_frames(data: pd.DataFrame, nagents: int) -> pd.DataFrame: With this method with basically skip all these frames and keep only these with agents """ - # Group data by frame and count the unique IDs in each frame frame_counts = data.groupby("frame")["id"].nunique() @@ -314,7 +325,7 @@ def init_gender_code(filename: str) -> str: return name -def calculate_proximity_analysis(country: str, filename: str, data: pd.DataFrame, fps: int = 25) -> List[Dict[str, Any]]: +def calculate_proximity_analysis(country: str, filename: str, data: pd.DataFrame, fps: int = 25, method: Method = Method.VECT) -> List[Dict[str, Any]]: """ Perform proximity analysis on given data of agents. @@ -337,25 +348,24 @@ def calculate_proximity_analysis(country: str, filename: str, data: pd.DataFrame 'gender_of_next_neighbor', 'gender_of_prev_neighbor', 'distance_to_next_neighbor', and 'distance_to_prev_neighbor' columns. """ - pp = Path(filename).parts country_short = Path(pp[-2]).name if country_short != "pal": # print(f"WARNINNG in calculate_proximity_analysis(): method={method} is hard-coded") # merge and vect are basically the same, just testing which implementation is faster # arc uses a different calculation of distances based on arc. - method = "arc" - if method == "merge": + if method == method.MERGE: processed_data = compute_distance_merge(data) - elif method == "vect": + elif method == method.VECT: nagents = extract_first_number(filename) cleaned_data = filter_frames(data, nagents) processed_data = calculate_circular_distance_and_gender_vect(cleaned_data) - else: # arc + elif method == method.ARC: nagents = extract_first_number(filename) cleaned_data = filter_frames(data, nagents) processed_data = calculate_circular_distance_and_gender_arc(cleaned_data) - fps = 25 + else: + print(f"Method {method} is not recognised.") else: processed_data = calculate_circular_distance_and_gender_pal(data) fps = 1 @@ -363,7 +373,7 @@ def calculate_proximity_analysis(country: str, filename: str, data: pd.DataFrame proximity_analysis_res = [] if processed_data.empty: print("========") - print(filename) + print("Processed_data empty", filename) print(processed_data) print("========") frames_to_include = set(range(0, processed_data["frame"].max(), fps)) @@ -420,11 +430,11 @@ def unpack_and_process(args: Any) -> List[Dict[str, Any]]: return calculate_proximity_analysis(*args) -def prepare_data(country: str, selected_file: str, fps: int) -> Tuple[str, str, pd.DataFrame, int]: +def prepare_data(country: str, selected_file: str, fps: int, method: Method) -> Tuple[str, str, pd.DataFrame, int]: """Load file and make pedpy.datatrajectory.""" trajectory_data = load_file(selected_file) data = trajectory_data.data - return country, selected_file, data, fps + return country, selected_file, data, fps, method def calculate_with_joblib(init_data: InitData) -> pd.DataFrame: @@ -435,7 +445,7 @@ def calculate_with_joblib(init_data: InitData) -> pd.DataFrame: print(f"prepare tasks: {country} with {key}") for filename in init_data.files[key]: - tasks.append(prepare_data(country, filename, init_data.fps)) + tasks.append(prepare_data(country, filename, init_data.fps, init_data.method)) # Define a function to be executed in parallel def process_task(task: List[Any]) -> List[Dict[str, Any]]: @@ -456,13 +466,15 @@ def process_task(task: List[Any]) -> List[Dict[str, Any]]: def init() -> InitData: """ - Initializes the application by setting up the countries, file paths, result CSV path, and FPS (frames per second). + Initialize the application by setting up the countries, file paths, result CSV path, and FPS (frames per second). Returns: InitData: A data class containing initialized data including countries, files dictionary, result CSV path, and FPS. """ print("Enter Init") - result_csv = Path(f"{parent_dir}/app_data/proximity_analysis_results_arc.csv") + method = Method.VECT + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + result_csv = Path(f"{parent_dir}/app_data/proximity_analysis_results_{method}_{timestamp}.csv") result_csv.parent.mkdir(parents=True, exist_ok=True) print(f"Created file: {result_csv}") fps = 25 # For distance calculations, calculate every fps-frame @@ -478,7 +490,7 @@ def init() -> InitData: key = Path(country).name files[key] = [str(path) for path in Path(country).glob("*.csv")] - return InitData(countries=countries, files=files, result_csv=result_csv, fps=fps) + return InitData(countries=countries, files=files, result_csv=result_csv, fps=fps, method=method) if __name__ == "__main__":