-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy patheval_ego_forward_model.py
107 lines (90 loc) · 3.21 KB
/
eval_ego_forward_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import argparse
import logging
import torch
from torch.utils.data import DataLoader, Subset
from carla_env.dataset.instance import InstanceDataset
from carla_env.models.dynamic.vehicle import KinematicBicycleModel
from carla_env.evaluator.ego_model import Evaluator
from utils.train_utils import get_device, seed_everything
from utils.model_utils import fetch_checkpoint_from_wandb_link
import wandb
logger = logging.getLogger(__name__)
logging.basicConfig(
level=logging.INFO,
datefmt="%Y-%m-%d %H:%M:%S",
format="%(asctime)s - %(name)s - %(levelname)s - %(funcName)s:%(lineno)d ==> %(message)s",
)
def main(config):
seed_everything(config.seed)
device = get_device()
run = wandb.Api().run(config.ego_forward_model_wandb_link)
checkpoint = fetch_checkpoint_from_wandb_link(
wandb_link=config.ego_forward_model_wandb_link,
checkpoint_number=config.ego_forward_model_checkpoint_number,
)
model = KinematicBicycleModel.load_model_from_wandb_run(
run=run, checkpoint=checkpoint, device=device
)
model.to(device).eval()
# Create dataset and its loader
data_path_test = config.data_path_test
dataset_test = InstanceDataset(
data_path=data_path_test,
sequence_length=run.config["num_time_step_previous"]
+ run.config["num_time_step_future"],
read_keys=["ego"],
dilation=run.config["dataset_dilation"],
)
logger.info(f"Test dataset size: {len(dataset_test)}")
dataloader_test = DataLoader(
dataset=Subset(
dataset_test,
range(
0,
len(dataset_test),
(
run.config["num_time_step_previous"]
+ run.config["num_time_step_future"]
)
* run.config["dataset_dilation"],
),
),
batch_size=20,
shuffle=False,
num_workers=0,
)
# dataloader_test = DataLoader(
# dataset=dataset_test,
# batch_size=3,
# shuffle=False,
# num_workers=0)
evaluator = Evaluator(
model=model,
dataloader=dataloader_test,
device=device,
sequence_length=run.config["num_time_step_previous"]
+ run.config["num_time_step_future"],
save_path=f"{config.save_path}",
)
evaluator.evaluate(render=False, save=True)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--seed", type=int, default=42)
parser.add_argument(
"--data_path_test",
type=str,
default="/home/volkan/Documents/Codes/carla_env/data/ground_truth_bev_model_test_data_10Hz_multichannel_bev_dense_traffic/",
)
parser.add_argument(
"--save_path",
type=str,
default="figures/ego_forward_model_evaluation_extensive/ashdgjhsagdjsa/",
)
parser.add_argument("--num_time_step_previous", type=int, default=1)
parser.add_argument("--num_time_step_future", type=int, default=10)
parser.add_argument(
"--ego_forward_model_wandb_link", type=str, default="vaydingul/mbl/ssifa1go"
)
parser.add_argument("--ego_forward_model_checkpoint_number", type=int, default=459)
config = parser.parse_args()
main(config)