Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

renderer, garments and scenes #8

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 100 additions & 0 deletions NDF_combine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import models.local_model as model
from models.data import dataloader_garments, voxelized_data_shapenet

from models import generation
import torch
from torch.nn import functional as F


def rot_YZ(points):
points_rot = points.copy()
points_rot[:, 1], points_rot[:, 2] = points[:, 2], points[:, 1]
return points_rot

def to_grid(points):
grid_points = points.copy()
grid_points[:, 0], grid_points[:, 2] = points[:, 2], points[:, 0]

return 2 * grid_points

def from_grid(grid_points):
points = grid_points.copy()
points[:, 0], points[:, 2] = grid_points[:, 2], grid_points[:, 0]

return 0.5 * points

# 'test', 'val', 'train'
def loadNDF(index, pointcloud_samples, exp_name, data_dir, split_file, sample_distribution, sample_sigmas, res, mode = 'test'):

global encoding
global net
global device

net = model.NDF()

device = torch.device("cuda")


if 'garments' in exp_name.lower() :

dataset = dataloader_garments.VoxelizedDataset(mode = mode, data_path = data_dir, split_file = split_file,
res = res, density =0, pointcloud_samples = pointcloud_samples,
sample_distribution=sample_distribution,
sample_sigmas=sample_sigmas,
)



checkpoint = 'checkpoint_127h:6m:33s_457593.9149734974'

generator = generation.Generator(net,exp_name, checkpoint = checkpoint, device = device)

if 'cars' in exp_name.lower() :

dataset = voxelized_data_shapenet.VoxelizedDataset( mode = mode, res = res, pointcloud_samples = pointcloud_samples,
data_path = data_dir, split_file = split_file,
sample_distribution = sample_distribution, sample_sigmas = sample_sigmas,
batch_size = 1, num_sample_points = 1024, num_workers = 1
)



checkpoint = 'checkpoint_108h:5m:50s_389150.3971107006'

generator = generation.Generator(net, exp_name, checkpoint=checkpoint, device=device)


example = dataset[index]

print('Object: ',example['path'])
inputs = torch.from_numpy(example['inputs']).unsqueeze(0).to(device) # lead inputs and samples including one batch channel

for param in net.parameters():
param.requires_grad = False

encoding = net.encoder(inputs)



def predictRotNDF(points):

points = rot_YZ(points)
points = to_grid(points)
points = torch.from_numpy(points).unsqueeze(0).float().to(device)
return torch.clamp(net.decoder(points,*encoding), max=0.1).squeeze(0).cpu().numpy()


def predictRotGradientNDF(points):
points = rot_YZ(points)
points = to_grid(points)
points = torch.from_numpy(points).unsqueeze(0).float().to(device)
points.requires_grad = True

df_pred = torch.clamp(net.decoder(points,*encoding), max=0.1)

df_pred.sum().backward()

gradient = F.normalize(points.grad, dim=2)[0].detach().cpu().numpy()

df_pred = df_pred.detach().squeeze(0).cpu().numpy()
return df_pred, rot_YZ( 2 * from_grid(gradient))
50 changes: 49 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,36 @@ but replacing `configs/shapenet_cars.txt` in the commands with the desired confi
> with execution of this command. Y needs to be an integer between 0 to X-1, including O and X-1. In case you have SLURM
> available you can use `slurm_scripts/run_preprocessing.sh`

## Downloading Garments and Scenes

To run the garment processing run

```
cd dataprocessing
chmod u+x garment_process.sh
source garment_process.sh
```

## Scenes
Download gibson dataset from [here](https://docs.google.com/forms/d/e/1FAIpQLScWlx5Z1DM1M-wTSXaa6zV8lTFkPmTHW1LqMsoCBDWsTDjBkQ/viewform)
Specifically we used this version:
All scenes, 572 scenes (108GB): gibson_v2_all.tar.gz

Then run
```
python scene_process.py --input_path <PATH to Gibson> --output_path <Where data is to be stored> --sigmas 0.01 0.04 0.16 --res 256 --density 0.001708246

```

This also creates a split file of scenes to be later used for training. This is stored in ./datasets alongside the garments split file


```
./datasets/split_scenes.npz
./datasets/split_garments.npz
```


## Training and generation
To train NDF use
```
Expand All @@ -98,9 +128,27 @@ python generate.py --config configs/shapenet_cars.txt
Again, replacing `configs/shapenet_cars.txt` in the above commands with the desired configuration and `EXP_NAME` with
the experiment name defined in the configuration.

## Rendering

To render garments, run

```
python renderer.py --config configs/garments.txt
```

To render cars, run

```
python renderer.py --config configs/shapenet_cars_pretrained.txt
```

To render from different perspectives, change the `cam_position` and `cam_orientation` variables in the config files

## Contact

For questions and comments please contact [Julian Chibane](http://virtualhumans.mpi-inf.mpg.de/people/Chibane.html) via mail.
For questions and comments about the training and generation, please contact [Julian Chibane](http://virtualhumans.mpi-inf.mpg.de/people/Chibane.html) via mail.

For questions and comments about the rendering code, please contact [Aymen Mir](http://virtualhumans.mpi-inf.mpg.de/people/Mir.html) via mail.

## License
Copyright (c) 2020 Julian Chibane, Max-Planck-Gesellschaft
Expand Down
33 changes: 33 additions & 0 deletions configs/config_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
import os


def str2bool(inp):
return inp.lower() in 'true'

def config_parser():
parser = configargparse.ArgumentParser()

Expand Down Expand Up @@ -95,6 +98,36 @@ def config_parser():
help='Optimizer used during training.')



## Rendering arguments
parser.add_argument("--pc_samples", type=int, help='input pointcloud size')
parser.add_argument("--index", type=int, help='index to be rendered')

###
parser.add_argument("--size", type=int, help="the size of image", default=512)
parser.add_argument("--max_depth", type=float, help="the max depth of projected rays", default=2)
parser.add_argument("--alpha", type=float, help="the value by which the stepping distance should be multiplied",
default=0.6)
parser.add_argument("--step_back", type=float, default=0.005, help="the value by which we step back after stopping criteria met")
parser.add_argument("--epsilon", type=float, default=0.0026, help="epsilon ball - stopping criteria")
parser.add_argument("--screen_bound", type=float, default=0.4)
parser.add_argument("--screen_depth", type=float, default=-1)

parser.add_argument('--cam_position', nargs='+', type=float, help='3D position of camera', default=[0, 0, -1])
parser.add_argument('--light_position', nargs='+', type=float, help='3D position of light source',
default=[-1, -1, -1])
parser.add_argument("--cam_orientation", nargs='+', type=float,
help="Camera Orientation in xyz euler angles (degrees)", default=[180.0, 0.0, -180.0])

parser.add_argument("--folder", type=str, default='./save',
help="location where images are to be saved")
parser.add_argument("--shade", type=str2bool, default=True, help="whether to save shade image")
parser.add_argument("--depth", type=str2bool, default=True, help="whether to save depth image")
parser.add_argument("--normal", type=str2bool, default=True, help="whether to save normal image")

parser.add_argument("--debug_mode", type=str2bool, default=True,
help="to visualize everything in debug mode or not")

return parser


Expand Down
12 changes: 12 additions & 0 deletions configs/garments.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
exp_name = garments_pretrained
data_dir = datasets/garments_data/
split_file = datasets/split_garments.npz
sample_std_dev = [0.08, 0.02, 0.01]
sample_ratio = [0.02, 0.48, 0.50]

input_res = 256
pc_samples = 3000
index = 10
num_points=1000
cam_position=[0, 1, 0]
cam_orientation=[90.0, 0.0, 180.0]
7 changes: 7 additions & 0 deletions configs/shapenet_cars.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,10 @@ split_file = datasets/shapenet/data/split_cars.npz
input_data_glob = /*/model.obj
sample_std_dev = [0.08, 0.02, 0.003]
sample_ratio = [0.01, 0.49, 0.5]

input_res = 256
pc_samples = 3000
index = 1
num_points=1000
cam_position=[0, 1, 0]
cam_orientation=[90.0, 0.0, 180.0]
7 changes: 7 additions & 0 deletions configs/shapenet_cars_pretrained.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,10 @@ split_file = datasets/shapenet/data/split_cars.npz
input_data_glob = /*/model.obj
sample_std_dev = [0.08, 0.02, 0.003]
sample_ratio = [0.01, 0.49, 0.5]

input_res = 256
pc_samples = 3000
index = 1
num_points=1000
cam_position=[0, 1, 0]
cam_orientation=[90.0, 0.0, 180.0]
41 changes: 41 additions & 0 deletions dataprocessing/garment_normalize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import os
import numpy as np
import trimesh
import argparse

def get_dirs_paths(d):
paths = [os.path.join(d, o) for o in os.listdir(d) if os.path.isdir(os.path.join(d, o))]
dirs = [ o for o in os.listdir(d) if os.path.isdir(os.path.join(d, o))]
return sorted(dirs), sorted(paths)



if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--input_folder", type = str)
parser.add_argument("--output_folder", type = str)
args = parser.parse_args()

lists = ['TShirtNoCoat.obj', 'ShortPants.obj', 'Pants.obj', 'ShirtNoCoat.obj', 'LongCoat.obj']

all_dirs, all_paths = get_dirs_paths(args.input_folder)
for index in range(len(all_paths)):
path = all_paths[index]
for file in os.listdir(path):
if file in lists:
class_name = file.replace('.obj', '')
mesh_path = os.path.join(path, file)
out_dir = os.path.join(args.output_folder, class_name + '_' + all_dirs[index])
if not os.path.isdir(out_dir):
os.makedirs(out_dir)

out_file = os.path.join(out_dir, 'mesh.off')

mesh = trimesh.load(mesh_path)

new_verts = mesh.vertices - np.mean(mesh.vertices, axis = 0)
new_verts_sc = new_verts / 0.9748783846
new_verts_sc = new_verts_sc * 0.5
new_mesh = trimesh.Trimesh(vertices = new_verts_sc, faces = mesh.faces)
new_mesh.export(out_file)
print("Processed {} {}".format(all_dirs[index], class_name))
Loading