-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_no_encoding.py
87 lines (76 loc) · 2.1 KB
/
train_no_encoding.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import torch
import torch.nn as nn
import torch.nn.functional as F
from model import NeRF
from dataset import DummyCubeDataset
from train_utils import *
from torch.utils.data import DataLoader
import cv2
import os
import numpy as np
if __name__ == "__main__":
# Example usage
seed_everything(42)
torch.mps.empty_cache()
H = 64
W = 64
FOCAL = 64
NUM_SAMPLES = 16
NEAR = -0.1
FAR = -11.0
D = 16
L = 0
dataset = DummyCubeDataset(
num_images=5,
H=H,
W=W,
focal=FOCAL,
output_dir="dataset_images",
distance_min=0.0,
distance_max=5.0,
azimuth_list=[48, 49, 51, 52],
elevation_list=[50],
)
dataloader = DataLoader(
dataset,
batch_size=1,
shuffle=True,
)
device = "mps" # if torch.cuda.is_available() else "cpu"
model = NeRF(D=D, L=L)
model.to(device)
INFERENCE_FOLDER = f"inference_{H}_{W}_{FOCAL}_{NUM_SAMPLES}_{NEAR}_{FAR}_{D}_{L}"
if not os.path.exists(INFERENCE_FOLDER):
os.makedirs(INFERENCE_FOLDER)
CHECKPOINT_FOLDER = (
f"checkpoints_{H}_{W}_{FOCAL}_{NUM_SAMPLES}_{NEAR}_{FAR}_{D}_{L}"
)
if not os.path.exists(CHECKPOINT_FOLDER):
os.makedirs(CHECKPOINT_FOLDER)
# Initialize and train the model
train_nerf(
model,
dataloader,
epochs=200,
H=H,
W=W,
focal=FOCAL,
num_samples=NUM_SAMPLES,
near=NEAR,
far=FAR,
inference_folder=INFERENCE_FOLDER,
checkpoint_folder=CHECKPOINT_FOLDER,
device=device,
)
visualize(
model,
H,
W,
FOCAL,
NUM_SAMPLES,
NEAR,
FAR,
INFERENCE_FOLDER,
azimuth_list=np.linspace(48, 52, 50),
device=device,
)