-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpoint-cloud-classification.html
4052 lines (4031 loc) · 522 KB
/
point-cloud-classification.html
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
<!DOCTYPE html>
<html xmlns="http://www.w3.org/1999/xhtml" lang="" xml:lang="">
<head>
<meta charset="utf-8" />
<meta name="generator" content="pandoc" />
<meta name="viewport" content="width=device-width, initial-scale=1.0, user-scalable=yes" />
<meta name="author" content="SternGerlach" />
<title>点群処理のFPGA高速化</title>
<style>
code{white-space: pre-wrap;}
span.smallcaps{font-variant: small-caps;}
div.columns{display: flex; gap: min(4vw, 1.5em);}
div.column{flex: auto; overflow-x: auto;}
div.hanging-indent{margin-left: 1.5em; text-indent: -1.5em;}
ul.task-list{list-style: none;}
ul.task-list li input[type="checkbox"] {
width: 0.8em;
margin: 0 0.8em 0.2em -1.6em;
vertical-align: middle;
}
pre > code.sourceCode { white-space: pre; position: relative; }
pre > code.sourceCode > span { display: inline-block; line-height: 1.25; }
pre > code.sourceCode > span:empty { height: 1.2em; }
.sourceCode { overflow: visible; }
code.sourceCode > span { color: inherit; text-decoration: inherit; }
div.sourceCode { margin: 1em 0; }
pre.sourceCode { margin: 0; }
@media screen {
div.sourceCode { overflow: auto; }
}
@media print {
pre > code.sourceCode { white-space: pre-wrap; }
pre > code.sourceCode > span { text-indent: -5em; padding-left: 5em; }
}
pre.numberSource code
{ counter-reset: source-line 0; }
pre.numberSource code > span
{ position: relative; left: -4em; counter-increment: source-line; }
pre.numberSource code > span > a:first-child::before
{ content: counter(source-line);
position: relative; left: -1em; text-align: right; vertical-align: baseline;
border: none; display: inline-block;
-webkit-touch-callout: none; -webkit-user-select: none;
-khtml-user-select: none; -moz-user-select: none;
-ms-user-select: none; user-select: none;
padding: 0 4px; width: 4em;
color: #aaaaaa;
}
pre.numberSource { margin-left: 3em; border-left: 1px solid #aaaaaa; padding-left: 4px; }
div.sourceCode
{ }
@media screen {
pre > code.sourceCode > span > a:first-child::before { text-decoration: underline; }
}
code span.al { color: #ff0000; font-weight: bold; } /* Alert */
code span.an { color: #60a0b0; font-weight: bold; font-style: italic; } /* Annotation */
code span.at { color: #7d9029; } /* Attribute */
code span.bn { color: #40a070; } /* BaseN */
code span.bu { color: #008000; } /* BuiltIn */
code span.cf { color: #007020; font-weight: bold; } /* ControlFlow */
code span.ch { color: #4070a0; } /* Char */
code span.cn { color: #880000; } /* Constant */
code span.co { color: #60a0b0; font-style: italic; } /* Comment */
code span.cv { color: #60a0b0; font-weight: bold; font-style: italic; } /* CommentVar */
code span.do { color: #ba2121; font-style: italic; } /* Documentation */
code span.dt { color: #902000; } /* DataType */
code span.dv { color: #40a070; } /* DecVal */
code span.er { color: #ff0000; font-weight: bold; } /* Error */
code span.ex { } /* Extension */
code span.fl { color: #40a070; } /* Float */
code span.fu { color: #06287e; } /* Function */
code span.im { color: #008000; font-weight: bold; } /* Import */
code span.in { color: #60a0b0; font-weight: bold; font-style: italic; } /* Information */
code span.kw { color: #007020; font-weight: bold; } /* Keyword */
code span.op { color: #666666; } /* Operator */
code span.ot { color: #007020; } /* Other */
code span.pp { color: #bc7a00; } /* Preprocessor */
code span.sc { color: #4070a0; } /* SpecialChar */
code span.ss { color: #bb6688; } /* SpecialString */
code span.st { color: #4070a0; } /* String */
code span.va { color: #19177c; } /* Variable */
code span.vs { color: #4070a0; } /* VerbatimString */
code span.wa { color: #60a0b0; font-weight: bold; font-style: italic; } /* Warning */
</style>
<link rel="stylesheet" href="style.css" />
<script
src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-chtml-full.js"
type="text/javascript"></script>
<!--[if lt IE 9]>
<script src="//cdnjs.cloudflare.com/ajax/libs/html5shiv/3.7.3/html5shiv-printshiv.min.js"></script>
<![endif]-->
</head>
<body>
<header id="title-block-header">
<h1 class="title">点群処理のFPGA高速化</h1>
<p class="author">SternGerlach</p>
</header>
<!--
pandoc -s -f markdown -t html5 --mathjax --css style.css point-cloud-classification.md -o point-cloud-classification.html
-->
<p><a href="./index.html">ホームに戻る</a></p>
<h1 id="このページについて">このページについて</h1>
<p>このページは、<a
href="https://adventar.org/calendars/7773">慶應理工アドベントカレンダー2022</a>の22日目の記事です。
去年の記事は<a
href="./scan-matching-branch-and-bound.html">こちら</a>と<a
href="./scan-matching-branch-and-bound-impl.html">こちら</a>です。</p>
<p>早速余談ですが、1983年12月22日は、Yellow Magic Orchestra (YMO)
が行った最後の国内ツアーの最終日で、開催場所は日本武道館でした。
今日は、その散開ツアーからちょうど39年目の記念すべき日です。
1984年2月22日発売の「アフター・サーヴィス」や、1992年11月21日発売の「コンプリート・サーヴィス」に音源が収録されているので、みなさん是非聴いてみてください。
また余談ですが、普段は(研究そっちのけで)CDを集めています。
70年代から80年代にかけてのアーティストが好きです。
最近は、専らオフコースを聴いています。
オフコースの旧規格盤のコレクションは<a
href="./off-course-ca35-series.html">こちら</a>にあります。
また、コレクションは<a href="./cds.html">こちら</a>と<a
href="./toshiba-emi.html">こちら</a>にまとめてあります。
暇なときにご覧ください。</p>
<p>もう一つ余談。 今年聴いたなかで最も良かったアルバム。</p>
<ol type="1">
<li>チューリップ「Halo」(1983年 / VICL-62399 / 2007年盤)
<ul>
<li>特によかった曲:
🥇「丘に吹く風」🥈「愛を抱きしめて」🥉「輝く星」「想い出のランドスケープ」「コスモスの咲く郷」「星空の伝言」「セルリアン・ブルー」</li>
</ul></li>
<li>オフコース「この道をゆけば オフ・コース・ラウンド2」(1974年 /
CA35-1033 / 1983年盤)
<ul>
<li>特によかった曲:
🥇「はたちの頃」🥈「別れの情景(1)」🥉「首輪のない犬」「あの角をまがれば」「日曜日のたいくつ」</li>
</ul></li>
<li>オフコース「I Love You」(1982年 / CA35-1002 / 1982年盤)
<ul>
<li>特によかった曲:
🥇「哀しき街」🥈「決して彼等のようではなく」🥉「Yes-Yes-Yes」「愛のゆくえ」</li>
</ul></li>
<li>オフコース「ワインの匂い」(1975年 / CA35-1032 / 1983年盤)
<ul>
<li>特によかった曲:
🥇「幻想」🥈「老人のつぶやき」🥉「憂き世に」「雨よ激しく」「倖せなんて」「ワインの匂い」「眠れぬ夜」</li>
</ul></li>
<li>オフコース「Song Is Love」(1976年 / CA35-1041 / 1983年盤)
<ul>
<li>特によかった曲:
🥇「冬が来るまえに」🥈「青空と人生と」🥉「歌を捧げて」「青春」「ひとりで生きてゆければ」</li>
</ul></li>
<li>チューリップ「New Tune」(1985年 / 35FD-1005 / 1985年盤)
<ul>
<li>特によかった曲:
🥇「もっと幸せに素直になれたら」🥈「ロベリア」🥉「Our
Song」「ふたつめのクリスマス」「そんな男になれたら」</li>
</ul></li>
<li>大滝詠一「Each Time」(1984年 / 35DH 78 / 1984年盤)
<ul>
<li>特によかった曲: 🥇「Bachelor
Girl」🥈「ペパーミント・ブルー」🥉「魔法の瞳」「恋のナックルボール」</li>
</ul></li>
<li>麗美「“R”」(1984年 / 35C31-7250 / 1984年盤)
<ul>
<li>特によかった曲:
🥇「星のクライマー」🥈「風は明日へ」🥉「空が一面海に見えた日」「恋の一時間は孤独の千年」「青春のリグレット」「ポニーテイル」</li>
</ul></li>
<li>ハイ・ファイ・セット「Sweet Locomotion」(1986年 / 32DH 393 /
1986年盤)
<ul>
<li>特によかった曲:
🥇「ひときれの恋」🥈「たった一枚のフォトグラフ」🥉「Elevator Town」「Do
You Remember Me?」</li>
</ul></li>
<li>和久井映見「Flora」(1990年 / PSCR-1006 / 1990年盤)
<ul>
<li>特によかった曲:
🥇「マイ・ロンリィ・グッバイ・クラブ」🥈「偶然の旅人」🥉「夢で会いましょう」「神様がいない土曜日」</li>
</ul></li>
<li>鈴木康博「Sincerely」(1983年 / CA35-1043 / 1983年盤)
<ul>
<li>特によかった曲: 🥇「瑠璃色の夜明け」🥈「僕と海へ」🥉「ラララ
~愛の世界へ~」「入り江」「君の誕生日」</li>
</ul></li>
<li>岡田有希子「ヴィーナス誕生」(1986年 / D32A0164 / 1986年盤)
<ul>
<li>特によかった曲:
🥇「ヴィーナス誕生」🥈「銀河のバカンス」🥉「眠れぬ夜のAquarius」「Wonder
Trip Lover」「Spring Accident」</li>
</ul></li>
<li>尾崎亜美「Kids」(1986年 / D32A0235 / 1986年盤)
<ul>
<li>特によかった曲:
🥇「流れ星が好き」🥈「シャイネスボーイ」🥉「St.Valentine’s Day
Rhapsody」「Com’on Mamy」</li>
</ul></li>
<li>久保田早紀「夜の底は柔らかな幻」(1984年 / DYCL-17 / 2005年盤)
<ul>
<li>特によかった曲:
🥇「ピアニッシモで…」🥈「寒い絵葉書」🥉「月の浜辺ボタンがひとつ」「メランコリーのテーブルクロス」</li>
</ul></li>
<li>薬師丸ひろ子「花図鑑」(1986年 / CA32-1260 / 1986年盤)
<ul>
<li>特によかった曲:
🥇「紅い花、青い花」🥈「寒椿、咲いた」🥉「ローズ・ティーはいかが?」「哀しみの種」「透明なチューリップ」「麦わら帽子のアン」</li>
</ul></li>
</ol>
<p>イントロが良い曲 (おまけ)。</p>
<ol type="1">
<li>チューリップ「Shooting Star」(1981年)</li>
<li>井上鑑「Karsavina ~ニジンスキーの翼」(1983年)</li>
<li>井上鑑「Running Fence -Ode A Christo」(1982年)</li>
</ol>
<p>今年は、点群処理 (点群分類タスク)
向けニューラルネットのFPGA高速化を試してみます。
LeNetやResNetなど、画像処理向けニューラルネットのFPGA高速化も面白いのですが、既にたくさんの素晴らしい記事が出ているのでやめました。
音楽の話も、誰にも通じないし、ウケないと思ったのでやめました。
コンピュータで閲覧されることをお勧めします。</p>
<h1 id="ニューラルネットの準備">ニューラルネットの準備</h1>
<p>点群の分類、セグメンテーション、レジストレーションなど、様々なタスクに対応した代表的なモデルとして、2017年にCVPRで発表されたPointNetが挙げられます。
PointNetは、MLPとMaxプーリング層からなる、シンプルかつ強力なモデルです。
分類タスク向けのPointNetの構造を、以下に示します。</p>
<p><a
href="point-cloud-classification-images/pointnet-layers.svg"><img src="point-cloud-classification-images/pointnet-layers.svg" width="100%" /></a></p>
<p>モデルは、点群からの特徴抽出と、特徴に基づく分類の、2つの部分に分けられます
(図のFeature extractionとClassification)。</p>
<p>図の左端に示すように、<span
class="math inline">\(N\)</span>個の点を含む、3次元の点群<span
class="math inline">\(\mathcal{P} = \left\{ \boldsymbol{p}_1, \ldots,
\boldsymbol{p}_N \right\} \in \mathbb{R}^{N \times
3}\)</span>が入力です。 MLPを用いて、各点<span
class="math inline">\(\boldsymbol{p}_i \in
\mathbb{R}^3\)</span>に対して、1024次元のローカルな特徴<span
class="math inline">\(\boldsymbol{\psi}_i \in
\mathbb{R}^{1024}\)</span>を計算します。
全ての点に対してローカルな特徴量<span
class="math inline">\(\boldsymbol{\Psi} = \left\{ \boldsymbol{\psi}_1,
\ldots, \boldsymbol{\psi}_N \right\} \in \mathbb{R}^{N \times
1024}\)</span>を計算したら、それらをMaxプーリング層により集約して、点群全体を表すグローバルな特徴量<span
class="math inline">\(\boldsymbol{\phi} \in
\mathbb{R}^{1024}\)</span>を得ます (<span
class="math inline">\(\boldsymbol{\phi} \gets \max(\boldsymbol{\psi}_1,
\ldots, \boldsymbol{\psi}_N)\)</span>)。</p>
<p>分類用のネットワークは、この特徴量<span
class="math inline">\(\boldsymbol{\phi}\)</span>を入力として、各物体のクラスに対するロジット
(スコア)を出力します。 物体のクラス数を<span
class="math inline">\(K\)</span>とすれば、出力は<span
class="math inline">\(K\)</span>次元のベクトルとなります。</p>
<p>図のInput TransformおよびFeature
Transformは、点群の特徴に対してアフィン変換を施し、剛体変換に対して不変な特徴量を得るためのネットワークですが、実装が面倒なので取り除きます(<strong>最適化その1:
モデルの簡略化</strong>)。
従って、今回FPGA上に実装するPointNetは、以下のようになります。</p>
<p>画像認識向けのモデルとは異なり、畳み込み層がありません。
また、MLPは、全結合層、ReLU活性化層、バッチ正規化層をまとめたものとします。</p>
<p><a
href="point-cloud-classification-images/pointnet-layers2.svg"><img src="point-cloud-classification-images/pointnet-layers2.svg" width="80%" /></a></p>
<p>PyTorchによるモデルの定義は、次のようになります
(<code>net/model.py</code>)。 ソースコード全体は<a
href="https://github.com/sterngerlach/advent_2022_point_cloud_classification">こちらのリポジトリ</a>に置かれているので、適宜ご参照ください。</p>
<div class="sourceCode" id="cb1"><pre
class="sourceCode python"><code class="sourceCode python"><span id="cb1-1"><a href="#cb1-1" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> PointNetFeat(torch.nn.Module):</span>
<span id="cb1-2"><a href="#cb1-2" aria-hidden="true" tabindex="-1"></a> <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>):</span>
<span id="cb1-3"><a href="#cb1-3" aria-hidden="true" tabindex="-1"></a> <span class="bu">super</span>().<span class="fu">__init__</span>()</span>
<span id="cb1-4"><a href="#cb1-4" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb1-5"><a href="#cb1-5" aria-hidden="true" tabindex="-1"></a> <span class="va">self</span>.conv1 <span class="op">=</span> torch.nn.Conv1d(<span class="dv">3</span>, <span class="dv">64</span>, <span class="dv">1</span>)</span>
<span id="cb1-6"><a href="#cb1-6" aria-hidden="true" tabindex="-1"></a> <span class="va">self</span>.conv2 <span class="op">=</span> torch.nn.Conv1d(<span class="dv">64</span>, <span class="dv">64</span>, <span class="dv">1</span>)</span>
<span id="cb1-7"><a href="#cb1-7" aria-hidden="true" tabindex="-1"></a> <span class="va">self</span>.conv3 <span class="op">=</span> torch.nn.Conv1d(<span class="dv">64</span>, <span class="dv">64</span>, <span class="dv">1</span>)</span>
<span id="cb1-8"><a href="#cb1-8" aria-hidden="true" tabindex="-1"></a> <span class="va">self</span>.conv4 <span class="op">=</span> torch.nn.Conv1d(<span class="dv">64</span>, <span class="dv">128</span>, <span class="dv">1</span>)</span>
<span id="cb1-9"><a href="#cb1-9" aria-hidden="true" tabindex="-1"></a> <span class="va">self</span>.conv5 <span class="op">=</span> torch.nn.Conv1d(<span class="dv">128</span>, <span class="dv">1024</span>, <span class="dv">1</span>)</span>
<span id="cb1-10"><a href="#cb1-10" aria-hidden="true" tabindex="-1"></a> <span class="va">self</span>.bn1 <span class="op">=</span> torch.nn.BatchNorm1d(<span class="dv">64</span>)</span>
<span id="cb1-11"><a href="#cb1-11" aria-hidden="true" tabindex="-1"></a> <span class="va">self</span>.bn2 <span class="op">=</span> torch.nn.BatchNorm1d(<span class="dv">64</span>)</span>
<span id="cb1-12"><a href="#cb1-12" aria-hidden="true" tabindex="-1"></a> <span class="va">self</span>.bn3 <span class="op">=</span> torch.nn.BatchNorm1d(<span class="dv">64</span>)</span>
<span id="cb1-13"><a href="#cb1-13" aria-hidden="true" tabindex="-1"></a> <span class="va">self</span>.bn4 <span class="op">=</span> torch.nn.BatchNorm1d(<span class="dv">128</span>)</span>
<span id="cb1-14"><a href="#cb1-14" aria-hidden="true" tabindex="-1"></a> <span class="va">self</span>.bn5 <span class="op">=</span> torch.nn.BatchNorm1d(<span class="dv">1024</span>)</span>
<span id="cb1-15"><a href="#cb1-15" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb1-16"><a href="#cb1-16" aria-hidden="true" tabindex="-1"></a> <span class="kw">def</span> forward(<span class="va">self</span>, x: torch.Tensor):</span>
<span id="cb1-17"><a href="#cb1-17" aria-hidden="true" tabindex="-1"></a> <span class="co"># `x` is of size [B, N, 3]</span></span>
<span id="cb1-18"><a href="#cb1-18" aria-hidden="true" tabindex="-1"></a> N <span class="op">=</span> x.shape[<span class="dv">1</span>]</span>
<span id="cb1-19"><a href="#cb1-19" aria-hidden="true" tabindex="-1"></a> <span class="co"># `x` is of size [B, 3, N]</span></span>
<span id="cb1-20"><a href="#cb1-20" aria-hidden="true" tabindex="-1"></a> x <span class="op">=</span> x.transpose(<span class="dv">1</span>, <span class="dv">2</span>)</span>
<span id="cb1-21"><a href="#cb1-21" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb1-22"><a href="#cb1-22" aria-hidden="true" tabindex="-1"></a> <span class="co"># `x` is of size [B, 1024, N]</span></span>
<span id="cb1-23"><a href="#cb1-23" aria-hidden="true" tabindex="-1"></a> x <span class="op">=</span> F.relu(<span class="va">self</span>.bn1(<span class="va">self</span>.conv1(x)))</span>
<span id="cb1-24"><a href="#cb1-24" aria-hidden="true" tabindex="-1"></a> x <span class="op">=</span> F.relu(<span class="va">self</span>.bn2(<span class="va">self</span>.conv2(x)))</span>
<span id="cb1-25"><a href="#cb1-25" aria-hidden="true" tabindex="-1"></a> x <span class="op">=</span> F.relu(<span class="va">self</span>.bn3(<span class="va">self</span>.conv3(x)))</span>
<span id="cb1-26"><a href="#cb1-26" aria-hidden="true" tabindex="-1"></a> x <span class="op">=</span> F.relu(<span class="va">self</span>.bn4(<span class="va">self</span>.conv4(x)))</span>
<span id="cb1-27"><a href="#cb1-27" aria-hidden="true" tabindex="-1"></a> x <span class="op">=</span> F.relu(<span class="va">self</span>.bn5(<span class="va">self</span>.conv5(x)))</span>
<span id="cb1-28"><a href="#cb1-28" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb1-29"><a href="#cb1-29" aria-hidden="true" tabindex="-1"></a> <span class="co"># `x` is of size [B, 1024]</span></span>
<span id="cb1-30"><a href="#cb1-30" aria-hidden="true" tabindex="-1"></a> x <span class="op">=</span> torch.<span class="bu">max</span>(x, dim<span class="op">=</span><span class="dv">2</span>)[<span class="dv">0</span>]</span>
<span id="cb1-31"><a href="#cb1-31" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb1-32"><a href="#cb1-32" aria-hidden="true" tabindex="-1"></a> <span class="cf">return</span> x</span>
<span id="cb1-33"><a href="#cb1-33" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb1-34"><a href="#cb1-34" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> PointNetCls(torch.nn.Module):</span>
<span id="cb1-35"><a href="#cb1-35" aria-hidden="true" tabindex="-1"></a> <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, num_classes: <span class="bu">int</span>):</span>
<span id="cb1-36"><a href="#cb1-36" aria-hidden="true" tabindex="-1"></a> <span class="bu">super</span>().<span class="fu">__init__</span>()</span>
<span id="cb1-37"><a href="#cb1-37" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb1-38"><a href="#cb1-38" aria-hidden="true" tabindex="-1"></a> <span class="co"># Feature extraction</span></span>
<span id="cb1-39"><a href="#cb1-39" aria-hidden="true" tabindex="-1"></a> <span class="va">self</span>.feat <span class="op">=</span> PointNetFeat()</span>
<span id="cb1-40"><a href="#cb1-40" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb1-41"><a href="#cb1-41" aria-hidden="true" tabindex="-1"></a> <span class="co"># Classification network</span></span>
<span id="cb1-42"><a href="#cb1-42" aria-hidden="true" tabindex="-1"></a> <span class="va">self</span>.fc1 <span class="op">=</span> torch.nn.Linear(<span class="dv">1024</span>, <span class="dv">512</span>)</span>
<span id="cb1-43"><a href="#cb1-43" aria-hidden="true" tabindex="-1"></a> <span class="va">self</span>.fc2 <span class="op">=</span> torch.nn.Linear(<span class="dv">512</span>, <span class="dv">256</span>)</span>
<span id="cb1-44"><a href="#cb1-44" aria-hidden="true" tabindex="-1"></a> <span class="va">self</span>.fc3 <span class="op">=</span> torch.nn.Linear(<span class="dv">256</span>, num_classes)</span>
<span id="cb1-45"><a href="#cb1-45" aria-hidden="true" tabindex="-1"></a> <span class="va">self</span>.bn1 <span class="op">=</span> torch.nn.BatchNorm1d(<span class="dv">512</span>)</span>
<span id="cb1-46"><a href="#cb1-46" aria-hidden="true" tabindex="-1"></a> <span class="va">self</span>.bn2 <span class="op">=</span> torch.nn.BatchNorm1d(<span class="dv">256</span>)</span>
<span id="cb1-47"><a href="#cb1-47" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb1-48"><a href="#cb1-48" aria-hidden="true" tabindex="-1"></a> <span class="kw">def</span> forward(<span class="va">self</span>, x):</span>
<span id="cb1-49"><a href="#cb1-49" aria-hidden="true" tabindex="-1"></a> <span class="co"># `x` is of size [B, N, 3]</span></span>
<span id="cb1-50"><a href="#cb1-50" aria-hidden="true" tabindex="-1"></a> <span class="co"># `x` is of size [B, 1024]</span></span>
<span id="cb1-51"><a href="#cb1-51" aria-hidden="true" tabindex="-1"></a> x <span class="op">=</span> <span class="va">self</span>.feat(x)</span>
<span id="cb1-52"><a href="#cb1-52" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb1-53"><a href="#cb1-53" aria-hidden="true" tabindex="-1"></a> <span class="co"># `x` is of size [B, `num_classes`]</span></span>
<span id="cb1-54"><a href="#cb1-54" aria-hidden="true" tabindex="-1"></a> x <span class="op">=</span> F.relu(<span class="va">self</span>.bn1(<span class="va">self</span>.fc1(x)))</span>
<span id="cb1-55"><a href="#cb1-55" aria-hidden="true" tabindex="-1"></a> x <span class="op">=</span> F.relu(<span class="va">self</span>.bn2(<span class="va">self</span>.fc2(x)))</span>
<span id="cb1-56"><a href="#cb1-56" aria-hidden="true" tabindex="-1"></a> x <span class="op">=</span> <span class="va">self</span>.fc3(x)</span>
<span id="cb1-57"><a href="#cb1-57" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb1-58"><a href="#cb1-58" aria-hidden="true" tabindex="-1"></a> <span class="cf">return</span> x</span></code></pre></div>
<p>さて、このモデルをそのまま実装する場合、次のような問題があります。
特徴抽出部分 (図のFeature extraction)に注目します。
図中の灰色の四角に示すように、<span
class="math inline">\(N\)</span>個全ての点に対する中間結果や、ローカルな特徴量<span
class="math inline">\(\boldsymbol{\Psi}\)</span>を、どこかに保持しておく必要があります。
大容量のメモリを搭載したGPUであれば、これでも問題ありませんが、FPGA内部のオンチップメモリ
(BlockRAM)
は非常に容量が少ないので、全ての点に対する中間結果を保持しようとすると、オンチップメモリがあっという間に枯渇するでしょう。
言い換えると、搭載されているオンチップメモリの容量によって、点の個数<span
class="math inline">\(N\)</span>が制限されてしまいます。
これは避けたいものです。
オンチップメモリの代わりに、容量の大きなDRAM上に置くこともできますが、データへのアクセス時間は長くなります。
全ての層の中間結果をDRAMに置くと、データ転送のオーバーヘッドが増加して、性能に悪影響を及ぼします。
層の中間結果は、オンチップバッファに置きたいものです。</p>
<p>そこで、全ての点<span
class="math inline">\(\mathcal{P}\)</span>に対して、ローカルな特徴量<span
class="math inline">\(\boldsymbol{\Psi}\)</span>を一気に計算するのではなく、1つずつの点<span
class="math inline">\(\boldsymbol{p}\)</span>に対して順にローカルな特徴量<span
class="math inline">\(\boldsymbol{\psi}\)</span>を計算しましょう。
一気に計算するのと比べて計算効率は落ちますが、1つの点に対する中間結果やローカルな特徴量だけを保持すればよいので、オンチップバッファの消費を大きく削減できます。</p>
<p>以前は
(PyTorchなどのフレームワークを使う場合は)、特徴抽出は次のように行われていました。</p>
<ol type="1">
<li>全ての点<span
class="math inline">\(\mathcal{P}\)</span>に対して、ローカルな特徴量を<span
class="math inline">\(\boldsymbol{\Psi}\)</span>をまとめて計算する
(<span class="math inline">\((N, 64)\)</span>や<span
class="math inline">\((N, 1024)\)</span>のバッファが必要)。</li>
<li>Maxプーリング層により、ローカルな特徴量<span
class="math inline">\(\boldsymbol{\Psi}\)</span>を集約して、グローバルな特徴量<span
class="math inline">\(\boldsymbol{\phi}\)</span>を得る (<span
class="math inline">\(\boldsymbol{\phi} \gets \max(\boldsymbol{\psi}_1,
\ldots, \boldsymbol{\psi}_N)\)</span>)。</li>
<li>グローバルな特徴量<span
class="math inline">\(\boldsymbol{\phi}\)</span>をMLPに入力し、各クラスに対するロジット(<span
class="math inline">\(K\)</span>次元のベクトル)を得る。</li>
</ol>
<p>これを、次のように変更します(<strong>最適化その2:
計算順序の変更</strong>)。</p>
<ol type="1">
<li>グローバルな特徴量<span
class="math inline">\(\boldsymbol{\phi}\)</span>を、<span
class="math inline">\(\boldsymbol{0}\)</span>で初期化する。</li>
<li>各点<span class="math inline">\(\boldsymbol{p}_i \ (i = 1, \ldots,
N)\)</span>に対して、以下の処理を行う。
<ol type="1">
<li>MLPの順伝播により、ローカルな特徴量<span
class="math inline">\(\boldsymbol{\psi}_i\)</span>を得る (<span
class="math inline">\((1, 64)\)</span>や<span class="math inline">\((1,
1024)\)</span>のバッファがあればよい)。</li>
<li><span class="math inline">\(\boldsymbol{\phi}\)</span>と<span
class="math inline">\(\boldsymbol{\psi}_i\)</span>との、要素ごとの<span
class="math inline">\(\max\)</span>をとることで、<span
class="math inline">\(\boldsymbol{\phi}\)</span>を更新する (<span
class="math inline">\(\boldsymbol{\phi} \gets \max(\boldsymbol{\phi},
\boldsymbol{\psi}_i)\)</span>)。</li>
</ol></li>
<li>グローバルな特徴量<span
class="math inline">\(\boldsymbol{\phi}\)</span>をMLPに入力し、各クラスに対するロジット(<span
class="math inline">\(K\)</span>次元のベクトル)を得る。</li>
</ol>
<p>全ての点に対するローカルな特徴量<span
class="math inline">\(\boldsymbol{\Psi}\)</span>を集約するのではなく、各点<span
class="math inline">\(\boldsymbol{p}_i\)</span>に対するローカルな特徴量<span
class="math inline">\(\boldsymbol{\psi}_i\)</span>を使って、グローバルな特徴量<span
class="math inline">\(\boldsymbol{\phi}\)</span>を逐次的に更新していきます。
これは近似ではないので、全く同じ結果となります。</p>
<p>最終的に、今回FPGA上に実装するPointNetは、以下のようになります。</p>
<p><a
href="point-cloud-classification-images/pointnet-layers3.svg"><img src="point-cloud-classification-images/pointnet-layers3.svg" width="80%" /></a></p>
<h1 id="高位合成による実装">高位合成による実装</h1>
<p>今回は、高位合成 (HLS: High-Level
Synthesis)を用いて、上記に示すPointNetの専用回路
(<strong>IPコア</strong>) を記述します。
ニューラルネットの推論を実現する別の手段として、行列演算や畳み込み演算用の、巨大かつ汎用的な演算回路をFPGA上に実装し、それに繰り返しデータを与えることも考えられます。</p>
<p>高位合成は、C/C++による動作レベル (Behavior Level)
の回路記述を、Verilog HDLやSystemVerilogによるレジスタ転送レベル (RTL:
Register Transfer Level) の回路記述に変換するための技術です。 Verilog
HDLを直接記述するのに比べて、遥かに楽で、ストレスが少なく、生産性が向上します。
但し、C/C++で記述するとはいっても、通常のソフトウェア開発とは全く様相が異なります。
<code>malloc()</code>や<code>new</code>はもちろんのこと、これらに依存する<code>std::vector</code>などの便利なデータ型も使えないので、固定長の配列に置き換えてどうにかします。
ニューラルネットはサイズが固定で、一般には決まった動作をするので、FPGA上に実装しやすいです。</p>
<p>高位合成用のツールとして、Xilinx社のVitis HLS 2022.1を利用します。
また実装対象のFPGAとして、Xilinx ZCU104 Evaluation Board
(XCZU7EV-2FFVC1156)を使います。 Xilinx
ZCU104には、FPGAのほかに、クアッドコア ARM Cortex-A53 CPU
(1.2GHz)と2GBのDRAMも搭載されており、Linuxが動作します。</p>
<p>早速、PointNetのIPコアを示します
(適宜GitHubのリポジトリをご覧ください)。
高位合成ツールのバックエンドがGCC
6.2ですので、C++14やC++17の一部機能が利用できます。
但し、ツールのバグを踏むかもしれないので、あまり凝った機能は使わないようにしています。</p>
<div class="sourceCode" id="cb2"><pre
class="sourceCode c++"><code class="sourceCode cpp"><span id="cb2-1"><a href="#cb2-1" aria-hidden="true" tabindex="-1"></a><span class="co">// Size of the PointNet classification network</span></span>
<span id="cb2-2"><a href="#cb2-2" aria-hidden="true" tabindex="-1"></a><span class="co">// Refer to net/model.py for details</span></span>
<span id="cb2-3"><a href="#cb2-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb2-4"><a href="#cb2-4" aria-hidden="true" tabindex="-1"></a><span class="co">// Size of the feature extraction network</span></span>
<span id="cb2-5"><a href="#cb2-5" aria-hidden="true" tabindex="-1"></a><span class="kw">constexpr</span> <span class="at">const</span> <span class="dt">int</span> kFeatDims0 <span class="op">=</span> <span class="dv">3</span><span class="op">;</span></span>
<span id="cb2-6"><a href="#cb2-6" aria-hidden="true" tabindex="-1"></a><span class="kw">constexpr</span> <span class="at">const</span> <span class="dt">int</span> kFeatDims1 <span class="op">=</span> <span class="dv">64</span><span class="op">;</span></span>
<span id="cb2-7"><a href="#cb2-7" aria-hidden="true" tabindex="-1"></a><span class="kw">constexpr</span> <span class="at">const</span> <span class="dt">int</span> kFeatDims2 <span class="op">=</span> <span class="dv">64</span><span class="op">;</span></span>
<span id="cb2-8"><a href="#cb2-8" aria-hidden="true" tabindex="-1"></a><span class="kw">constexpr</span> <span class="at">const</span> <span class="dt">int</span> kFeatDims3 <span class="op">=</span> <span class="dv">64</span><span class="op">;</span></span>
<span id="cb2-9"><a href="#cb2-9" aria-hidden="true" tabindex="-1"></a><span class="kw">constexpr</span> <span class="at">const</span> <span class="dt">int</span> kFeatDims4 <span class="op">=</span> <span class="dv">128</span><span class="op">;</span></span>
<span id="cb2-10"><a href="#cb2-10" aria-hidden="true" tabindex="-1"></a><span class="kw">constexpr</span> <span class="at">const</span> <span class="dt">int</span> kFeatDims5 <span class="op">=</span> <span class="dv">1024</span><span class="op">;</span></span>
<span id="cb2-11"><a href="#cb2-11" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb2-12"><a href="#cb2-12" aria-hidden="true" tabindex="-1"></a><span class="co">// Size of the classification network</span></span>
<span id="cb2-13"><a href="#cb2-13" aria-hidden="true" tabindex="-1"></a><span class="co">// ModelNet40 has 40 object classes</span></span>
<span id="cb2-14"><a href="#cb2-14" aria-hidden="true" tabindex="-1"></a><span class="kw">constexpr</span> <span class="at">const</span> <span class="dt">int</span> kClsDims0 <span class="op">=</span> kFeatDims5<span class="op">;</span></span>
<span id="cb2-15"><a href="#cb2-15" aria-hidden="true" tabindex="-1"></a><span class="kw">constexpr</span> <span class="at">const</span> <span class="dt">int</span> kClsDims1 <span class="op">=</span> <span class="dv">512</span><span class="op">;</span></span>
<span id="cb2-16"><a href="#cb2-16" aria-hidden="true" tabindex="-1"></a><span class="kw">constexpr</span> <span class="at">const</span> <span class="dt">int</span> kClsDims2 <span class="op">=</span> <span class="dv">256</span><span class="op">;</span></span>
<span id="cb2-17"><a href="#cb2-17" aria-hidden="true" tabindex="-1"></a><span class="kw">constexpr</span> <span class="at">const</span> <span class="dt">int</span> kClsDims3 <span class="op">=</span> <span class="dv">40</span><span class="op">;</span></span>
<span id="cb2-18"><a href="#cb2-18" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb2-19"><a href="#cb2-19" aria-hidden="true" tabindex="-1"></a><span class="co">// Top function</span></span>
<span id="cb2-20"><a href="#cb2-20" aria-hidden="true" tabindex="-1"></a><span class="dt">void</span> PointNetClsTop<span class="op">(</span><span class="at">const</span> <span class="dt">int</span> op_mode<span class="op">,</span></span>
<span id="cb2-21"><a href="#cb2-21" aria-hidden="true" tabindex="-1"></a> <span class="at">const</span> <span class="dt">float</span><span class="op">*</span> point_cloud<span class="op">,</span></span>
<span id="cb2-22"><a href="#cb2-22" aria-hidden="true" tabindex="-1"></a> <span class="at">const</span> <span class="dt">int</span> num_points<span class="op">,</span></span>
<span id="cb2-23"><a href="#cb2-23" aria-hidden="true" tabindex="-1"></a> <span class="dt">float</span><span class="op">*</span> out_logits<span class="op">,</span></span>
<span id="cb2-24"><a href="#cb2-24" aria-hidden="true" tabindex="-1"></a> <span class="at">const</span> <span class="dt">float</span><span class="op">*</span> feat_params1<span class="op">,</span></span>
<span id="cb2-25"><a href="#cb2-25" aria-hidden="true" tabindex="-1"></a> <span class="at">const</span> <span class="dt">float</span><span class="op">*</span> feat_params2<span class="op">,</span></span>
<span id="cb2-26"><a href="#cb2-26" aria-hidden="true" tabindex="-1"></a> <span class="at">const</span> <span class="dt">float</span><span class="op">*</span> feat_params3<span class="op">,</span></span>
<span id="cb2-27"><a href="#cb2-27" aria-hidden="true" tabindex="-1"></a> <span class="at">const</span> <span class="dt">float</span><span class="op">*</span> feat_params4<span class="op">,</span></span>
<span id="cb2-28"><a href="#cb2-28" aria-hidden="true" tabindex="-1"></a> <span class="at">const</span> <span class="dt">float</span><span class="op">*</span> feat_params5<span class="op">,</span></span>
<span id="cb2-29"><a href="#cb2-29" aria-hidden="true" tabindex="-1"></a> <span class="at">const</span> <span class="dt">float</span><span class="op">*</span> cls_params1<span class="op">,</span></span>
<span id="cb2-30"><a href="#cb2-30" aria-hidden="true" tabindex="-1"></a> <span class="at">const</span> <span class="dt">float</span><span class="op">*</span> cls_params2<span class="op">,</span></span>
<span id="cb2-31"><a href="#cb2-31" aria-hidden="true" tabindex="-1"></a> <span class="at">const</span> <span class="dt">float</span><span class="op">*</span> cls_params3<span class="op">)</span></span>
<span id="cb2-32"><a href="#cb2-32" aria-hidden="true" tabindex="-1"></a><span class="op">{</span></span>
<span id="cb2-33"><a href="#cb2-33" aria-hidden="true" tabindex="-1"></a><span class="pp">#pragma HLS INTERFACE m_axi port=point_cloud offset=slave bundle=gmem0</span></span>
<span id="cb2-34"><a href="#cb2-34" aria-hidden="true" tabindex="-1"></a><span class="pp">#pragma HLS INTERFACE m_axi port=out_logits offset=slave bundle=gmem0</span></span>
<span id="cb2-35"><a href="#cb2-35" aria-hidden="true" tabindex="-1"></a><span class="pp">#pragma HLS INTERFACE m_axi port=feat_params1 offset=slave bundle=gmem0</span></span>
<span id="cb2-36"><a href="#cb2-36" aria-hidden="true" tabindex="-1"></a><span class="pp">#pragma HLS INTERFACE m_axi port=feat_params2 offset=slave bundle=gmem0</span></span>
<span id="cb2-37"><a href="#cb2-37" aria-hidden="true" tabindex="-1"></a><span class="pp">#pragma HLS INTERFACE m_axi port=feat_params3 offset=slave bundle=gmem0</span></span>
<span id="cb2-38"><a href="#cb2-38" aria-hidden="true" tabindex="-1"></a><span class="pp">#pragma HLS INTERFACE m_axi port=feat_params4 offset=slave bundle=gmem0</span></span>
<span id="cb2-39"><a href="#cb2-39" aria-hidden="true" tabindex="-1"></a><span class="pp">#pragma HLS INTERFACE m_axi port=feat_params5 offset=slave bundle=gmem0</span></span>
<span id="cb2-40"><a href="#cb2-40" aria-hidden="true" tabindex="-1"></a><span class="pp">#pragma HLS INTERFACE m_axi port=cls_params1 offset=slave bundle=gmem0</span></span>
<span id="cb2-41"><a href="#cb2-41" aria-hidden="true" tabindex="-1"></a><span class="pp">#pragma HLS INTERFACE m_axi port=cls_params2 offset=slave bundle=gmem0</span></span>
<span id="cb2-42"><a href="#cb2-42" aria-hidden="true" tabindex="-1"></a><span class="pp">#pragma HLS INTERFACE m_axi port=cls_params3 offset=slave bundle=gmem0</span></span>
<span id="cb2-43"><a href="#cb2-43" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb2-44"><a href="#cb2-44" aria-hidden="true" tabindex="-1"></a><span class="pp">#pragma HLS INTERFACE s_axilite port=op_mode bundle=control</span></span>
<span id="cb2-45"><a href="#cb2-45" aria-hidden="true" tabindex="-1"></a><span class="pp">#pragma HLS INTERFACE s_axilite port=point_cloud bundle=control</span></span>
<span id="cb2-46"><a href="#cb2-46" aria-hidden="true" tabindex="-1"></a><span class="pp">#pragma HLS INTERFACE s_axilite port=num_points bundle=control</span></span>
<span id="cb2-47"><a href="#cb2-47" aria-hidden="true" tabindex="-1"></a><span class="pp">#pragma HLS INTERFACE s_axilite port=out_logits bundle=control</span></span>
<span id="cb2-48"><a href="#cb2-48" aria-hidden="true" tabindex="-1"></a><span class="pp">#pragma HLS INTERFACE s_axilite port=feat_params1 bundle=control</span></span>
<span id="cb2-49"><a href="#cb2-49" aria-hidden="true" tabindex="-1"></a><span class="pp">#pragma HLS INTERFACE s_axilite port=feat_params2 bundle=control</span></span>
<span id="cb2-50"><a href="#cb2-50" aria-hidden="true" tabindex="-1"></a><span class="pp">#pragma HLS INTERFACE s_axilite port=feat_params3 bundle=control</span></span>
<span id="cb2-51"><a href="#cb2-51" aria-hidden="true" tabindex="-1"></a><span class="pp">#pragma HLS INTERFACE s_axilite port=feat_params4 bundle=control</span></span>
<span id="cb2-52"><a href="#cb2-52" aria-hidden="true" tabindex="-1"></a><span class="pp">#pragma HLS INTERFACE s_axilite port=feat_params5 bundle=control</span></span>
<span id="cb2-53"><a href="#cb2-53" aria-hidden="true" tabindex="-1"></a><span class="pp">#pragma HLS INTERFACE s_axilite port=cls_params1 bundle=control</span></span>
<span id="cb2-54"><a href="#cb2-54" aria-hidden="true" tabindex="-1"></a><span class="pp">#pragma HLS INTERFACE s_axilite port=cls_params2 bundle=control</span></span>
<span id="cb2-55"><a href="#cb2-55" aria-hidden="true" tabindex="-1"></a><span class="pp">#pragma HLS INTERFACE s_axilite port=cls_params3 bundle=control</span></span>
<span id="cb2-56"><a href="#cb2-56" aria-hidden="true" tabindex="-1"></a><span class="pp">#pragma HLS INTERFACE s_axilite port=return bundle=control</span></span>
<span id="cb2-57"><a href="#cb2-57" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb2-58"><a href="#cb2-58" aria-hidden="true" tabindex="-1"></a> <span class="co">// Parameters for feature extraction</span></span>
<span id="cb2-59"><a href="#cb2-59" aria-hidden="true" tabindex="-1"></a> LinearParams<span class="op"><</span><span class="dt">param_t</span><span class="op">,</span> kFeatDims0<span class="op">,</span> kFeatDims1<span class="op">></span> feat_conv1<span class="op">;</span></span>
<span id="cb2-60"><a href="#cb2-60" aria-hidden="true" tabindex="-1"></a> LinearParams<span class="op"><</span><span class="dt">param_t</span><span class="op">,</span> kFeatDims1<span class="op">,</span> kFeatDims2<span class="op">></span> feat_conv2<span class="op">;</span></span>
<span id="cb2-61"><a href="#cb2-61" aria-hidden="true" tabindex="-1"></a> LinearParams<span class="op"><</span><span class="dt">param_t</span><span class="op">,</span> kFeatDims2<span class="op">,</span> kFeatDims3<span class="op">></span> feat_conv3<span class="op">;</span></span>
<span id="cb2-62"><a href="#cb2-62" aria-hidden="true" tabindex="-1"></a> LinearParams<span class="op"><</span><span class="dt">param_t</span><span class="op">,</span> kFeatDims3<span class="op">,</span> kFeatDims4<span class="op">></span> feat_conv4<span class="op">;</span></span>
<span id="cb2-63"><a href="#cb2-63" aria-hidden="true" tabindex="-1"></a> LinearParams<span class="op"><</span><span class="dt">param_t</span><span class="op">,</span> kFeatDims4<span class="op">,</span> kFeatDims5<span class="op">></span> feat_conv5<span class="op">;</span></span>
<span id="cb2-64"><a href="#cb2-64" aria-hidden="true" tabindex="-1"></a> BatchNorm1dParams<span class="op"><</span><span class="dt">param_t</span><span class="op">,</span> kFeatDims1<span class="op">></span> feat_bn1<span class="op">;</span></span>
<span id="cb2-65"><a href="#cb2-65" aria-hidden="true" tabindex="-1"></a> BatchNorm1dParams<span class="op"><</span><span class="dt">param_t</span><span class="op">,</span> kFeatDims2<span class="op">></span> feat_bn2<span class="op">;</span></span>
<span id="cb2-66"><a href="#cb2-66" aria-hidden="true" tabindex="-1"></a> BatchNorm1dParams<span class="op"><</span><span class="dt">param_t</span><span class="op">,</span> kFeatDims3<span class="op">></span> feat_bn3<span class="op">;</span></span>
<span id="cb2-67"><a href="#cb2-67" aria-hidden="true" tabindex="-1"></a> BatchNorm1dParams<span class="op"><</span><span class="dt">param_t</span><span class="op">,</span> kFeatDims4<span class="op">></span> feat_bn4<span class="op">;</span></span>
<span id="cb2-68"><a href="#cb2-68" aria-hidden="true" tabindex="-1"></a> BatchNorm1dParams<span class="op"><</span><span class="dt">param_t</span><span class="op">,</span> kFeatDims5<span class="op">></span> feat_bn5<span class="op">;</span></span>
<span id="cb2-69"><a href="#cb2-69" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb2-70"><a href="#cb2-70" aria-hidden="true" tabindex="-1"></a> <span class="co">// Parameters for classification network</span></span>
<span id="cb2-71"><a href="#cb2-71" aria-hidden="true" tabindex="-1"></a> <span class="co">// LinearParams<param_t, kClsDims0, kClsDims1> cls_fc1;</span></span>
<span id="cb2-72"><a href="#cb2-72" aria-hidden="true" tabindex="-1"></a> <span class="co">// LinearParams<param_t, kClsDims1, kClsDims2> cls_fc2;</span></span>
<span id="cb2-73"><a href="#cb2-73" aria-hidden="true" tabindex="-1"></a> LinearParams<span class="op"><</span><span class="dt">param_t</span><span class="op">,</span> kClsDims2<span class="op">,</span> kClsDims3<span class="op">></span> cls_fc3<span class="op">;</span></span>
<span id="cb2-74"><a href="#cb2-74" aria-hidden="true" tabindex="-1"></a> BatchNorm1dParams<span class="op"><</span><span class="dt">param_t</span><span class="op">,</span> kClsDims1<span class="op">></span> cls_bn1<span class="op">;</span></span>
<span id="cb2-75"><a href="#cb2-75" aria-hidden="true" tabindex="-1"></a> BatchNorm1dParams<span class="op"><</span><span class="dt">param_t</span><span class="op">,</span> kClsDims2<span class="op">></span> cls_bn2<span class="op">;</span></span>
<span id="cb2-76"><a href="#cb2-76" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb2-77"><a href="#cb2-77" aria-hidden="true" tabindex="-1"></a> <span class="co">// Extracted feature</span></span>
<span id="cb2-78"><a href="#cb2-78" aria-hidden="true" tabindex="-1"></a> <span class="dt">value_t</span> feature<span class="op">[</span>kFeatDims5<span class="op">];</span></span>
<span id="cb2-79"><a href="#cb2-79" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb2-80"><a href="#cb2-80" aria-hidden="true" tabindex="-1"></a> <span class="cf">if</span> <span class="op">(</span>op_mode <span class="op">==</span> kModeInitWeights<span class="op">)</span> <span class="op">{</span></span>
<span id="cb2-81"><a href="#cb2-81" aria-hidden="true" tabindex="-1"></a> <span class="co">// Initialize the PointNet feature extraction network</span></span>
<span id="cb2-82"><a href="#cb2-82" aria-hidden="true" tabindex="-1"></a> InitializeFeatNaive<span class="op"><</span><span class="dt">param_t</span><span class="op">>(</span></span>
<span id="cb2-83"><a href="#cb2-83" aria-hidden="true" tabindex="-1"></a> <span class="op">&</span>feat_conv1<span class="op">,</span> <span class="op">&</span>feat_conv2<span class="op">,</span> <span class="op">&</span>feat_conv3<span class="op">,</span> <span class="op">&</span>feat_conv4<span class="op">,</span> <span class="op">&</span>feat_conv5<span class="op">,</span></span>
<span id="cb2-84"><a href="#cb2-84" aria-hidden="true" tabindex="-1"></a> <span class="op">&</span>feat_bn1<span class="op">,</span> <span class="op">&</span>feat_bn2<span class="op">,</span> <span class="op">&</span>feat_bn3<span class="op">,</span> <span class="op">&</span>feat_bn4<span class="op">,</span> <span class="op">&</span>feat_bn5<span class="op">,</span></span>
<span id="cb2-85"><a href="#cb2-85" aria-hidden="true" tabindex="-1"></a> feat_params1<span class="op">,</span> feat_params2<span class="op">,</span> feat_params3<span class="op">,</span> feat_params4<span class="op">,</span> feat_params5<span class="op">);</span></span>
<span id="cb2-86"><a href="#cb2-86" aria-hidden="true" tabindex="-1"></a> <span class="co">// Initialize the classification network</span></span>
<span id="cb2-87"><a href="#cb2-87" aria-hidden="true" tabindex="-1"></a> InitializeClsNaive<span class="op"><</span><span class="dt">param_t</span><span class="op">>(</span></span>
<span id="cb2-88"><a href="#cb2-88" aria-hidden="true" tabindex="-1"></a> <span class="op">&</span>cls_fc3<span class="op">,</span> <span class="op">&</span>cls_bn1<span class="op">,</span> <span class="op">&</span>cls_bn2<span class="op">,</span></span>
<span id="cb2-89"><a href="#cb2-89" aria-hidden="true" tabindex="-1"></a> cls_params1<span class="op">,</span> cls_params2<span class="op">,</span> cls_params3<span class="op">);</span></span>
<span id="cb2-90"><a href="#cb2-90" aria-hidden="true" tabindex="-1"></a> <span class="op">}</span> <span class="cf">else</span> <span class="cf">if</span> <span class="op">(</span>op_mode <span class="op">==</span> kModeInference<span class="op">)</span> <span class="op">{</span></span>
<span id="cb2-91"><a href="#cb2-91" aria-hidden="true" tabindex="-1"></a> <span class="co">// Run the PointNet feature extraction</span></span>
<span id="cb2-92"><a href="#cb2-92" aria-hidden="true" tabindex="-1"></a> InferenceFeatNaive<span class="op"><</span><span class="dt">value_t</span><span class="op">,</span> <span class="dt">param_t</span><span class="op">,</span> <span class="dv">1024</span><span class="op">>(</span></span>
<span id="cb2-93"><a href="#cb2-93" aria-hidden="true" tabindex="-1"></a> point_cloud<span class="op">,</span> num_points<span class="op">,</span> feature<span class="op">,</span></span>
<span id="cb2-94"><a href="#cb2-94" aria-hidden="true" tabindex="-1"></a> <span class="op">&</span>feat_conv1<span class="op">,</span> <span class="op">&</span>feat_conv2<span class="op">,</span> <span class="op">&</span>feat_conv3<span class="op">,</span> <span class="op">&</span>feat_conv4<span class="op">,</span> <span class="op">&</span>feat_conv5<span class="op">,</span></span>
<span id="cb2-95"><a href="#cb2-95" aria-hidden="true" tabindex="-1"></a> <span class="op">&</span>feat_bn1<span class="op">,</span> <span class="op">&</span>feat_bn2<span class="op">,</span> <span class="op">&</span>feat_bn3<span class="op">,</span> <span class="op">&</span>feat_bn4<span class="op">,</span> <span class="op">&</span>feat_bn5<span class="op">);</span></span>
<span id="cb2-96"><a href="#cb2-96" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb2-97"><a href="#cb2-97" aria-hidden="true" tabindex="-1"></a> <span class="co">// Run the classification</span></span>
<span id="cb2-98"><a href="#cb2-98" aria-hidden="true" tabindex="-1"></a> InferenceClsNaive<span class="op"><</span><span class="dt">value_t</span><span class="op">,</span> <span class="dt">param_t</span><span class="op">>(</span></span>
<span id="cb2-99"><a href="#cb2-99" aria-hidden="true" tabindex="-1"></a> feature<span class="op">,</span> out_logits<span class="op">,</span></span>
<span id="cb2-100"><a href="#cb2-100" aria-hidden="true" tabindex="-1"></a> <span class="op">&</span>cls_fc3<span class="op">,</span> <span class="op">&</span>cls_bn1<span class="op">,</span> <span class="op">&</span>cls_bn2<span class="op">,</span></span>
<span id="cb2-101"><a href="#cb2-101" aria-hidden="true" tabindex="-1"></a> cls_params1<span class="op">,</span> cls_params2<span class="op">,</span> cls_params3<span class="op">);</span></span>
<span id="cb2-102"><a href="#cb2-102" aria-hidden="true" tabindex="-1"></a> <span class="op">}</span></span>
<span id="cb2-103"><a href="#cb2-103" aria-hidden="true" tabindex="-1"></a><span class="op">}</span></span></code></pre></div>
<p>上記を高位合成すると、次のようなIPコアが作られます。</p>
<p><a
href="point-cloud-classification-images/pointnet-ip-core.svg"><img src="point-cloud-classification-images/pointnet-ip-core.svg" width="50%" /></a></p>
<p>このIPコアを別のIPコアと組み合わせることで
(後述)、次のようなブロックデザインができます。</p>
<p><a
href="point-cloud-classification-images/board-design.svg"><img src="point-cloud-classification-images/board-design.svg" width="100%" /></a></p>
<p>このブロックデザインに対して、論理合成および配置配線することで、回路情報を表すビットストリーム
(Bitstream) を生成します。
ビットストリームをFPGAにロードすることで、PointNetの専用回路が使えるようになります。</p>
<h2 id="入出力ポート">入出力ポート</h2>
<p><code>PointNetClsTop</code>が、IPコアを表す最上位の関数です。
トップ関数 (Top function) とよばれます。
関数の引数は、IPコアの入出力ポートとなり、別のIPコアに接続されます
(上のブロックデザインをご覧ください)。 HLSでは、関数そのものが回路
(Verilog HDLにおけるモジュール) になります。
関数の再帰呼び出しはできません。</p>
<p>特徴抽出用のネットワークには5つのMLP、またクラス分類用のネットワークには3つのMLPが含まれます。
これらのパラメータは、ソフトウェア側から操作できるように、DRAM上のバッファに置かれます。
また、点群<span
class="math inline">\(\mathcal{P}\)</span>や、モデルの出力(ロジット)も同様に、DRAMバッファに置かれます。</p>
<p><code>feat_params1</code>から<code>feat_params5</code>までと、<code>cls_params1</code>から<code>cls_params3</code>までの8つのポートは、DRAMバッファ上のパラメータを、IPコア側から読み取るために使います。
<code>point_cloud</code>は点群の読み出し、<code>out_logits</code>はロジットの書き込みのために使います。
<code>op_mode</code>は回路の動作モード、<code>num_points</code>は点の個数<span
class="math inline">\(N\)</span>を設定するための制御レジスタです。</p>
<p><code>#pragma HLS</code>から始まる行は、高位合成ツールに対して、C/C++からRTLに変換する際のヒントを与えます
(必ずしも守ってくれるとは限りません)。
パイプライン化、データフロー最適化などはC/C++では記述できませんが、このような<strong>HLSプラグマ</strong>を適切な場所に置くことで、高位合成ツールが自動的にこれらの最適化を施してくれます。</p>
<p><code>#pragma HLS INLINE off</code>とすると、その関数がインライン展開されなくなります
(必ず、1つのモジュールとして作られる)。
大きな関数であれば、自動的にインライン展開されることはありませんが、念のため付与しています。
以下のような状況では、関数<code>B</code>をインライン展開しない方がいいと思います。
同時に使われないのにも関わらず、関数<code>A</code>の内部に<code>B</code>のコピーが3つ作られて、リソースの無駄遣いとなります。
関数<code>B</code>のインライン化を抑制して、<code>B</code>を1つだけ作り、それを使い回した方がいいでしょう。</p>
<div class="sourceCode" id="cb3"><pre
class="sourceCode c++"><code class="sourceCode cpp"><span id="cb3-1"><a href="#cb3-1" aria-hidden="true" tabindex="-1"></a><span class="dt">void</span> B<span class="op">(</span><span class="at">const</span> <span class="dt">float</span> x_in<span class="op">[</span><span class="dv">10</span><span class="op">],</span> <span class="dt">float</span> y_out<span class="op">[</span><span class="dv">10</span><span class="op">])</span></span>
<span id="cb3-2"><a href="#cb3-2" aria-hidden="true" tabindex="-1"></a><span class="op">{</span></span>
<span id="cb3-3"><a href="#cb3-3" aria-hidden="true" tabindex="-1"></a><span class="pp">#pragma HLS INLINE</span></span>
<span id="cb3-4"><a href="#cb3-4" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb3-5"><a href="#cb3-5" aria-hidden="true" tabindex="-1"></a> <span class="co">// 何らかの処理</span></span>
<span id="cb3-6"><a href="#cb3-6" aria-hidden="true" tabindex="-1"></a><span class="op">}</span></span>
<span id="cb3-7"><a href="#cb3-7" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb3-8"><a href="#cb3-8" aria-hidden="true" tabindex="-1"></a><span class="dt">void</span> A<span class="op">(</span><span class="at">const</span> <span class="dt">float</span> x_in<span class="op">[</span><span class="dv">10</span><span class="op">],</span> <span class="dt">float</span> y_out<span class="op">[</span><span class="dv">10</span><span class="op">])</span></span>
<span id="cb3-9"><a href="#cb3-9" aria-hidden="true" tabindex="-1"></a><span class="op">{</span></span>
<span id="cb3-10"><a href="#cb3-10" aria-hidden="true" tabindex="-1"></a> <span class="dt">float</span> x0<span class="op">[</span><span class="dv">10</span><span class="op">];</span></span>
<span id="cb3-11"><a href="#cb3-11" aria-hidden="true" tabindex="-1"></a> <span class="dt">float</span> x1<span class="op">[</span><span class="dv">10</span><span class="op">];</span></span>
<span id="cb3-12"><a href="#cb3-12" aria-hidden="true" tabindex="-1"></a> B<span class="op">(</span>x_in<span class="op">,</span> x0<span class="op">);</span></span>
<span id="cb3-13"><a href="#cb3-13" aria-hidden="true" tabindex="-1"></a> B<span class="op">(</span>x0<span class="op">,</span> x1<span class="op">);</span></span>
<span id="cb3-14"><a href="#cb3-14" aria-hidden="true" tabindex="-1"></a> B<span class="op">(</span>x1<span class="op">,</span> y_out<span class="op">);</span></span>
<span id="cb3-15"><a href="#cb3-15" aria-hidden="true" tabindex="-1"></a><span class="op">}</span></span></code></pre></div>
<p><code>#pragma HLS INTERFACE m_axi</code>と、<code>#pragma HLS INTERFACE s_axilite</code>の記述が目立ちますが、入出力ポート
(例えば<code>feat_params1</code>)
に対してこの2つのHLSプラグマを記述すると、IPコア側からDRAMバッファを読み書きできるようになります。
読み書きの際には、AXIとよばれるプロトコルを使用しますが、<code>#pragma HLS INTERFACE m_axi</code>によってそれを指定できます
(IPコア側がマスターになります)。</p>
<p>ソフトウェア側からは、各ポートに対して、バッファの物理アドレスを割り当てて、ポートとバッファを紐づけます。
各ポートには、物理アドレスを設定するための制御レジスタを作成する必要があり、<code>#pragma HLS INTERFACE s_axilite</code>によってそれを実現できます
(IPコア側からみるとスレーブです)。
<code>op_mode</code>、<code>num_points</code>に対してもレジスタを作成します。
<code>port=return</code>としている行は、IPコア用の制御レジスタを作成し、CPU側からIPコアの動作を開始したり、状態
(アイドル状態なのか動作中か) を読み取ったりするために必要です。
これらのレジスタは、ソフトウェア側から、メモリマップトI/OおよびAXI-Liteプロトコルによって読み書きされます。</p>
<p>各入出力ポートからは、PyTorchのモデルで定義した、各層のパラメータが読み出されます
(一次元の配列として、全てのパラメータが連結されます)。</p>
<ul>
<li><code>feat_params1</code>: <code>PointNetFeat::conv1</code> +
<code>PointNetFeat::bn1</code>のパラメータ</li>
<li><code>feat_params2</code>: <code>PointNetFeat::conv2</code> +
<code>PointNetFeat::bn2</code>のパラメータ</li>
<li><code>feat_params3</code>: <code>PointNetFeat::conv3</code> +
<code>PointNetFeat::bn3</code>のパラメータ</li>
<li><code>feat_params4</code>: <code>PointNetFeat::conv4</code> +
<code>PointNetFeat::bn4</code>のパラメータ</li>
<li><code>feat_params5</code>: <code>PointNetFeat::conv5</code> +
<code>PointNetFeat::bn5</code>のパラメータ</li>
<li><code>cls_params1</code>: <code>PointNetCls::fc1</code> +
<code>PointNetCls::bn1</code>のパラメータ</li>
<li><code>cls_params2</code>: <code>PointNetCls::fc2</code> +
<code>PointNetCls::bn2</code>のパラメータ</li>
<li><code>cls_params3</code>:
<code>PointNetCls::fc3</code>のパラメータ</li>
</ul>
<div class="sourceCode" id="cb4"><pre
class="sourceCode c++"><code class="sourceCode cpp"><span id="cb4-1"><a href="#cb4-1" aria-hidden="true" tabindex="-1"></a><span class="dt">void</span> PointNetClsTop<span class="op">(</span><span class="at">const</span> <span class="dt">int</span> op_mode<span class="op">,</span></span>
<span id="cb4-2"><a href="#cb4-2" aria-hidden="true" tabindex="-1"></a> <span class="at">const</span> <span class="dt">float</span><span class="op">*</span> point_cloud<span class="op">,</span></span>
<span id="cb4-3"><a href="#cb4-3" aria-hidden="true" tabindex="-1"></a> <span class="at">const</span> <span class="dt">int</span> num_points<span class="op">,</span></span>
<span id="cb4-4"><a href="#cb4-4" aria-hidden="true" tabindex="-1"></a> <span class="dt">float</span><span class="op">*</span> out_logits<span class="op">,</span></span>
<span id="cb4-5"><a href="#cb4-5" aria-hidden="true" tabindex="-1"></a> <span class="at">const</span> <span class="dt">float</span><span class="op">*</span> feat_params1<span class="op">,</span></span>
<span id="cb4-6"><a href="#cb4-6" aria-hidden="true" tabindex="-1"></a> <span class="at">const</span> <span class="dt">float</span><span class="op">*</span> feat_params2<span class="op">,</span></span>
<span id="cb4-7"><a href="#cb4-7" aria-hidden="true" tabindex="-1"></a> <span class="at">const</span> <span class="dt">float</span><span class="op">*</span> feat_params3<span class="op">,</span></span>
<span id="cb4-8"><a href="#cb4-8" aria-hidden="true" tabindex="-1"></a> <span class="at">const</span> <span class="dt">float</span><span class="op">*</span> feat_params4<span class="op">,</span></span>
<span id="cb4-9"><a href="#cb4-9" aria-hidden="true" tabindex="-1"></a> <span class="at">const</span> <span class="dt">float</span><span class="op">*</span> feat_params5<span class="op">,</span></span>
<span id="cb4-10"><a href="#cb4-10" aria-hidden="true" tabindex="-1"></a> <span class="at">const</span> <span class="dt">float</span><span class="op">*</span> cls_params1<span class="op">,</span></span>
<span id="cb4-11"><a href="#cb4-11" aria-hidden="true" tabindex="-1"></a> <span class="at">const</span> <span class="dt">float</span><span class="op">*</span> cls_params2<span class="op">,</span></span>
<span id="cb4-12"><a href="#cb4-12" aria-hidden="true" tabindex="-1"></a> <span class="at">const</span> <span class="dt">float</span><span class="op">*</span> cls_params3<span class="op">)</span></span>
<span id="cb4-13"><a href="#cb4-13" aria-hidden="true" tabindex="-1"></a><span class="op">{</span></span>
<span id="cb4-14"><a href="#cb4-14" aria-hidden="true" tabindex="-1"></a> <span class="co">// ...</span></span>
<span id="cb4-15"><a href="#cb4-15" aria-hidden="true" tabindex="-1"></a><span class="op">}</span></span></code></pre></div>
<h2 id="各層のパラメータと処理">各層のパラメータと処理</h2>
<p><code>torch.nn.Conv1d</code>および<code>torch.nn.Linear</code>のパラメータとしては、重みとバイアスが挙げられます。
<code>Conv1d</code>とありますが、カーネルサイズは1なので、<code>Linear</code>と動作が同じになります。
以後、<code>Conv1d</code>と<code>Linear</code>を同一視します。
入力と出力の次元数を<span
class="math inline">\(\mathrm{InDims}\)</span>、<span
class="math inline">\(\mathrm{OutDims}\)</span>とすると、重みとバイアスのサイズは<span
class="math inline">\((\mathrm{OutDims},
\mathrm{InDims})\)</span>、<span
class="math inline">\((\mathrm{OutDims})\)</span>となります。 入力<span
class="math inline">\(\boldsymbol{x} \in
\mathbb{R}^{\mathrm{InDims}}\)</span>、重み<span
class="math inline">\(\boldsymbol{W} \in \mathbb{R}^{\mathrm{OutDims}
\times \mathrm{InDims}}\)</span>、バイアス<span
class="math inline">\(\boldsymbol{b} \in
\mathbb{R}^{\mathrm{OutDims}}\)</span>があるとき、出力<span
class="math inline">\(\boldsymbol{y} \in
\mathbb{R}^{\mathrm{OutDims}}\)</span>は次のように計算されます。 <span
class="math display">\[
\boldsymbol{y} = \boldsymbol{W} \boldsymbol{x} + \boldsymbol{b}
\]</span></p>
<p><code>torch.nn.BatchNorm1d</code>のパラメータとしては、平均、標準偏差、重み、バイアスの4つが挙げられます。
入出力の次元を<span
class="math inline">\(\mathrm{Dims}\)</span>とすると、これら4つのパラメータのサイズは<span
class="math inline">\((\mathrm{Dims})\)</span>です。
平均、標準偏差、重み、バイアス<span
class="math inline">\(\boldsymbol{\mu}, \boldsymbol{\sigma},
\boldsymbol{w}, \boldsymbol{b} \in
\mathbb{R}^{\mathrm{Dims}}\)</span>があるとき、入力<span
class="math inline">\(\boldsymbol{x} \in
\mathbb{R}^{\mathrm{Dims}}\)</span>に対して出力<span
class="math inline">\(\boldsymbol{y} \in
\mathbb{R}^{\mathrm{Dims}}\)</span>は次のように計算されます。 <span
class="math display">\[
y_i = \frac{x_i - \mu_i}{\sqrt{\sigma_i^2 + \varepsilon}} \cdot w_i +
b_i \quad (i = 1, \ldots, \mathrm{Dims})
\]</span> <span
class="math inline">\(\varepsilon\)</span>は、ゼロ除算を防ぐための小さな正の値です。
<span class="math inline">\(x_i\)</span>は、<span
class="math inline">\(\boldsymbol{x}\)</span>の第<span
class="math inline">\(i\)</span>要素です (他も同様)。
上記をみると、<span class="math inline">\(w_i / \sqrt{\sigma_i^2 +
\varepsilon}\)</span>の部分を先に計算できることが分かります。 <span
class="math inline">\(\boldsymbol{w}\)</span>と<span
class="math inline">\(\boldsymbol{\sigma}\)</span>の両方を使う場合と比べて、除算および平方根の計算を省略できます。
また、オンチップバッファの使用量を削減できます。
細かい話にみえますが、リソース制約の大きなFPGA上に実装する場合は重要です。
バッチ正規化の計算は以下のようにします。 <span class="math display">\[
y_i = \left( x_i - \mu_i \right) \cdot s_i + b_i \quad (i = 1, \ldots,
\mathrm{Dims})
\]</span> 上記の<span
class="math inline">\(s_i\)</span>を、ここでは<strong>スケール</strong>と呼ぶことにします。
パラメータは、平均<span
class="math inline">\(\boldsymbol{\mu}\)</span>、バイアス<span
class="math inline">\(\boldsymbol{b}\)</span>、スケール<span
class="math inline">\(\boldsymbol{s} \in
\mathbb{R}^{\mathrm{Dims}}\)</span>の3つになります。 <span
class="math inline">\(\boldsymbol{s}\)</span>の計算は、モデルの初期化時にソフトウェア上で行うことにします。</p>
<p>バッチ正規化の後にReLU活性化が計算されます。
各層を別々に実装するよりも、まとめてしまった方が効率がよいので、バッチ正規化とReLU活性化を次のようにまとめます
(<strong>最適化その3: 計算の簡略化</strong>)。 <span
class="math display">\[
y_i = \max \left( 0, \left( x_i - \mu_i \right) \cdot s_i + b_i
\right) \quad (i = 1, \ldots, \mathrm{Dims})
\]</span></p>
<p>最後にMaxプーリング層ですが、先述の通り、各点に対するローカル特徴量<span
class="math inline">\(\boldsymbol{\psi}_i \in
\mathbb{R}^{1024}\)</span>と、現在のグローバル特徴量<span
class="math inline">\(\boldsymbol{\phi} \in
\mathbb{R}^{1024}\)</span>との、要素ごとの<span
class="math inline">\(\max\)</span>に置き換えました。
Maxプーリング層の計算は次のようになります。 <span
class="math display">\[
\phi_i = \max \left( \phi_i, \psi_i \right) \quad (i = 1, \ldots,
1024)
\]</span></p>
<p>さて、ソースコードの<code>LinearParams<T, InDims_, OutDims_></code>構造体と、<code>BatchNorm1dParams<T, Dims_></code>構造体は、全結合層
(<code>Conv1d</code>および<code>Linear</code>) と、バッチ正規化層
(<code>BatchNorm1d</code>) のパラメータをそれぞれまとめたものです。</p>
<div class="sourceCode" id="cb5"><pre
class="sourceCode c++"><code class="sourceCode cpp"><span id="cb5-1"><a href="#cb5-1" aria-hidden="true" tabindex="-1"></a><span class="co">// Parameters for fully-connected layers</span></span>
<span id="cb5-2"><a href="#cb5-2" aria-hidden="true" tabindex="-1"></a><span class="kw">template</span> <span class="op"><</span><span class="kw">typename</span> T<span class="op">,</span> <span class="dt">int</span> InDims_<span class="op">,</span> <span class="dt">int</span> OutDims_<span class="op">></span></span>
<span id="cb5-3"><a href="#cb5-3" aria-hidden="true" tabindex="-1"></a><span class="kw">struct</span> LinearParams</span>
<span id="cb5-4"><a href="#cb5-4" aria-hidden="true" tabindex="-1"></a><span class="op">{</span></span>
<span id="cb5-5"><a href="#cb5-5" aria-hidden="true" tabindex="-1"></a> <span class="kw">enum</span></span>
<span id="cb5-6"><a href="#cb5-6" aria-hidden="true" tabindex="-1"></a> <span class="op">{</span></span>
<span id="cb5-7"><a href="#cb5-7" aria-hidden="true" tabindex="-1"></a> InDims <span class="op">=</span> InDims_<span class="op">,</span></span>
<span id="cb5-8"><a href="#cb5-8" aria-hidden="true" tabindex="-1"></a> OutDims <span class="op">=</span> OutDims_<span class="op">,</span></span>
<span id="cb5-9"><a href="#cb5-9" aria-hidden="true" tabindex="-1"></a> <span class="op">};</span></span>
<span id="cb5-10"><a href="#cb5-10" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb5-11"><a href="#cb5-11" aria-hidden="true" tabindex="-1"></a> T weight<span class="op">[</span>OutDims<span class="op">][</span>InDims<span class="op">];</span></span>
<span id="cb5-12"><a href="#cb5-12" aria-hidden="true" tabindex="-1"></a> T bias<span class="op">[</span>OutDims<span class="op">];</span></span>
<span id="cb5-13"><a href="#cb5-13" aria-hidden="true" tabindex="-1"></a><span class="op">};</span></span>
<span id="cb5-14"><a href="#cb5-14" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb5-15"><a href="#cb5-15" aria-hidden="true" tabindex="-1"></a><span class="co">// Parameters for 1D batch normalization layers</span></span>
<span id="cb5-16"><a href="#cb5-16" aria-hidden="true" tabindex="-1"></a><span class="kw">template</span> <span class="op"><</span><span class="kw">typename</span> T<span class="op">,</span> <span class="dt">int</span> Dims_<span class="op">></span></span>
<span id="cb5-17"><a href="#cb5-17" aria-hidden="true" tabindex="-1"></a><span class="kw">struct</span> BatchNorm1dParams</span>
<span id="cb5-18"><a href="#cb5-18" aria-hidden="true" tabindex="-1"></a><span class="op">{</span></span>
<span id="cb5-19"><a href="#cb5-19" aria-hidden="true" tabindex="-1"></a> <span class="kw">enum</span></span>
<span id="cb5-20"><a href="#cb5-20" aria-hidden="true" tabindex="-1"></a> <span class="op">{</span></span>
<span id="cb5-21"><a href="#cb5-21" aria-hidden="true" tabindex="-1"></a> Dims <span class="op">=</span> Dims_<span class="op">,</span></span>
<span id="cb5-22"><a href="#cb5-22" aria-hidden="true" tabindex="-1"></a> <span class="op">};</span></span>
<span id="cb5-23"><a href="#cb5-23" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb5-24"><a href="#cb5-24" aria-hidden="true" tabindex="-1"></a> <span class="co">// `scale` is obtained by multiplying weights and reciprocal of the</span></span>
<span id="cb5-25"><a href="#cb5-25" aria-hidden="true" tabindex="-1"></a> <span class="co">// square root of the standard deviation (to reduce the computational cost)</span></span>
<span id="cb5-26"><a href="#cb5-26" aria-hidden="true" tabindex="-1"></a> T scale<span class="op">[</span>Dims<span class="op">];</span></span>
<span id="cb5-27"><a href="#cb5-27" aria-hidden="true" tabindex="-1"></a> T bias<span class="op">[</span>Dims<span class="op">];</span></span>
<span id="cb5-28"><a href="#cb5-28" aria-hidden="true" tabindex="-1"></a> T mean<span class="op">[</span>Dims<span class="op">];</span></span>
<span id="cb5-29"><a href="#cb5-29" aria-hidden="true" tabindex="-1"></a><span class="op">};</span></span></code></pre></div>
<p><code>PointNetClsTop</code>内では、PyTorchで定義されたモデルの各層に対応して、以下のようなパラメータが宣言されます。</p>
<ul>
<li><code>feat_conv1</code>:
<code>PointNetFeat::conv1</code>の重み、バイアス</li>
<li><code>feat_conv2</code>:
<code>PointNetFeat::conv2</code>の重み、バイアス</li>
<li><code>feat_conv3</code>:
<code>PointNetFeat::conv3</code>の重み、バイアス</li>
<li><code>feat_conv4</code>:
<code>PointNetFeat::conv4</code>の重み、バイアス</li>
<li><code>feat_conv5</code>:
<code>PointNetFeat::conv5</code>の重み、バイアス</li>
<li><code>feat_bn1</code>:
<code>PointNetFeat::bn1</code>の平均、バイアス、スケール</li>
<li><code>feat_bn2</code>:
<code>PointNetFeat::bn2</code>の平均、バイアス、スケール</li>
<li><code>feat_bn3</code>:
<code>PointNetFeat::bn3</code>の平均、バイアス、スケール</li>
<li><code>feat_bn4</code>:
<code>PointNetFeat::bn4</code>の平均、バイアス、スケール</li>
<li><code>feat_bn5</code>:
<code>PointNetFeat::bn5</code>の平均、バイアス、スケール</li>
<li><code>cls_fc3</code>:
<code>PointNetCls::fc3</code>の重み、バイアス</li>
<li><code>cls_bn1</code>:
<code>PointNetCls::bn1</code>の平均、バイアス、スケール</li>
<li><code>cls_bn2</code>:
<code>PointNetCls::bn2</code>の平均、バイアス、スケール</li>
</ul>
<p>特徴抽出ネットワークの全ての層のパラメータは、推論を開始する前に予め、オンチップメモリ上に置いておきます。
一方、分類ネットワークの全結合層2つ
(<code>PointNetCls::fc1</code>、<code>PointNetCls::fc2</code>)
のパラメータは、オンチップメモリ上には置かないようにします。
パラメータサイズが大きく、オンチップメモリが不足するためです。
これらの層については、推論時にDRAMバッファから読み出します。
言い換えると、パラメータの一部をDRAMバッファから取り出して、出力の一部を計算することを繰り返します。
一部のパラメータを保持するために、小さなオンチップバッファを用意すればよくなります。</p>
<p>特徴抽出ネットワークについては、<span
class="math inline">\(N\)</span>個全ての点に対して特徴抽出を行うために、<span
class="math inline">\(N\)</span>回の順伝播が起こります。
推論時間のなかで占める割合が大きいので、1回の順伝播に要する計算時間をうまく短縮できれば、全体の推論時間の大幅な短縮につながります
(<strong>アムダールの法則</strong>)。
一方、分類ネットワークの順伝播は1度だけで、推論時間のなかではそれほど重要ではありません。
パラメータをオンチップメモリに事前に格納するのと比べて、推論時にDRAMバッファから読み出すと、層の計算時間は伸びてしまいますが、推論時間に与える影響はそれほど大きくありません。</p>
<h2 id="データ型">データ型</h2>
<p>Vitis
HLSでは、任意精度の<strong>固定</strong>小数点数型<code>ap_fixed</code>が用意されています。
単精度浮動小数点数<code>float</code>や、半精度浮動小数点数<code>half</code>も利用できます。
ここではリソース消費を抑えるために、固定小数点数を使います。</p>
<p>デフォルトのオーバーフローモード (<code>ap_o_mode::AP_WRAP</code>)
では、値がオーバーフローしたときに折り返します。
これだと、最大値から急に最小値になったりして危なっかしいので、最大値あるいは最小値に留まり続けるように、飽和モード
(<code>ap_o_mode::AP_SAT</code>) に変更しています。
飽和モードを使う固定小数点数型を、<code>ap_fixed_sat</code>として定義しました。</p>
<p>ニューラルネットの入出力とパラメータとでビット幅を変えるために、入出力用とパラメータ用に別々の型を用意しました
(<code>param_t</code>および<code>value_t</code>)。
パラメータの値域に合わせて、ビット幅を削減できるかもしれません。
ビット幅の削減や量子化、小数点型のフォーマットなどは、それ自体が立派な研究分野となっています。</p>
<div class="sourceCode" id="cb6"><pre
class="sourceCode c++"><code class="sourceCode cpp"><span id="cb6-1"><a href="#cb6-1" aria-hidden="true" tabindex="-1"></a><span class="co">// Value types</span></span>
<span id="cb6-2"><a href="#cb6-2" aria-hidden="true" tabindex="-1"></a><span class="kw">template</span> <span class="op"><</span><span class="dt">int</span> _AP_W<span class="op">,</span> <span class="dt">int</span> _AP_I<span class="op">></span></span>
<span id="cb6-3"><a href="#cb6-3" aria-hidden="true" tabindex="-1"></a><span class="kw">using</span> ap_fixed_sat <span class="op">=</span> ap_fixed<span class="op"><</span></span>
<span id="cb6-4"><a href="#cb6-4" aria-hidden="true" tabindex="-1"></a> _AP_W<span class="op">,</span> _AP_I<span class="op">,</span> ap_q_mode<span class="op">::</span>AP_TRN<span class="op">,</span> ap_o_mode<span class="op">::</span>AP_SAT<span class="op">,</span> <span class="dv">0</span><span class="op">>;</span></span>
<span id="cb6-5"><a href="#cb6-5" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb6-6"><a href="#cb6-6" aria-hidden="true" tabindex="-1"></a><span class="co">// Data type for values (layer inputs, outputs, and intermediate results)</span></span>
<span id="cb6-7"><a href="#cb6-7" aria-hidden="true" tabindex="-1"></a><span class="kw">using</span> <span class="dt">value_t</span> <span class="op">=</span> ap_fixed_sat<span class="op"><</span>kValueBitWidth<span class="op">,</span> kValueIntWidth<span class="op">>;</span></span>
<span id="cb6-8"><a href="#cb6-8" aria-hidden="true" tabindex="-1"></a><span class="co">// Data type for network parameters</span></span>
<span id="cb6-9"><a href="#cb6-9" aria-hidden="true" tabindex="-1"></a><span class="kw">using</span> <span class="dt">param_t</span> <span class="op">=</span> ap_fixed_sat<span class="op"><</span>kParamBitWidth<span class="op">,</span> kParamIntWidth<span class="op">>;</span></span></code></pre></div>
<h2 id="動作モード">動作モード</h2>
<p>さて、ここで示すIPコアには、2つの<strong>動作モード</strong>
(Operation mode) が用意されています。</p>
<ul>
<li>重み初期化モード (<code>kModeInitWeights</code>):
重みをDRAMバッファから読み取って、オンチップバッファに格納する。</li>
<li>推論モード (<code>kModeInference</code>):
入力点群から、各クラスのロジットを計算する。</li>
</ul>
<p>これらを順に説明します。</p>
<h3 id="重み初期化モード">重み初期化モード</h3>
<p>特徴抽出ネットワークの全パラメータと、分類ネットワークのパラメータの一部を、DRAMバッファから読み取って、オンチップバッファに格納します。
以下に示す、<code>InitializeFeatNaive</code>および<code>InitializeClsNaive</code>を利用します。
それぞれ、特徴抽出ネットワークと、分類ネットワークのための関数です。</p>
<div class="sourceCode" id="cb7"><pre
class="sourceCode c++"><code class="sourceCode cpp"><span id="cb7-1"><a href="#cb7-1" aria-hidden="true" tabindex="-1"></a><span class="co">// Naive implementation of the parameter initialization</span></span>
<span id="cb7-2"><a href="#cb7-2" aria-hidden="true" tabindex="-1"></a><span class="co">// `T` is the type for parameters</span></span>
<span id="cb7-3"><a href="#cb7-3" aria-hidden="true" tabindex="-1"></a><span class="kw">template</span> <span class="op"><</span><span class="kw">typename</span> T<span class="op">></span></span>
<span id="cb7-4"><a href="#cb7-4" aria-hidden="true" tabindex="-1"></a><span class="dt">void</span> InitializeFeatNaive<span class="op">(</span>LinearParams<span class="op"><</span>T<span class="op">,</span> kFeatDims0<span class="op">,</span> kFeatDims1<span class="op">>*</span> conv1<span class="op">,</span></span>
<span id="cb7-5"><a href="#cb7-5" aria-hidden="true" tabindex="-1"></a> LinearParams<span class="op"><</span>T<span class="op">,</span> kFeatDims1<span class="op">,</span> kFeatDims2<span class="op">>*</span> conv2<span class="op">,</span></span>
<span id="cb7-6"><a href="#cb7-6" aria-hidden="true" tabindex="-1"></a> LinearParams<span class="op"><</span>T<span class="op">,</span> kFeatDims2<span class="op">,</span> kFeatDims3<span class="op">>*</span> conv3<span class="op">,</span></span>
<span id="cb7-7"><a href="#cb7-7" aria-hidden="true" tabindex="-1"></a> LinearParams<span class="op"><</span>T<span class="op">,</span> kFeatDims3<span class="op">,</span> kFeatDims4<span class="op">>*</span> conv4<span class="op">,</span></span>
<span id="cb7-8"><a href="#cb7-8" aria-hidden="true" tabindex="-1"></a> LinearParams<span class="op"><</span>T<span class="op">,</span> kFeatDims4<span class="op">,</span> kFeatDims5<span class="op">>*</span> conv5<span class="op">,</span></span>
<span id="cb7-9"><a href="#cb7-9" aria-hidden="true" tabindex="-1"></a> BatchNorm1dParams<span class="op"><</span>T<span class="op">,</span> kFeatDims1<span class="op">>*</span> bn1<span class="op">,</span></span>
<span id="cb7-10"><a href="#cb7-10" aria-hidden="true" tabindex="-1"></a> BatchNorm1dParams<span class="op"><</span>T<span class="op">,</span> kFeatDims2<span class="op">>*</span> bn2<span class="op">,</span></span>
<span id="cb7-11"><a href="#cb7-11" aria-hidden="true" tabindex="-1"></a> BatchNorm1dParams<span class="op"><</span>T<span class="op">,</span> kFeatDims3<span class="op">>*</span> bn3<span class="op">,</span></span>
<span id="cb7-12"><a href="#cb7-12" aria-hidden="true" tabindex="-1"></a> BatchNorm1dParams<span class="op"><</span>T<span class="op">,</span> kFeatDims4<span class="op">>*</span> bn4<span class="op">,</span></span>
<span id="cb7-13"><a href="#cb7-13" aria-hidden="true" tabindex="-1"></a> BatchNorm1dParams<span class="op"><</span>T<span class="op">,</span> kFeatDims5<span class="op">>*</span> bn5<span class="op">,</span></span>
<span id="cb7-14"><a href="#cb7-14" aria-hidden="true" tabindex="-1"></a> <span class="at">const</span> <span class="dt">float</span><span class="op">*</span> params1<span class="op">,</span></span>
<span id="cb7-15"><a href="#cb7-15" aria-hidden="true" tabindex="-1"></a> <span class="at">const</span> <span class="dt">float</span><span class="op">*</span> params2<span class="op">,</span></span>
<span id="cb7-16"><a href="#cb7-16" aria-hidden="true" tabindex="-1"></a> <span class="at">const</span> <span class="dt">float</span><span class="op">*</span> params3<span class="op">,</span></span>
<span id="cb7-17"><a href="#cb7-17" aria-hidden="true" tabindex="-1"></a> <span class="at">const</span> <span class="dt">float</span><span class="op">*</span> params4<span class="op">,</span></span>
<span id="cb7-18"><a href="#cb7-18" aria-hidden="true" tabindex="-1"></a> <span class="at">const</span> <span class="dt">float</span><span class="op">*</span> params5<span class="op">)</span></span>
<span id="cb7-19"><a href="#cb7-19" aria-hidden="true" tabindex="-1"></a><span class="op">{</span></span>
<span id="cb7-20"><a href="#cb7-20" aria-hidden="true" tabindex="-1"></a><span class="pp">#pragma HLS INLINE off</span></span>
<span id="cb7-21"><a href="#cb7-21" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb7-22"><a href="#cb7-22" aria-hidden="true" tabindex="-1"></a> ReadBlockParamsNaive<span class="op"><</span>T<span class="op">,</span> kFeatDims0<span class="op">,</span> kFeatDims1<span class="op">>(</span>conv1<span class="op">,</span> bn1<span class="op">,</span> params1<span class="op">);</span></span>
<span id="cb7-23"><a href="#cb7-23" aria-hidden="true" tabindex="-1"></a> ReadBlockParamsNaive<span class="op"><</span>T<span class="op">,</span> kFeatDims1<span class="op">,</span> kFeatDims2<span class="op">>(</span>conv2<span class="op">,</span> bn2<span class="op">,</span> params2<span class="op">);</span></span>
<span id="cb7-24"><a href="#cb7-24" aria-hidden="true" tabindex="-1"></a> ReadBlockParamsNaive<span class="op"><</span>T<span class="op">,</span> kFeatDims2<span class="op">,</span> kFeatDims3<span class="op">>(</span>conv3<span class="op">,</span> bn3<span class="op">,</span> params3<span class="op">);</span></span>
<span id="cb7-25"><a href="#cb7-25" aria-hidden="true" tabindex="-1"></a> ReadBlockParamsNaive<span class="op"><</span>T<span class="op">,</span> kFeatDims3<span class="op">,</span> kFeatDims4<span class="op">>(</span>conv4<span class="op">,</span> bn4<span class="op">,</span> params4<span class="op">);</span></span>
<span id="cb7-26"><a href="#cb7-26" aria-hidden="true" tabindex="-1"></a> ReadBlockParamsNaive<span class="op"><</span>T<span class="op">,</span> kFeatDims4<span class="op">,</span> kFeatDims5<span class="op">>(</span>conv5<span class="op">,</span> bn5<span class="op">,</span> params5<span class="op">);</span></span>
<span id="cb7-27"><a href="#cb7-27" aria-hidden="true" tabindex="-1"></a><span class="op">}</span></span>
<span id="cb7-28"><a href="#cb7-28" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb7-29"><a href="#cb7-29" aria-hidden="true" tabindex="-1"></a><span class="co">// Naive implementation of the parameter initialization</span></span>
<span id="cb7-30"><a href="#cb7-30" aria-hidden="true" tabindex="-1"></a><span class="co">// `T` is the type for parameters</span></span>
<span id="cb7-31"><a href="#cb7-31" aria-hidden="true" tabindex="-1"></a><span class="kw">template</span> <span class="op"><</span><span class="kw">typename</span> T<span class="op">></span></span>
<span id="cb7-32"><a href="#cb7-32" aria-hidden="true" tabindex="-1"></a><span class="dt">void</span> InitializeClsNaive<span class="op">(</span>LinearParams<span class="op"><</span>T<span class="op">,</span> kClsDims2<span class="op">,</span> kClsDims3<span class="op">>*</span> fc3<span class="op">,</span></span>
<span id="cb7-33"><a href="#cb7-33" aria-hidden="true" tabindex="-1"></a> BatchNorm1dParams<span class="op"><</span>T<span class="op">,</span> kClsDims1<span class="op">>*</span> bn1<span class="op">,</span></span>
<span id="cb7-34"><a href="#cb7-34" aria-hidden="true" tabindex="-1"></a> BatchNorm1dParams<span class="op"><</span>T<span class="op">,</span> kClsDims2<span class="op">>*</span> bn2<span class="op">,</span></span>
<span id="cb7-35"><a href="#cb7-35" aria-hidden="true" tabindex="-1"></a> <span class="at">const</span> <span class="dt">float</span><span class="op">*</span> params1<span class="op">,</span></span>
<span id="cb7-36"><a href="#cb7-36" aria-hidden="true" tabindex="-1"></a> <span class="at">const</span> <span class="dt">float</span><span class="op">*</span> params2<span class="op">,</span></span>
<span id="cb7-37"><a href="#cb7-37" aria-hidden="true" tabindex="-1"></a> <span class="at">const</span> <span class="dt">float</span><span class="op">*</span> params3<span class="op">)</span></span>
<span id="cb7-38"><a href="#cb7-38" aria-hidden="true" tabindex="-1"></a><span class="op">{</span></span>
<span id="cb7-39"><a href="#cb7-39" aria-hidden="true" tabindex="-1"></a><span class="pp">#pragma HLS INLINE off</span></span>
<span id="cb7-40"><a href="#cb7-40" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb7-41"><a href="#cb7-41" aria-hidden="true" tabindex="-1"></a> ReadBatchNorm1dParamsNaive<span class="op"><</span>T<span class="op">,</span> kClsDims1<span class="op">>(</span></span>
<span id="cb7-42"><a href="#cb7-42" aria-hidden="true" tabindex="-1"></a> bn1<span class="op">,</span> params1<span class="op">,</span> kClsDims0 <span class="op">*</span> kClsDims1 <span class="op">+</span> kClsDims1<span class="op">);</span></span>
<span id="cb7-43"><a href="#cb7-43" aria-hidden="true" tabindex="-1"></a> ReadBatchNorm1dParamsNaive<span class="op"><</span>T<span class="op">,</span> kClsDims2<span class="op">>(</span></span>
<span id="cb7-44"><a href="#cb7-44" aria-hidden="true" tabindex="-1"></a> bn2<span class="op">,</span> params2<span class="op">,</span> kClsDims1 <span class="op">*</span> kClsDims2 <span class="op">+</span> kClsDims2<span class="op">);</span></span>
<span id="cb7-45"><a href="#cb7-45" aria-hidden="true" tabindex="-1"></a> ReadLinearParamsNaive<span class="op"><</span>T<span class="op">,</span> kClsDims2<span class="op">,</span> kClsDims3<span class="op">>(</span></span>
<span id="cb7-46"><a href="#cb7-46" aria-hidden="true" tabindex="-1"></a> fc3<span class="op">,</span> params3<span class="op">,</span> <span class="dv">0</span><span class="op">);</span></span>
<span id="cb7-47"><a href="#cb7-47" aria-hidden="true" tabindex="-1"></a><span class="op">}</span></span></code></pre></div>
<p>これらの関数のなかでは、<code>ReadBlockParamsNaive</code>、<code>ReadLinearParamsNaive</code>、そして<code>ReadBatchNorm1dParamsNaive</code>の3つの関数を呼び出しています。
各関数は次のような動作です (詳細はソースコードをご参照ください)。
DRAMバッファ上には<code>float</code>型で置かれていますが、これを固定小数点数型に直す処理も含まれます。</p>
<ul>
<li><code>ReadLinearParamsNaive<T, InDims, OutDims></code>:
DRAMバッファから、全結合層
(<code>Conv1d</code>および<code>Linear</code>)
の重みとバイアスを読み取る。
重みのサイズは<code>(OutDims, InDims)</code>、バイアスのサイズは<code>(OutDims)</code>である。
2つのパラメータは、1次元の配列として連結されているとする
(配列のサイズは<code>OutDims * InDims + OutDims</code>)。</li>
<li><code>ReadBatchNorm1dParamsNaive<T, Dims></code>:
DRAMバッファから、バッチ正規化層 (<code>BatchNorm1d</code>)
のスケール、バイアス、平均を読み取る。
パラメータのサイズは<code>(Dims)</code>である。
3つのパラメータは、1次元の配列として連結されているとする
(配列のサイズは<code>3 * Dims</code>)。</li>
<li><code>ReadBlockParamsNaive<T, InDims, OutDims</code>:
DRAMバッファから、全結合層およびバッチ正規化層のパラメータ5つを読み取る。
5つのパラメータは、1次元の配列として連結されているとする
(配列のサイズは<code>OutDims * InDims + 4 * OutDims</code>)。</li>
</ul>
<h3 id="推論モード">推論モード</h3>
<p>入力点群から、各クラスのロジットを計算します。
以下に示す、<code>InferenceFeatNaive</code>および<code>InferenceClsNaive</code>を利用します。
それぞれ、特徴抽出ネットワークと、分類ネットワークの処理です。</p>
<div class="sourceCode" id="cb8"><pre
class="sourceCode c++"><code class="sourceCode cpp"><span id="cb8-1"><a href="#cb8-1" aria-hidden="true" tabindex="-1"></a><span class="co">// Naive implementation of the PointNet feature extraction</span></span>
<span id="cb8-2"><a href="#cb8-2" aria-hidden="true" tabindex="-1"></a><span class="co">// `T` is the type for layer input, output, and intermediate results</span></span>
<span id="cb8-3"><a href="#cb8-3" aria-hidden="true" tabindex="-1"></a><span class="co">// `U` is the type for parameters</span></span>
<span id="cb8-4"><a href="#cb8-4" aria-hidden="true" tabindex="-1"></a><span class="co">// `N` is the expected number of input points (e.g., 1024)</span></span>
<span id="cb8-5"><a href="#cb8-5" aria-hidden="true" tabindex="-1"></a><span class="kw">template</span> <span class="op"><</span><span class="kw">typename</span> T<span class="op">,</span> <span class="kw">typename</span> U<span class="op">,</span> <span class="dt">int</span> N<span class="op">></span></span>
<span id="cb8-6"><a href="#cb8-6" aria-hidden="true" tabindex="-1"></a><span class="dt">void</span> InferenceFeatNaive<span class="op">(</span><span class="at">const</span> <span class="dt">float</span><span class="op">*</span> point_cloud<span class="op">,</span></span>
<span id="cb8-7"><a href="#cb8-7" aria-hidden="true" tabindex="-1"></a> <span class="at">const</span> <span class="dt">int</span> num_points<span class="op">,</span></span>
<span id="cb8-8"><a href="#cb8-8" aria-hidden="true" tabindex="-1"></a> T feature<span class="op">[</span>kFeatDims5<span class="op">],</span></span>
<span id="cb8-9"><a href="#cb8-9" aria-hidden="true" tabindex="-1"></a> <span class="at">const</span> LinearParams<span class="op"><</span>U<span class="op">,</span> kFeatDims0<span class="op">,</span> kFeatDims1<span class="op">>*</span> conv1<span class="op">,</span></span>
<span id="cb8-10"><a href="#cb8-10" aria-hidden="true" tabindex="-1"></a> <span class="at">const</span> LinearParams<span class="op"><</span>U<span class="op">,</span> kFeatDims1<span class="op">,</span> kFeatDims2<span class="op">>*</span> conv2<span class="op">,</span></span>
<span id="cb8-11"><a href="#cb8-11" aria-hidden="true" tabindex="-1"></a> <span class="at">const</span> LinearParams<span class="op"><</span>U<span class="op">,</span> kFeatDims2<span class="op">,</span> kFeatDims3<span class="op">>*</span> conv3<span class="op">,</span></span>
<span id="cb8-12"><a href="#cb8-12" aria-hidden="true" tabindex="-1"></a> <span class="at">const</span> LinearParams<span class="op"><</span>U<span class="op">,</span> kFeatDims3<span class="op">,</span> kFeatDims4<span class="op">>*</span> conv4<span class="op">,</span></span>
<span id="cb8-13"><a href="#cb8-13" aria-hidden="true" tabindex="-1"></a> <span class="at">const</span> LinearParams<span class="op"><</span>U<span class="op">,</span> kFeatDims4<span class="op">,</span> kFeatDims5<span class="op">>*</span> conv5<span class="op">,</span></span>
<span id="cb8-14"><a href="#cb8-14" aria-hidden="true" tabindex="-1"></a> <span class="at">const</span> BatchNorm1dParams<span class="op"><</span>U<span class="op">,</span> kFeatDims1<span class="op">>*</span> bn1<span class="op">,</span></span>
<span id="cb8-15"><a href="#cb8-15" aria-hidden="true" tabindex="-1"></a> <span class="at">const</span> BatchNorm1dParams<span class="op"><</span>U<span class="op">,</span> kFeatDims2<span class="op">>*</span> bn2<span class="op">,</span></span>
<span id="cb8-16"><a href="#cb8-16" aria-hidden="true" tabindex="-1"></a> <span class="at">const</span> BatchNorm1dParams<span class="op"><</span>U<span class="op">,</span> kFeatDims3<span class="op">>*</span> bn3<span class="op">,</span></span>
<span id="cb8-17"><a href="#cb8-17" aria-hidden="true" tabindex="-1"></a> <span class="at">const</span> BatchNorm1dParams<span class="op"><</span>U<span class="op">,</span> kFeatDims4<span class="op">>*</span> bn4<span class="op">,</span></span>
<span id="cb8-18"><a href="#cb8-18" aria-hidden="true" tabindex="-1"></a> <span class="at">const</span> BatchNorm1dParams<span class="op"><</span>U<span class="op">,</span> kFeatDims5<span class="op">>*</span> bn5<span class="op">)</span></span>
<span id="cb8-19"><a href="#cb8-19" aria-hidden="true" tabindex="-1"></a><span class="op">{</span></span>
<span id="cb8-20"><a href="#cb8-20" aria-hidden="true" tabindex="-1"></a><span class="pp">#pragma HLS INLINE off</span></span>
<span id="cb8-21"><a href="#cb8-21" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb8-22"><a href="#cb8-22" aria-hidden="true" tabindex="-1"></a> <span class="co">// Zero-initialize the output feature</span></span>
<span id="cb8-23"><a href="#cb8-23" aria-hidden="true" tabindex="-1"></a> VectorNdSetZero<span class="op"><</span>T<span class="op">,</span> kFeatDims5<span class="op">>(</span>feature<span class="op">);</span></span>
<span id="cb8-24"><a href="#cb8-24" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb8-25"><a href="#cb8-25" aria-hidden="true" tabindex="-1"></a> <span class="co">// Compute the feature</span></span>
<span id="cb8-26"><a href="#cb8-26" aria-hidden="true" tabindex="-1"></a> <span class="cf">for</span> <span class="op">(</span><span class="dt">int</span> i <span class="op">=</span> <span class="dv">0</span><span class="op">;</span> i <span class="op"><</span> num_points<span class="op">;</span> <span class="op">++</span>i<span class="op">)</span> <span class="op">{</span></span>
<span id="cb8-27"><a href="#cb8-27" aria-hidden="true" tabindex="-1"></a><span class="pp">#pragma HLS LOOP_TRIPCOUNT min=N max=N avg=N</span></span>
<span id="cb8-28"><a href="#cb8-28" aria-hidden="true" tabindex="-1"></a><span class="pp">#pragma HLS LOOP_FLATTEN off</span></span>
<span id="cb8-29"><a href="#cb8-29" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb8-30"><a href="#cb8-30" aria-hidden="true" tabindex="-1"></a> <span class="co">// Input, output, and intermediate results</span></span>
<span id="cb8-31"><a href="#cb8-31" aria-hidden="true" tabindex="-1"></a> T x0<span class="op">[</span>kFeatDims0<span class="op">];</span></span>
<span id="cb8-32"><a href="#cb8-32" aria-hidden="true" tabindex="-1"></a> T x1<span class="op">[</span>kFeatDims1<span class="op">];</span></span>
<span id="cb8-33"><a href="#cb8-33" aria-hidden="true" tabindex="-1"></a> T x2<span class="op">[</span>kFeatDims1<span class="op">];</span></span>
<span id="cb8-34"><a href="#cb8-34" aria-hidden="true" tabindex="-1"></a> T x3<span class="op">[</span>kFeatDims2<span class="op">];</span></span>
<span id="cb8-35"><a href="#cb8-35" aria-hidden="true" tabindex="-1"></a> T x4<span class="op">[</span>kFeatDims2<span class="op">];</span></span>
<span id="cb8-36"><a href="#cb8-36" aria-hidden="true" tabindex="-1"></a> T x5<span class="op">[</span>kFeatDims3<span class="op">];</span></span>
<span id="cb8-37"><a href="#cb8-37" aria-hidden="true" tabindex="-1"></a> T x6<span class="op">[</span>kFeatDims3<span class="op">];</span></span>
<span id="cb8-38"><a href="#cb8-38" aria-hidden="true" tabindex="-1"></a> T x7<span class="op">[</span>kFeatDims4<span class="op">];</span></span>
<span id="cb8-39"><a href="#cb8-39" aria-hidden="true" tabindex="-1"></a> T x8<span class="op">[</span>kFeatDims4<span class="op">];</span></span>
<span id="cb8-40"><a href="#cb8-40" aria-hidden="true" tabindex="-1"></a> T x9<span class="op">[</span>kFeatDims5<span class="op">];</span></span>
<span id="cb8-41"><a href="#cb8-41" aria-hidden="true" tabindex="-1"></a> T x10<span class="op">[</span>kFeatDims5<span class="op">];</span></span>
<span id="cb8-42"><a href="#cb8-42" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb8-43"><a href="#cb8-43" aria-hidden="true" tabindex="-1"></a> <span class="co">// Read a point from a DDR memory</span></span>
<span id="cb8-44"><a href="#cb8-44" aria-hidden="true" tabindex="-1"></a> ReadPointNaive<span class="op"><</span>T<span class="op">>(</span>point_cloud<span class="op">,</span> i<span class="op">,</span> x0<span class="op">);</span></span>
<span id="cb8-45"><a href="#cb8-45" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb8-46"><a href="#cb8-46" aria-hidden="true" tabindex="-1"></a> <span class="co">// Compute a point feature</span></span>
<span id="cb8-47"><a href="#cb8-47" aria-hidden="true" tabindex="-1"></a> LinearNaive<span class="op"><</span>T<span class="op">,</span> U<span class="op">,</span> kFeatDims0<span class="op">,</span> kFeatDims1<span class="op">,</span> <span class="kw">false</span><span class="op">>(</span></span>
<span id="cb8-48"><a href="#cb8-48" aria-hidden="true" tabindex="-1"></a> x0<span class="op">,</span> x1<span class="op">,</span> conv1<span class="op">-></span>weight<span class="op">,</span> conv1<span class="op">-></span>bias<span class="op">);</span></span>
<span id="cb8-49"><a href="#cb8-49" aria-hidden="true" tabindex="-1"></a> BatchNorm1dReLUNaive<span class="op"><</span>T<span class="op">,</span> U<span class="op">,</span> kFeatDims1<span class="op">>(</span></span>
<span id="cb8-50"><a href="#cb8-50" aria-hidden="true" tabindex="-1"></a> x1<span class="op">,</span> x2<span class="op">,</span> bn1<span class="op">-></span>scale<span class="op">,</span> bn1<span class="op">-></span>bias<span class="op">,</span> bn1<span class="op">-></span>mean<span class="op">);</span></span>
<span id="cb8-51"><a href="#cb8-51" aria-hidden="true" tabindex="-1"></a> LinearNaive<span class="op"><</span>T<span class="op">,</span> U<span class="op">,</span> kFeatDims1<span class="op">,</span> kFeatDims2<span class="op">,</span> <span class="kw">false</span><span class="op">>(</span></span>
<span id="cb8-52"><a href="#cb8-52" aria-hidden="true" tabindex="-1"></a> x2<span class="op">,</span> x3<span class="op">,</span> conv2<span class="op">-></span>weight<span class="op">,</span> conv2<span class="op">-></span>bias<span class="op">);</span></span>
<span id="cb8-53"><a href="#cb8-53" aria-hidden="true" tabindex="-1"></a> BatchNorm1dReLUNaive<span class="op"><</span>T<span class="op">,</span> U<span class="op">,</span> kFeatDims2<span class="op">>(</span></span>
<span id="cb8-54"><a href="#cb8-54" aria-hidden="true" tabindex="-1"></a> x3<span class="op">,</span> x4<span class="op">,</span> bn2<span class="op">-></span>scale<span class="op">,</span> bn2<span class="op">-></span>bias<span class="op">,</span> bn2<span class="op">-></span>mean<span class="op">);</span></span>
<span id="cb8-55"><a href="#cb8-55" aria-hidden="true" tabindex="-1"></a> LinearNaive<span class="op"><</span>T<span class="op">,</span> U<span class="op">,</span> kFeatDims2<span class="op">,</span> kFeatDims3<span class="op">,</span> <span class="kw">false</span><span class="op">>(</span></span>
<span id="cb8-56"><a href="#cb8-56" aria-hidden="true" tabindex="-1"></a> x4<span class="op">,</span> x5<span class="op">,</span> conv3<span class="op">-></span>weight<span class="op">,</span> conv3<span class="op">-></span>bias<span class="op">);</span></span>
<span id="cb8-57"><a href="#cb8-57" aria-hidden="true" tabindex="-1"></a> BatchNorm1dReLUNaive<span class="op"><</span>T<span class="op">,</span> U<span class="op">,</span> kFeatDims3<span class="op">>(</span></span>
<span id="cb8-58"><a href="#cb8-58" aria-hidden="true" tabindex="-1"></a> x5<span class="op">,</span> x6<span class="op">,</span> bn3<span class="op">-></span>scale<span class="op">,</span> bn3<span class="op">-></span>bias<span class="op">,</span> bn3<span class="op">-></span>mean<span class="op">);</span></span>
<span id="cb8-59"><a href="#cb8-59" aria-hidden="true" tabindex="-1"></a> LinearNaive<span class="op"><</span>T<span class="op">,</span> U<span class="op">,</span> kFeatDims3<span class="op">,</span> kFeatDims4<span class="op">,</span> <span class="kw">false</span><span class="op">>(</span></span>
<span id="cb8-60"><a href="#cb8-60" aria-hidden="true" tabindex="-1"></a> x6<span class="op">,</span> x7<span class="op">,</span> conv4<span class="op">-></span>weight<span class="op">,</span> conv4<span class="op">-></span>bias<span class="op">);</span></span>
<span id="cb8-61"><a href="#cb8-61" aria-hidden="true" tabindex="-1"></a> BatchNorm1dReLUNaive<span class="op"><</span>T<span class="op">,</span> U<span class="op">,</span> kFeatDims4<span class="op">>(</span></span>
<span id="cb8-62"><a href="#cb8-62" aria-hidden="true" tabindex="-1"></a> x7<span class="op">,</span> x8<span class="op">,</span> bn4<span class="op">-></span>scale<span class="op">,</span> bn4<span class="op">-></span>bias<span class="op">,</span> bn4<span class="op">-></span>mean<span class="op">);</span></span>
<span id="cb8-63"><a href="#cb8-63" aria-hidden="true" tabindex="-1"></a> LinearNaive<span class="op"><</span>T<span class="op">,</span> U<span class="op">,</span> kFeatDims4<span class="op">,</span> kFeatDims5<span class="op">,</span> <span class="kw">false</span><span class="op">>(</span></span>
<span id="cb8-64"><a href="#cb8-64" aria-hidden="true" tabindex="-1"></a> x8<span class="op">,</span> x9<span class="op">,</span> conv5<span class="op">-></span>weight<span class="op">,</span> conv5<span class="op">-></span>bias<span class="op">);</span></span>
<span id="cb8-65"><a href="#cb8-65" aria-hidden="true" tabindex="-1"></a> BatchNorm1dReLUNaive<span class="op"><</span>T<span class="op">,</span> U<span class="op">,</span> kFeatDims5<span class="op">>(</span></span>
<span id="cb8-66"><a href="#cb8-66" aria-hidden="true" tabindex="-1"></a> x9<span class="op">,</span> x10<span class="op">,</span> bn5<span class="op">-></span>scale<span class="op">,</span> bn5<span class="op">-></span>bias<span class="op">,</span> bn5<span class="op">-></span>mean<span class="op">);</span></span>
<span id="cb8-67"><a href="#cb8-67" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb8-68"><a href="#cb8-68" aria-hidden="true" tabindex="-1"></a> <span class="co">// Update the output feature</span></span>
<span id="cb8-69"><a href="#cb8-69" aria-hidden="true" tabindex="-1"></a> MaxPool1dNaive<span class="op"><</span>T<span class="op">,</span> kFeatDims5<span class="op">>(</span>x10<span class="op">,</span> feature<span class="op">);</span></span>
<span id="cb8-70"><a href="#cb8-70" aria-hidden="true" tabindex="-1"></a> <span class="op">}</span></span>
<span id="cb8-71"><a href="#cb8-71" aria-hidden="true" tabindex="-1"></a><span class="op">}</span></span>
<span id="cb8-72"><a href="#cb8-72" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb8-73"><a href="#cb8-73" aria-hidden="true" tabindex="-1"></a><span class="co">// Naive implementation of the classification network</span></span>
<span id="cb8-74"><a href="#cb8-74" aria-hidden="true" tabindex="-1"></a><span class="co">// `T` is the type for layer input, output, and intermediate results</span></span>
<span id="cb8-75"><a href="#cb8-75" aria-hidden="true" tabindex="-1"></a><span class="co">// `U` is the type for parameters</span></span>
<span id="cb8-76"><a href="#cb8-76" aria-hidden="true" tabindex="-1"></a><span class="kw">template</span> <span class="op"><</span><span class="kw">typename</span> T<span class="op">,</span> <span class="kw">typename</span> U<span class="op">></span></span>
<span id="cb8-77"><a href="#cb8-77" aria-hidden="true" tabindex="-1"></a><span class="dt">void</span> InferenceClsNaive<span class="op">(</span><span class="at">const</span> T feature<span class="op">[</span>kFeatDims5<span class="op">],</span></span>
<span id="cb8-78"><a href="#cb8-78" aria-hidden="true" tabindex="-1"></a> <span class="dt">float</span><span class="op">*</span> out_logits<span class="op">,</span></span>
<span id="cb8-79"><a href="#cb8-79" aria-hidden="true" tabindex="-1"></a> <span class="at">const</span> LinearParams<span class="op"><</span>U<span class="op">,</span> kClsDims2<span class="op">,</span> kClsDims3<span class="op">>*</span> fc3<span class="op">,</span></span>
<span id="cb8-80"><a href="#cb8-80" aria-hidden="true" tabindex="-1"></a> <span class="at">const</span> BatchNorm1dParams<span class="op"><</span>U<span class="op">,</span> kClsDims1<span class="op">>*</span> bn1<span class="op">,</span></span>
<span id="cb8-81"><a href="#cb8-81" aria-hidden="true" tabindex="-1"></a> <span class="at">const</span> BatchNorm1dParams<span class="op"><</span>U<span class="op">,</span> kClsDims2<span class="op">>*</span> bn2<span class="op">,</span></span>
<span id="cb8-82"><a href="#cb8-82" aria-hidden="true" tabindex="-1"></a> <span class="at">const</span> <span class="dt">float</span><span class="op">*</span> params1<span class="op">,</span></span>
<span id="cb8-83"><a href="#cb8-83" aria-hidden="true" tabindex="-1"></a> <span class="at">const</span> <span class="dt">float</span><span class="op">*</span> params2<span class="op">,</span></span>
<span id="cb8-84"><a href="#cb8-84" aria-hidden="true" tabindex="-1"></a> <span class="at">const</span> <span class="dt">float</span><span class="op">*</span> params3<span class="op">)</span></span>
<span id="cb8-85"><a href="#cb8-85" aria-hidden="true" tabindex="-1"></a><span class="op">{</span></span>
<span id="cb8-86"><a href="#cb8-86" aria-hidden="true" tabindex="-1"></a><span class="pp">#pragma HLS INLINE off</span></span>
<span id="cb8-87"><a href="#cb8-87" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb8-88"><a href="#cb8-88" aria-hidden="true" tabindex="-1"></a> <span class="kw">static_assert</span><span class="op">(</span>kFeatDims5 <span class="op">==</span> kClsDims0<span class="op">,</span></span>
<span id="cb8-89"><a href="#cb8-89" aria-hidden="true" tabindex="-1"></a> <span class="st">"Feature dimension should be equal to the input dimension"</span><span class="op">);</span></span>
<span id="cb8-90"><a href="#cb8-90" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb8-91"><a href="#cb8-91" aria-hidden="true" tabindex="-1"></a> <span class="co">// Input, output, and intermediate results</span></span>
<span id="cb8-92"><a href="#cb8-92" aria-hidden="true" tabindex="-1"></a> T x0<span class="op">[</span>kClsDims1<span class="op">];</span></span>
<span id="cb8-93"><a href="#cb8-93" aria-hidden="true" tabindex="-1"></a> T x1<span class="op">[</span>kClsDims1<span class="op">];</span></span>
<span id="cb8-94"><a href="#cb8-94" aria-hidden="true" tabindex="-1"></a> T x2<span class="op">[</span>kClsDims2<span class="op">];</span></span>
<span id="cb8-95"><a href="#cb8-95" aria-hidden="true" tabindex="-1"></a> T x3<span class="op">[</span>kClsDims2<span class="op">];</span></span>
<span id="cb8-96"><a href="#cb8-96" aria-hidden="true" tabindex="-1"></a> T x4<span class="op">[</span>kClsDims3<span class="op">];</span></span>
<span id="cb8-97"><a href="#cb8-97" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb8-98"><a href="#cb8-98" aria-hidden="true" tabindex="-1"></a> <span class="co">// Compute logits</span></span>
<span id="cb8-99"><a href="#cb8-99" aria-hidden="true" tabindex="-1"></a> LinearNaiveDDR<span class="op"><</span>T<span class="op">,</span> U<span class="op">,</span> kClsDims0<span class="op">,</span> kClsDims1<span class="op">,</span> <span class="kw">false</span><span class="op">>(</span></span>
<span id="cb8-100"><a href="#cb8-100" aria-hidden="true" tabindex="-1"></a> feature<span class="op">,</span> x0<span class="op">,</span> params1<span class="op">,</span> <span class="dv">0</span><span class="op">);</span></span>
<span id="cb8-101"><a href="#cb8-101" aria-hidden="true" tabindex="-1"></a> BatchNorm1dReLUNaive<span class="op"><</span>T<span class="op">,</span> U<span class="op">,</span> kClsDims1<span class="op">>(</span></span>
<span id="cb8-102"><a href="#cb8-102" aria-hidden="true" tabindex="-1"></a> x0<span class="op">,</span> x1<span class="op">,</span> bn1<span class="op">-></span>scale<span class="op">,</span> bn1<span class="op">-></span>bias<span class="op">,</span> bn1<span class="op">-></span>mean<span class="op">);</span></span>
<span id="cb8-103"><a href="#cb8-103" aria-hidden="true" tabindex="-1"></a> LinearNaiveDDR<span class="op"><</span>T<span class="op">,</span> U<span class="op">,</span> kClsDims1<span class="op">,</span> kClsDims2<span class="op">,</span> <span class="kw">false</span><span class="op">>(</span></span>
<span id="cb8-104"><a href="#cb8-104" aria-hidden="true" tabindex="-1"></a> x1<span class="op">,</span> x2<span class="op">,</span> params2<span class="op">,</span> <span class="dv">0</span><span class="op">);</span></span>
<span id="cb8-105"><a href="#cb8-105" aria-hidden="true" tabindex="-1"></a> BatchNorm1dReLUNaive<span class="op"><</span>T<span class="op">,</span> U<span class="op">,</span> kClsDims2<span class="op">>(</span></span>
<span id="cb8-106"><a href="#cb8-106" aria-hidden="true" tabindex="-1"></a> x2<span class="op">,</span> x3<span class="op">,</span> bn2<span class="op">-></span>scale<span class="op">,</span> bn2<span class="op">-></span>bias<span class="op">,</span> bn2<span class="op">-></span>mean<span class="op">);</span></span>
<span id="cb8-107"><a href="#cb8-107" aria-hidden="true" tabindex="-1"></a> LinearNaive<span class="op"><</span>T<span class="op">,</span> U<span class="op">,</span> kClsDims2<span class="op">,</span> kClsDims3<span class="op">,</span> <span class="kw">false</span><span class="op">>(</span></span>
<span id="cb8-108"><a href="#cb8-108" aria-hidden="true" tabindex="-1"></a> x3<span class="op">,</span> x4<span class="op">,</span> fc3<span class="op">-></span>weight<span class="op">,</span> fc3<span class="op">-></span>bias<span class="op">);</span></span>
<span id="cb8-109"><a href="#cb8-109" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb8-110"><a href="#cb8-110" aria-hidden="true" tabindex="-1"></a> <span class="co">// Write the result</span></span>
<span id="cb8-111"><a href="#cb8-111" aria-hidden="true" tabindex="-1"></a> WriteTensor1dNaive<span class="op"><</span>T<span class="op">,</span> kClsDims3<span class="op">>(</span>out_logits<span class="op">,</span> x4<span class="op">,</span> <span class="dv">0</span><span class="op">);</span></span>
<span id="cb8-112"><a href="#cb8-112" aria-hidden="true" tabindex="-1"></a><span class="op">}</span></span></code></pre></div>
<p><code>InferenceFeatNaive</code>では、DRAMに置かれた点群データ
(<code>point_cloud</code>) から、1つずつ点を読み取ります。 各点
(<code>x0</code>) に対してローカルな特徴量 (<code>x10</code>)
を計算し、現在のグローバル特徴量 (<code>feature</code>)
を更新する処理を、点の個数 (<code>num_points</code>) だけ繰り返します。
<code>InferenceClsNaive</code>は、点群全体を表すグローバル特徴量
(<code>feature</code>) を受け取って、各クラスに対するロジット
(<code>x4</code>) を計算し、それをDRAMバッファ (<code>out_logits</code>)
に書き戻します。</p>
<p><code>ReadPointNaive</code>は、<span
class="math inline">\(i\)</span>番目の点<span
class="math inline">\(\boldsymbol{p}_i\)</span>を、DRAMバッファから読み取るものです。
<code>LinearNaive</code>、<code>BatchNorm1dReLUNaive</code>、<code>MaxPool1dNaive</code>は、名前の通り、全結合層
(<code>Conv1d</code>)、バッチ正規化層とReLU活性化、Maxプーリング層に対応します