Skip to content

Latest commit

 

History

History
94 lines (83 loc) · 3.91 KB

04.pytorch-nerf模型创建2之位置编码.md

File metadata and controls

94 lines (83 loc) · 3.91 KB

7.颠覆传统三维重建方法之nerf(九)---pytorch-nerf模型创建2之位置编码

读get_embedder函数

问题?

  • 为什么有两个位置编码?
  • 第一个位置编码输入3,输出63通道
  • 第二个位置编码输入3,输出27通道
  • 63+27传入到深度学习模型

一. 位置编码get_embedder

embed_fn, input_ch = get_embedder(args.multires, args.i_embed)
if args.use_viewdirs:
    embeddirs_fn, input_ch_views = get_embedder(args.multires_views, args.i_embed)
  1. 输入
    • use_viewdirs(default=true) : use full 5D input instead of 3D
    • multires(default=10):log2 of max freq for positional encoding (3D location)
    • i_embed(default=0):set 0 for default positional encoding, -1 for none
    • multires_views(default=4):log2 of max freq for positional encoding (2D direction)
  2. 输出
    • embed_fn:编码函数
    • input_ch = 63:编码输出的通道数
    • input_ch_views=27: 3+342 = 27
  3. input_ch分析:
    • 对输入的3通道数据,进行multires次torch.sin和torch.cos变化。
    • 算一下输出:输入的原值输出3 + 3102 = 63。 10次sin + 次cos。
    • sin和cos是x * freq。freq = 2的0-9次方。freq=[1,2,4,...,512];
    • embed_fn = [ x[0], x[1], x[2], sin(1 *x[0]), cos(1 *x[0]), sin(1 *x[1]), cos(1 **x[1]), sin(1 **x[2]), cos(1 **x[2]), sin(2 *x[0]), cos(2 *x[0]), sin(2 *x[1]), cos(2 **x[1]), sin(2 **x[2]), cos(2 **x[2]), sin(4 *x[0]), cos(4 *x[0]), sin(4 *x[1]), cos(4 **x[1]), sin(4 **x[2]), cos(4 **x[2]), sin(8 *x[0]), cos(8 *x[0]), sin(8 *x[1]), cos(8 x[1]), sin(8 x[2]), cos(8 x[2]), sin(16 x[0]), cos(16 x[0]), sin(16 x[1]), cos(16 x[1]), sin(16 x[2]), cos(16 x[2]), sin(32 x[0]), cos(32 x[0]), sin(32 x[1]), cos(32 x[1]), sin(32 x[2]), cos(32 x[2]), sin(64 x[0]), cos(64 x[0]), sin(64 x[1]), cos(64 **x[1]), sin(64 **x[2]), cos(64 **x[2]), sin(128x[0]), cos(128x[0]), sin(128x[1]), cos(128x[1]), sin(128x[2]), cos(128x[2]), sin(256x[0]), cos(256x[0]), sin(256x[1]), cos(256x[1]), sin(256x[2]), cos(256x[2]), sin(512x[0]), cos(512x[0]), sin(512x[1]), cos(512x[1]), sin(512x[2]), cos(512x[2])]

二. 具体实现,可以不看,直接拷贝

def get_embedder(multires, i=0):
    if i == -1:
        return nn.Identity(), 3
    
    embed_kwargs = {
                'include_input' : True,
                'input_dims' : 3,
                'max_freq_log2' : multires-1,
                'num_freqs' : multires,
                'log_sampling' : True,
                'periodic_fns' : [torch.sin, torch.cos],
    }
    
    embedder_obj = Embedder(**embed_kwargs)
    embed = lambda x, eo=embedder_obj : eo.embed(x)
    return embed, embedder_obj.out_dim
# Positional encoding (section 5.1)
class Embedder:
    def __init__(self, **kwargs):
        self.kwargs = kwargs
        self.create_embedding_fn()
        
    def create_embedding_fn(self):
        embed_fns = []
        d = self.kwargs['input_dims']
        out_dim = 0
        if self.kwargs['include_input']:
            embed_fns.append(lambda x : x)
            out_dim += d
            
        max_freq = self.kwargs['max_freq_log2']
        N_freqs = self.kwargs['num_freqs']
        
        if self.kwargs['log_sampling']:
            freq_bands = 2.**torch.linspace(0., max_freq, steps=N_freqs)
        else:
            freq_bands = torch.linspace(2.**0., 2.**max_freq, steps=N_freqs)
            
        for freq in freq_bands:
            for p_fn in self.kwargs['periodic_fns']:
                embed_fns.append(lambda x, p_fn=p_fn, freq=freq : p_fn(x * freq))
                out_dim += d
                    
        self.embed_fns = embed_fns
        self.out_dim = out_dim
        
    def embed(self, inputs):
        return torch.cat([fn(inputs) for fn in self.embed_fns], -1)