Skip to content

Latest commit

 

History

History
3344 lines (2811 loc) · 178 KB

point-cloud-classification.md

File metadata and controls

3344 lines (2811 loc) · 178 KB

title: 点矀凊理のFPGA高速化 author: SternGerlach

ホヌムに戻る

このペヌゞに぀いお

このペヌゞは、慶應理工アドベントカレンダヌ2022の22日目の蚘事です。 去幎の蚘事はこちらずこちらです。

早速䜙談ですが、1983幎12月22日は、Yellow Magic Orchestra (YMO) が行った最埌の囜内ツアヌの最終日で、開催堎所は日本歊道通でした。 今日は、その散開ツアヌからちょうど39幎目の蚘念すべき日です。 1984幎2月22日発売の「アフタヌ・サヌノィス」や、1992幎11月21日発売の「コンプリヌト・サヌノィス」に音源が収録されおいるので、みなさん是非聎いおみおください。 たた䜙談ですが、普段は(研究そっちのけで)CDを集めおいたす。 70幎代から80幎代にかけおのアヌティストが奜きです。 最近は、専らオフコヌスを聎いおいたす。 オフコヌスの旧芏栌盀のコレクションはこちらにありたす。 たた、コレクションはこちらずこちらにたずめおありたす。 暇なずきにご芧ください。

もう䞀぀䜙談。 今幎聎いたなかで最も良かったアルバム。

  1. チュヌリップ「Halo」(1983幎 / VICL-62399 / 2007幎盀)
    • 特によかった曲: 🥇「䞘に吹く颚」🥈「愛を抱きしめお」🥉「茝く星」「想い出のランドスケヌプ」「コスモスの咲く郷」「星空の䌝蚀」「セルリアン・ブルヌ」
  2. オフコヌス「この道をゆけば オフ・コヌス・ラりンド2」(1974幎 / CA35-1033 / 1983幎盀)
    • 特によかった曲: 🥇「はたちの頃」🥈「別れの情景(1)」🥉「銖茪のない犬」「あの角をたがれば」「日曜日のたいく぀」
  3. オフコヌス「I Love You」(1982幎 / CA35-1002 / 1982幎盀)
    • 特によかった曲: 🥇「哀しき街」🥈「決しお圌等のようではなく」🥉「Yes-Yes-Yes」「愛のゆくえ」
  4. オフコヌス「ワむンの匂い」(1975幎 / CA35-1032 / 1983幎盀)
    • 特によかった曲: 🥇「幻想」🥈「老人の぀ぶやき」🥉「憂き䞖に」「雚よ激しく」「倖せなんお」「ワむンの匂い」「眠れぬ倜」
  5. オフコヌス「Song Is Love」(1976幎 / CA35-1041 / 1983幎盀)
    • 特によかった曲: 🥇「冬が来るたえに」🥈「青空ず人生ず」🥉「歌を捧げお」「青春」「ひずりで生きおゆければ」
  6. チュヌリップ「New Tune」(1985幎 / 35FD-1005 / 1985幎盀)
    • 特によかった曲: 🥇「もっず幞せに玠盎になれたら」🥈「ロベリア」🥉「Our Song」「ふた぀めのクリスマス」「そんな男になれたら」
  7. 倧滝詠䞀「Each Time」(1984幎 / 35DH 78 / 1984幎盀)
    • 特によかった曲: 🥇「Bachelor Girl」🥈「ペパヌミント・ブルヌ」🥉「魔法の瞳」「恋のナックルボヌル」
  8. 麗矎「"R"」(1984幎 / 35C31-7250 / 1984幎盀)
    • 特によかった曲: 🥇「星のクラむマヌ」🥈「颚は明日ぞ」🥉「空が䞀面海に芋えた日」「恋の䞀時間は孀独の千幎」「青春のリグレット」「ポニヌテむル」
  9. ハむ・ファむ・セット「Sweet Locomotion」(1986幎 / 32DH 393 / 1986幎盀)
    • 特によかった曲: 🥇「ひずきれの恋」🥈「たった䞀枚のフォトグラフ」🥉「Elevator Town」「Do You Remember Me?」
  10. 和久井映芋「Flora」(1990幎 / PSCR-1006 / 1990幎盀)
    • 特によかった曲: 🥇「マむ・ロンリィ・グッバむ・クラブ」🥈「偶然の旅人」🥉「倢で䌚いたしょう」「神様がいない土曜日」
  11. 鈎朚康博「Sincerely」(1983幎 / CA35-1043 / 1983幎盀)
    • 特によかった曲: 🥇「瑠璃色の倜明け」🥈「僕ず海ぞ」🥉「ラララ 愛の䞖界ぞ」「入り江」「君の誕生日」
  12. 岡田有垌子「ノィヌナス誕生」(1986幎 / D32A0164 / 1986幎盀)
    • 特によかった曲: 🥇「ノィヌナス誕生」🥈「銀河のバカンス」🥉「眠れぬ倜のAquarius」「Wonder Trip Lover」「Spring Accident」
  13. 尟厎亜矎「Kids」(1986幎 / D32A0235 / 1986幎盀)
    • 特によかった曲: 🥇「流れ星が奜き」🥈「シャむネスボヌむ」🥉「St.Valentine's Day Rhapsody」「Com'on Mamy」
  14. 久保田早玀「倜の底は柔らかな幻」(1984幎 / DYCL-17 / 2005幎盀)
    • 特によかった曲: 🥇「ピアニッシモで...」🥈「寒い絵葉曞」🥉「月の浜蟺ボタンがひず぀」「メランコリヌのテヌブルクロス」
  15. 薬垫䞞ひろ子「花図鑑」(1986幎 / CA32-1260 / 1986幎盀)
    • 特によかった曲: 🥇「玅い花、青い花」🥈「寒怿、咲いた」🥉「ロヌズ・ティヌはいかが?」「哀しみの皮」「透明なチュヌリップ」「麊わら垜子のアン」

むントロが良い曲 (おたけ)。

  1. チュヌリップ「Shooting Star」(1981幎)
  2. 井䞊鑑「Karsavina ニゞンスキヌの翌」(1983幎)
  3. 井䞊鑑「Running Fence -Ode A Christo」(1982幎)

今幎は、点矀凊理 (点矀分類タスク) 向けニュヌラルネットのFPGA高速化を詊しおみたす。 LeNetやResNetなど、画像凊理向けニュヌラルネットのFPGA高速化も面癜いのですが、既にたくさんの玠晎らしい蚘事が出おいるのでやめたした。 音楜の話も、誰にも通じないし、りケないず思ったのでやめたした。 コンピュヌタで閲芧されるこずをお勧めしたす。

ニュヌラルネットの準備

点矀の分類、セグメンテヌション、レゞストレヌションなど、様々なタスクに察応した代衚的なモデルずしお、2017幎にCVPRで発衚されたPointNetが挙げられたす。 PointNetは、MLPずMaxプヌリング局からなる、シンプルか぀匷力なモデルです。 分類タスク向けのPointNetの構造を、以䞋に瀺したす。

モデルは、点矀からの特城抜出ず、特城に基づく分類の、2぀の郚分に分けられたす (図のFeature extractionずClassification)。

図の巊端に瀺すように、$N$個の点を含む、3次元の点矀$\mathcal{P} = \left{ \boldsymbol{p}_1, \ldots, \boldsymbol{p}_N \right} \in \mathbb{R}^{N \times 3}$が入力です。 MLPを甚いお、各点$\boldsymbol{p}_i \in \mathbb{R}^3$に察しお、1024次元のロヌカルな特城$\boldsymbol{\psi}_i \in \mathbb{R}^{1024}$を蚈算したす。 党おの点に察しおロヌカルな特城量$\boldsymbol{\Psi} = \left{ \boldsymbol{\psi}_1, \ldots, \boldsymbol{\psi}_N \right} \in \mathbb{R}^{N \times 1024}$を蚈算したら、それらをMaxプヌリング局により集玄しお、点矀党䜓を衚すグロヌバルな特城量$\boldsymbol{\phi} \in \mathbb{R}^{1024}$を埗たす ($\boldsymbol{\phi} \gets \max(\boldsymbol{\psi}_1, \ldots, \boldsymbol{\psi}_N)$)。

分類甚のネットワヌクは、この特城量$\boldsymbol{\phi}$を入力ずしお、各物䜓のクラスに察するロゞット (スコア)を出力したす。 物䜓のクラス数を$K$ずすれば、出力は$K$次元のベクトルずなりたす。

図のInput TransformおよびFeature Transformは、点矀の特城に察しおアフィン倉換を斜し、剛䜓倉換に察しお䞍倉な特城量を埗るためのネットワヌクですが、実装が面倒なので取り陀きたす(最適化その1: モデルの簡略化)。 埓っお、今回FPGA䞊に実装するPointNetは、以䞋のようになりたす。

画像認識向けのモデルずは異なり、畳み蟌み局がありたせん。 たた、MLPは、党結合局、ReLU掻性化局、バッチ正芏化局をたずめたものずしたす。

PyTorchによるモデルの定矩は、次のようになりたす (net/model.py)。 ゜ヌスコヌド党䜓はこちらのリポゞトリに眮かれおいるので、適宜ご参照ください。

class PointNetFeat(torch.nn.Module):
    def __init__(self):
        super().__init__()

        self.conv1 = torch.nn.Conv1d(3, 64, 1)
        self.conv2 = torch.nn.Conv1d(64, 64, 1)
        self.conv3 = torch.nn.Conv1d(64, 64, 1)
        self.conv4 = torch.nn.Conv1d(64, 128, 1)
        self.conv5 = torch.nn.Conv1d(128, 1024, 1)
        self.bn1 = torch.nn.BatchNorm1d(64)
        self.bn2 = torch.nn.BatchNorm1d(64)
        self.bn3 = torch.nn.BatchNorm1d(64)
        self.bn4 = torch.nn.BatchNorm1d(128)
        self.bn5 = torch.nn.BatchNorm1d(1024)

    def forward(self, x: torch.Tensor):
        # `x` is of size [B, N, 3]
        N = x.shape[1]
        # `x` is of size [B, 3, N]
        x = x.transpose(1, 2)

        # `x` is of size [B, 1024, N]
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))
        x = F.relu(self.bn4(self.conv4(x)))
        x = F.relu(self.bn5(self.conv5(x)))

        # `x` is of size [B, 1024]
        x = torch.max(x, dim=2)[0]

        return x

class PointNetCls(torch.nn.Module):
    def __init__(self, num_classes: int):
        super().__init__()

        # Feature extraction
        self.feat = PointNetFeat()

        # Classification network
        self.fc1 = torch.nn.Linear(1024, 512)
        self.fc2 = torch.nn.Linear(512, 256)
        self.fc3 = torch.nn.Linear(256, num_classes)
        self.bn1 = torch.nn.BatchNorm1d(512)
        self.bn2 = torch.nn.BatchNorm1d(256)

    def forward(self, x):
        # `x` is of size [B, N, 3]
        # `x` is of size [B, 1024]
        x = self.feat(x)

        # `x` is of size [B, `num_classes`]
        x = F.relu(self.bn1(self.fc1(x)))
        x = F.relu(self.bn2(self.fc2(x)))
        x = self.fc3(x)

        return x

さお、このモデルをそのたた実装する堎合、次のような問題がありたす。 特城抜出郚分 (図のFeature extraction)に泚目したす。 図䞭の灰色の四角に瀺すように、$N$個党おの点に察する䞭間結果や、ロヌカルな特城量$\boldsymbol{\Psi}$を、どこかに保持しおおく必芁がありたす。 倧容量のメモリを搭茉したGPUであれば、これでも問題ありたせんが、FPGA内郚のオンチップメモリ (BlockRAM) は非垞に容量が少ないので、党おの点に察する䞭間結果を保持しようずするず、オンチップメモリがあっずいう間に枯枇するでしょう。 蚀い換えるず、搭茉されおいるオンチップメモリの容量によっお、点の個数$N$が制限されおしたいたす。 これは避けたいものです。 オンチップメモリの代わりに、容量の倧きなDRAM䞊に眮くこずもできたすが、デヌタぞのアクセス時間は長くなりたす。 党おの局の䞭間結果をDRAMに眮くず、デヌタ転送のオヌバヌヘッドが増加しお、性胜に悪圱響を及がしたす。 局の䞭間結果は、オンチップバッファに眮きたいものです。

そこで、党おの点$\mathcal{P}$に察しお、ロヌカルな特城量$\boldsymbol{\Psi}$を䞀気に蚈算するのではなく、1぀ず぀の点$\boldsymbol{p}$に察しお順にロヌカルな特城量$\boldsymbol{\psi}$を蚈算したしょう。 䞀気に蚈算するのず比べお蚈算効率は萜ちたすが、1぀の点に察する䞭間結果やロヌカルな特城量だけを保持すればよいので、オンチップバッファの消費を倧きく削枛できたす。

以前は (PyTorchなどのフレヌムワヌクを䜿う堎合は)、特城抜出は次のように行われおいたした。

  1. 党おの点$\mathcal{P}$に察しお、ロヌカルな特城量を$\boldsymbol{\Psi}$をたずめお蚈算する ($(N, 64)$や$(N, 1024)$のバッファが必芁)。
  2. Maxプヌリング局により、ロヌカルな特城量$\boldsymbol{\Psi}$を集玄しお、グロヌバルな特城量$\boldsymbol{\phi}$を埗る ($\boldsymbol{\phi} \gets \max(\boldsymbol{\psi}_1, \ldots, \boldsymbol{\psi}_N)$)。
  3. グロヌバルな特城量$\boldsymbol{\phi}$をMLPに入力し、各クラスに察するロゞット($K$次元のベクトル)を埗る。

これを、次のように倉曎したす(最適化その2: 蚈算順序の倉曎)。

  1. グロヌバルな特城量$\boldsymbol{\phi}$を、$\boldsymbol{0}$で初期化する。
  2. 各点$\boldsymbol{p}_i \ (i = 1, \ldots, N)$に察しお、以䞋の凊理を行う。
    1. MLPの順䌝播により、ロヌカルな特城量$\boldsymbol{\psi}_i$を埗る ($(1, 64)$や$(1, 1024)$のバッファがあればよい)。
    2. $\boldsymbol{\phi}$ず$\boldsymbol{\psi}_i$ずの、芁玠ごずの$\max$をずるこずで、$\boldsymbol{\phi}$を曎新する ($\boldsymbol{\phi} \gets \max(\boldsymbol{\phi}, \boldsymbol{\psi}_i)$)。
  3. グロヌバルな特城量$\boldsymbol{\phi}$をMLPに入力し、各クラスに察するロゞット($K$次元のベクトル)を埗る。

党おの点に察するロヌカルな特城量$\boldsymbol{\Psi}$を集玄するのではなく、各点$\boldsymbol{p}_i$に察するロヌカルな特城量$\boldsymbol{\psi}_i$を䜿っお、グロヌバルな特城量$\boldsymbol{\phi}$を逐次的に曎新しおいきたす。 これは近䌌ではないので、党く同じ結果ずなりたす。

最終的に、今回FPGA䞊に実装するPointNetは、以䞋のようになりたす。

高䜍合成による実装

今回は、高䜍合成 (HLS: High-Level Synthesis)を甚いお、䞊蚘に瀺すPointNetの専甚回路 (IPコア) を蚘述したす。 ニュヌラルネットの掚論を実珟する別の手段ずしお、行列挔算や畳み蟌み挔算甚の、巚倧か぀汎甚的な挔算回路をFPGA䞊に実装し、それに繰り返しデヌタを䞎えるこずも考えられたす。

高䜍合成は、C/C++による動䜜レベル (Behavior Level) の回路蚘述を、Verilog HDLやSystemVerilogによるレゞスタ転送レベル (RTL: Register Transfer Level) の回路蚘述に倉換するための技術です。 Verilog HDLを盎接蚘述するのに比べお、遥かに楜で、ストレスが少なく、生産性が向䞊したす。 䜆し、C/C++で蚘述するずはいっおも、通垞の゜フトりェア開発ずは党く様盞が異なりたす。 malloc()やnewはもちろんのこず、これらに䟝存するstd::vectorなどの䟿利なデヌタ型も䜿えないので、固定長の配列に眮き換えおどうにかしたす。 ニュヌラルネットはサむズが固定で、䞀般には決たった動䜜をするので、FPGA䞊に実装しやすいです。

高䜍合成甚のツヌルずしお、Xilinx瀟のVitis HLS 2022.1を利甚したす。 たた実装察象のFPGAずしお、Xilinx ZCU104 Evaluation Board (XCZU7EV-2FFVC1156)を䜿いたす。 Xilinx ZCU104には、FPGAのほかに、クアッドコア ARM Cortex-A53 CPU (1.2GHz)ず2GBのDRAMも搭茉されおおり、Linuxが動䜜したす。

早速、PointNetのIPコアを瀺したす (適宜GitHubのリポゞトリをご芧ください)。 高䜍合成ツヌルのバック゚ンドがGCC 6.2ですので、C++14やC++17の䞀郚機胜が利甚できたす。 䜆し、ツヌルのバグを螏むかもしれないので、あたり凝った機胜は䜿わないようにしおいたす。

// Size of the PointNet classification network
// Refer to net/model.py for details

// Size of the feature extraction network
constexpr const int kFeatDims0 = 3;
constexpr const int kFeatDims1 = 64;
constexpr const int kFeatDims2 = 64;
constexpr const int kFeatDims3 = 64;
constexpr const int kFeatDims4 = 128;
constexpr const int kFeatDims5 = 1024;

// Size of the classification network
// ModelNet40 has 40 object classes
constexpr const int kClsDims0 = kFeatDims5;
constexpr const int kClsDims1 = 512;
constexpr const int kClsDims2 = 256;
constexpr const int kClsDims3 = 40;

// Top function
void PointNetClsTop(const int op_mode,
                    const float* point_cloud,
                    const int num_points,
                    float* out_logits,
                    const float* feat_params1,
                    const float* feat_params2,
                    const float* feat_params3,
                    const float* feat_params4,
                    const float* feat_params5,
                    const float* cls_params1,
                    const float* cls_params2,
                    const float* cls_params3)
{
#pragma HLS INTERFACE m_axi port=point_cloud offset=slave bundle=gmem0
#pragma HLS INTERFACE m_axi port=out_logits offset=slave bundle=gmem0
#pragma HLS INTERFACE m_axi port=feat_params1 offset=slave bundle=gmem0
#pragma HLS INTERFACE m_axi port=feat_params2 offset=slave bundle=gmem0
#pragma HLS INTERFACE m_axi port=feat_params3 offset=slave bundle=gmem0
#pragma HLS INTERFACE m_axi port=feat_params4 offset=slave bundle=gmem0
#pragma HLS INTERFACE m_axi port=feat_params5 offset=slave bundle=gmem0
#pragma HLS INTERFACE m_axi port=cls_params1 offset=slave bundle=gmem0
#pragma HLS INTERFACE m_axi port=cls_params2 offset=slave bundle=gmem0
#pragma HLS INTERFACE m_axi port=cls_params3 offset=slave bundle=gmem0

#pragma HLS INTERFACE s_axilite port=op_mode bundle=control
#pragma HLS INTERFACE s_axilite port=point_cloud bundle=control
#pragma HLS INTERFACE s_axilite port=num_points bundle=control
#pragma HLS INTERFACE s_axilite port=out_logits bundle=control
#pragma HLS INTERFACE s_axilite port=feat_params1 bundle=control
#pragma HLS INTERFACE s_axilite port=feat_params2 bundle=control
#pragma HLS INTERFACE s_axilite port=feat_params3 bundle=control
#pragma HLS INTERFACE s_axilite port=feat_params4 bundle=control
#pragma HLS INTERFACE s_axilite port=feat_params5 bundle=control
#pragma HLS INTERFACE s_axilite port=cls_params1 bundle=control
#pragma HLS INTERFACE s_axilite port=cls_params2 bundle=control
#pragma HLS INTERFACE s_axilite port=cls_params3 bundle=control
#pragma HLS INTERFACE s_axilite port=return bundle=control

  // Parameters for feature extraction
  LinearParams<param_t, kFeatDims0, kFeatDims1> feat_conv1;
  LinearParams<param_t, kFeatDims1, kFeatDims2> feat_conv2;
  LinearParams<param_t, kFeatDims2, kFeatDims3> feat_conv3;
  LinearParams<param_t, kFeatDims3, kFeatDims4> feat_conv4;
  LinearParams<param_t, kFeatDims4, kFeatDims5> feat_conv5;
  BatchNorm1dParams<param_t, kFeatDims1> feat_bn1;
  BatchNorm1dParams<param_t, kFeatDims2> feat_bn2;
  BatchNorm1dParams<param_t, kFeatDims3> feat_bn3;
  BatchNorm1dParams<param_t, kFeatDims4> feat_bn4;
  BatchNorm1dParams<param_t, kFeatDims5> feat_bn5;

  // Parameters for classification network
  // LinearParams<param_t, kClsDims0, kClsDims1> cls_fc1;
  // LinearParams<param_t, kClsDims1, kClsDims2> cls_fc2;
  LinearParams<param_t, kClsDims2, kClsDims3> cls_fc3;
  BatchNorm1dParams<param_t, kClsDims1> cls_bn1;
  BatchNorm1dParams<param_t, kClsDims2> cls_bn2;

  // Extracted feature
  value_t feature[kFeatDims5];

  if (op_mode == kModeInitWeights) {
    // Initialize the PointNet feature extraction network
    InitializeFeatNaive<param_t>(
      &feat_conv1, &feat_conv2, &feat_conv3, &feat_conv4, &feat_conv5,
      &feat_bn1, &feat_bn2, &feat_bn3, &feat_bn4, &feat_bn5,
      feat_params1, feat_params2, feat_params3, feat_params4, feat_params5);
    // Initialize the classification network
    InitializeClsNaive<param_t>(
      &cls_fc3, &cls_bn1, &cls_bn2,
      cls_params1, cls_params2, cls_params3);
  } else if (op_mode == kModeInference) {
    // Run the PointNet feature extraction
    InferenceFeatNaive<value_t, param_t, 1024>(
      point_cloud, num_points, feature,
      &feat_conv1, &feat_conv2, &feat_conv3, &feat_conv4, &feat_conv5,
      &feat_bn1, &feat_bn2, &feat_bn3, &feat_bn4, &feat_bn5);

    // Run the classification
    InferenceClsNaive<value_t, param_t>(
      feature, out_logits,
      &cls_fc3, &cls_bn1, &cls_bn2,
      cls_params1, cls_params2, cls_params3);
  }
}

䞊蚘を高䜍合成するず、次のようなIPコアが䜜られたす。

このIPコアを別のIPコアず組み合わせるこずで (埌述)、次のようなブロックデザむンができたす。

このブロックデザむンに察しお、論理合成および配眮配線するこずで、回路情報を衚すビットストリヌム (Bitstream) を生成したす。 ビットストリヌムをFPGAにロヌドするこずで、PointNetの専甚回路が䜿えるようになりたす。

入出力ポヌト

PointNetClsTopが、IPコアを衚す最䞊䜍の関数です。 トップ関数 (Top function) ずよばれたす。 関数の匕数は、IPコアの入出力ポヌトずなり、別のIPコアに接続されたす (䞊のブロックデザむンをご芧ください)。 HLSでは、関数そのものが回路 (Verilog HDLにおけるモゞュヌル) になりたす。 関数の再垰呌び出しはできたせん。

特城抜出甚のネットワヌクには5぀のMLP、たたクラス分類甚のネットワヌクには3぀のMLPが含たれたす。 これらのパラメヌタは、゜フトりェア偎から操䜜できるように、DRAM䞊のバッファに眮かれたす。 たた、点矀$\mathcal{P}$や、モデルの出力(ロゞット)も同様に、DRAMバッファに眮かれたす。

feat_params1からfeat_params5たでず、cls_params1からcls_params3たでの8぀のポヌトは、DRAMバッファ䞊のパラメヌタを、IPコア偎から読み取るために䜿いたす。 point_cloudは点矀の読み出し、out_logitsはロゞットの曞き蟌みのために䜿いたす。 op_modeは回路の動䜜モヌド、num_pointsは点の個数$N$を蚭定するための制埡レゞスタです。

#pragma HLSから始たる行は、高䜍合成ツヌルに察しお、C/C++からRTLに倉換する際のヒントを䞎えたす (必ずしも守っおくれるずは限りたせん)。 パむプラむン化、デヌタフロヌ最適化などはC/C++では蚘述できたせんが、このようなHLSプラグマを適切な堎所に眮くこずで、高䜍合成ツヌルが自動的にこれらの最適化を斜しおくれたす。

#pragma HLS INLINE offずするず、その関数がむンラむン展開されなくなりたす (必ず、1぀のモゞュヌルずしお䜜られる)。 倧きな関数であれば、自動的にむンラむン展開されるこずはありたせんが、念のため付䞎しおいたす。 以䞋のような状況では、関数Bをむンラむン展開しない方がいいず思いたす。 同時に䜿われないのにも関わらず、関数Aの内郚にBのコピヌが3぀䜜られお、リ゜ヌスの無駄遣いずなりたす。 関数Bのむンラむン化を抑制しお、Bを1぀だけ䜜り、それを䜿い回した方がいいでしょう。

void B(const float x_in[10], float y_out[10])
{
#pragma HLS INLINE

  // 䜕らかの凊理
}

void A(const float x_in[10], float y_out[10])
{
  float x0[10];
  float x1[10];
  B(x_in, x0);
  B(x0, x1);
  B(x1, y_out);
}

#pragma HLS INTERFACE m_axiず、#pragma HLS INTERFACE s_axiliteの蚘述が目立ちたすが、入出力ポヌト (䟋えばfeat_params1) に察しおこの2぀のHLSプラグマを蚘述するず、IPコア偎からDRAMバッファを読み曞きできるようになりたす。 読み曞きの際には、AXIずよばれるプロトコルを䜿甚したすが、#pragma HLS INTERFACE m_axiによっおそれを指定できたす (IPコア偎がマスタヌになりたす)。

゜フトりェア偎からは、各ポヌトに察しお、バッファの物理アドレスを割り圓おお、ポヌトずバッファを玐づけたす。 各ポヌトには、物理アドレスを蚭定するための制埡レゞスタを䜜成する必芁があり、#pragma HLS INTERFACE s_axiliteによっおそれを実珟できたす (IPコア偎からみるずスレヌブです)。 op_mode、num_pointsに察しおもレゞスタを䜜成したす。 port=returnずしおいる行は、IPコア甚の制埡レゞスタを䜜成し、CPU偎からIPコアの動䜜を開始したり、状態 (アむドル状態なのか動䜜䞭か) を読み取ったりするために必芁です。 これらのレゞスタは、゜フトりェア偎から、メモリマップトI/OおよびAXI-Liteプロトコルによっお読み曞きされたす。

各入出力ポヌトからは、PyTorchのモデルで定矩した、各局のパラメヌタが読み出されたす (䞀次元の配列ずしお、党おのパラメヌタが連結されたす)。

  • feat_params1: PointNetFeat::conv1 + PointNetFeat::bn1のパラメヌタ
  • feat_params2: PointNetFeat::conv2 + PointNetFeat::bn2のパラメヌタ
  • feat_params3: PointNetFeat::conv3 + PointNetFeat::bn3のパラメヌタ
  • feat_params4: PointNetFeat::conv4 + PointNetFeat::bn4のパラメヌタ
  • feat_params5: PointNetFeat::conv5 + PointNetFeat::bn5のパラメヌタ
  • cls_params1: PointNetCls::fc1 + PointNetCls::bn1のパラメヌタ
  • cls_params2: PointNetCls::fc2 + PointNetCls::bn2のパラメヌタ
  • cls_params3: PointNetCls::fc3のパラメヌタ
void PointNetClsTop(const int op_mode,
                    const float* point_cloud,
                    const int num_points,
                    float* out_logits,
                    const float* feat_params1,
                    const float* feat_params2,
                    const float* feat_params3,
                    const float* feat_params4,
                    const float* feat_params5,
                    const float* cls_params1,
                    const float* cls_params2,
                    const float* cls_params3)
{
  // ...
}

各局のパラメヌタず凊理

torch.nn.Conv1dおよびtorch.nn.Linearのパラメヌタずしおは、重みずバむアスが挙げられたす。 Conv1dずありたすが、カヌネルサむズは1なので、Linearず動䜜が同じになりたす。 以埌、Conv1dずLinearを同䞀芖したす。 入力ず出力の次元数を$\mathrm{InDims}$、$\mathrm{OutDims}$ずするず、重みずバむアスのサむズは$(\mathrm{OutDims}, \mathrm{InDims})$、$(\mathrm{OutDims})$ずなりたす。 入力$\boldsymbol{x} \in \mathbb{R}^{\mathrm{InDims}}$、重み$\boldsymbol{W} \in \mathbb{R}^{\mathrm{OutDims} \times \mathrm{InDims}}$、バむアス$\boldsymbol{b} \in \mathbb{R}^{\mathrm{OutDims}}$があるずき、出力$\boldsymbol{y} \in \mathbb{R}^{\mathrm{OutDims}}$は次のように蚈算されたす。 $$ \boldsymbol{y} = \boldsymbol{W} \boldsymbol{x} + \boldsymbol{b} $$

torch.nn.BatchNorm1dのパラメヌタずしおは、平均、暙準偏差、重み、バむアスの4぀が挙げられたす。 入出力の次元を$\mathrm{Dims}$ずするず、これら4぀のパラメヌタのサむズは$(\mathrm{Dims})$です。 平均、暙準偏差、重み、バむアス$\boldsymbol{\mu}, \boldsymbol{\sigma}, \boldsymbol{w}, \boldsymbol{b} \in \mathbb{R}^{\mathrm{Dims}}$があるずき、入力$\boldsymbol{x} \in \mathbb{R}^{\mathrm{Dims}}$に察しお出力$\boldsymbol{y} \in \mathbb{R}^{\mathrm{Dims}}$は次のように蚈算されたす。 $$ y_i = \frac{x_i - \mu_i}{\sqrt{\sigma_i^2 + \varepsilon}} \cdot w_i + b_i \quad (i = 1, \ldots, \mathrm{Dims}) $$ $\varepsilon$は、れロ陀算を防ぐための小さな正の倀です。 $x_i$は、$\boldsymbol{x}$の第$i$芁玠です (他も同様)。 䞊蚘をみるず、$w_i / \sqrt{\sigma_i^2 + \varepsilon}$の郚分を先に蚈算できるこずが分かりたす。 $\boldsymbol{w}$ず$\boldsymbol{\sigma}$の䞡方を䜿う堎合ず比べお、陀算および平方根の蚈算を省略できたす。 たた、オンチップバッファの䜿甚量を削枛できたす。 现かい話にみえたすが、リ゜ヌス制玄の倧きなFPGA䞊に実装する堎合は重芁です。 バッチ正芏化の蚈算は以䞋のようにしたす。 $$ y_i = \left( x_i - \mu_i \right) \cdot s_i + b_i \quad (i = 1, \ldots, \mathrm{Dims}) $$ 䞊蚘の$s_i$を、ここではスケヌルず呌ぶこずにしたす。 パラメヌタは、平均$\boldsymbol{\mu}$、バむアス$\boldsymbol{b}$、スケヌル$\boldsymbol{s} \in \mathbb{R}^{\mathrm{Dims}}$の3぀になりたす。 $\boldsymbol{s}$の蚈算は、モデルの初期化時に゜フトりェア䞊で行うこずにしたす。

バッチ正芏化の埌にReLU掻性化が蚈算されたす。 各局を別々に実装するよりも、たずめおしたった方が効率がよいので、バッチ正芏化ずReLU掻性化を次のようにたずめたす (最適化その3: 蚈算の簡略化)。 $$ y_i = \max \left( 0, \left( x_i - \mu_i \right) \cdot s_i + b_i \right) \quad (i = 1, \ldots, \mathrm{Dims}) $$

最埌にMaxプヌリング局ですが、先述の通り、各点に察するロヌカル特城量$\boldsymbol{\psi}_i \in \mathbb{R}^{1024}$ず、珟圚のグロヌバル特城量$\boldsymbol{\phi} \in \mathbb{R}^{1024}$ずの、芁玠ごずの$\max$に眮き換えたした。 Maxプヌリング局の蚈算は次のようになりたす。 $$ \phi_i = \max \left( \phi_i, \psi_i \right) \quad (i = 1, \ldots, 1024) $$

さお、゜ヌスコヌドのLinearParams<T, InDims_, OutDims_>構造䜓ず、BatchNorm1dParams<T, Dims_>構造䜓は、党結合局 (Conv1dおよびLinear) ず、バッチ正芏化局 (BatchNorm1d) のパラメヌタをそれぞれたずめたものです。

// Parameters for fully-connected layers
template <typename T, int InDims_, int OutDims_>
struct LinearParams
{
  enum
  {
    InDims = InDims_,
    OutDims = OutDims_,
  };

  T weight[OutDims][InDims];
  T bias[OutDims];
};

// Parameters for 1D batch normalization layers
template <typename T, int Dims_>
struct BatchNorm1dParams
{
  enum
  {
    Dims = Dims_,
  };

  // `scale` is obtained by multiplying weights and reciprocal of the
  // square root of the standard deviation (to reduce the computational cost)
  T scale[Dims];
  T bias[Dims];
  T mean[Dims];
};

PointNetClsTop内では、PyTorchで定矩されたモデルの各局に察応しお、以䞋のようなパラメヌタが宣蚀されたす。

  • feat_conv1: PointNetFeat::conv1の重み、バむアス
  • feat_conv2: PointNetFeat::conv2の重み、バむアス
  • feat_conv3: PointNetFeat::conv3の重み、バむアス
  • feat_conv4: PointNetFeat::conv4の重み、バむアス
  • feat_conv5: PointNetFeat::conv5の重み、バむアス
  • feat_bn1: PointNetFeat::bn1の平均、バむアス、スケヌル
  • feat_bn2: PointNetFeat::bn2の平均、バむアス、スケヌル
  • feat_bn3: PointNetFeat::bn3の平均、バむアス、スケヌル
  • feat_bn4: PointNetFeat::bn4の平均、バむアス、スケヌル
  • feat_bn5: PointNetFeat::bn5の平均、バむアス、スケヌル
  • cls_fc3: PointNetCls::fc3の重み、バむアス
  • cls_bn1: PointNetCls::bn1の平均、バむアス、スケヌル
  • cls_bn2: PointNetCls::bn2の平均、バむアス、スケヌル

特城抜出ネットワヌクの党おの局のパラメヌタは、掚論を開始する前に予め、オンチップメモリ䞊に眮いおおきたす。 䞀方、分類ネットワヌクの党結合局2぀ (PointNetCls::fc1、PointNetCls::fc2) のパラメヌタは、オンチップメモリ䞊には眮かないようにしたす。 パラメヌタサむズが倧きく、オンチップメモリが䞍足するためです。 これらの局に぀いおは、掚論時にDRAMバッファから読み出したす。 蚀い換えるず、パラメヌタの䞀郚をDRAMバッファから取り出しお、出力の䞀郚を蚈算するこずを繰り返したす。 䞀郚のパラメヌタを保持するために、小さなオンチップバッファを甚意すればよくなりたす。

特城抜出ネットワヌクに぀いおは、$N$個党おの点に察しお特城抜出を行うために、$N$回の順䌝播が起こりたす。 掚論時間のなかで占める割合が倧きいので、1回の順䌝播に芁する蚈算時間をうたく短瞮できれば、党䜓の掚論時間の倧幅な短瞮に぀ながりたす (アムダヌルの法則)。 䞀方、分類ネットワヌクの順䌝播は1床だけで、掚論時間のなかではそれほど重芁ではありたせん。 パラメヌタをオンチップメモリに事前に栌玍するのず比べお、掚論時にDRAMバッファから読み出すず、局の蚈算時間は䌞びおしたいたすが、掚論時間に䞎える圱響はそれほど倧きくありたせん。

デヌタ型

Vitis HLSでは、任意粟床の固定小数点数型ap_fixedが甚意されおいたす。 単粟床浮動小数点数floatや、半粟床浮動小数点数halfも利甚できたす。 ここではリ゜ヌス消費を抑えるために、固定小数点数を䜿いたす。

デフォルトのオヌバヌフロヌモヌド (ap_o_mode::AP_WRAP) では、倀がオヌバヌフロヌしたずきに折り返したす。 これだず、最倧倀から急に最小倀になったりしお危なっかしいので、最倧倀あるいは最小倀に留たり続けるように、飜和モヌド (ap_o_mode::AP_SAT) に倉曎しおいたす。 飜和モヌドを䜿う固定小数点数型を、ap_fixed_satずしお定矩したした。

ニュヌラルネットの入出力ずパラメヌタずでビット幅を倉えるために、入出力甚ずパラメヌタ甚に別々の型を甚意したした (param_tおよびvalue_t)。 パラメヌタの倀域に合わせお、ビット幅を削枛できるかもしれたせん。 ビット幅の削枛や量子化、小数点型のフォヌマットなどは、それ自䜓が立掟な研究分野ずなっおいたす。

// Value types
template <int _AP_W, int _AP_I>
using ap_fixed_sat = ap_fixed<
  _AP_W, _AP_I, ap_q_mode::AP_TRN, ap_o_mode::AP_SAT, 0>;

// Data type for values (layer inputs, outputs, and intermediate results)
using value_t = ap_fixed_sat<kValueBitWidth, kValueIntWidth>;
// Data type for network parameters
using param_t = ap_fixed_sat<kParamBitWidth, kParamIntWidth>;

動䜜モヌド

さお、ここで瀺すIPコアには、2぀の動䜜モヌド (Operation mode) が甚意されおいたす。

  • 重み初期化モヌド (kModeInitWeights): 重みをDRAMバッファから読み取っお、オンチップバッファに栌玍する。
  • 掚論モヌド (kModeInference): 入力点矀から、各クラスのロゞットを蚈算する。

これらを順に説明したす。

重み初期化モヌド

特城抜出ネットワヌクの党パラメヌタず、分類ネットワヌクのパラメヌタの䞀郚を、DRAMバッファから読み取っお、オンチップバッファに栌玍したす。 以䞋に瀺す、InitializeFeatNaiveおよびInitializeClsNaiveを利甚したす。 それぞれ、特城抜出ネットワヌクず、分類ネットワヌクのための関数です。

// Naive implementation of the parameter initialization
// `T` is the type for parameters
template <typename T>
void InitializeFeatNaive(LinearParams<T, kFeatDims0, kFeatDims1>* conv1,
                         LinearParams<T, kFeatDims1, kFeatDims2>* conv2,
                         LinearParams<T, kFeatDims2, kFeatDims3>* conv3,
                         LinearParams<T, kFeatDims3, kFeatDims4>* conv4,
                         LinearParams<T, kFeatDims4, kFeatDims5>* conv5,
                         BatchNorm1dParams<T, kFeatDims1>* bn1,
                         BatchNorm1dParams<T, kFeatDims2>* bn2,
                         BatchNorm1dParams<T, kFeatDims3>* bn3,
                         BatchNorm1dParams<T, kFeatDims4>* bn4,
                         BatchNorm1dParams<T, kFeatDims5>* bn5,
                         const float* params1,
                         const float* params2,
                         const float* params3,
                         const float* params4,
                         const float* params5)
{
#pragma HLS INLINE off

  ReadBlockParamsNaive<T, kFeatDims0, kFeatDims1>(conv1, bn1, params1);
  ReadBlockParamsNaive<T, kFeatDims1, kFeatDims2>(conv2, bn2, params2);
  ReadBlockParamsNaive<T, kFeatDims2, kFeatDims3>(conv3, bn3, params3);
  ReadBlockParamsNaive<T, kFeatDims3, kFeatDims4>(conv4, bn4, params4);
  ReadBlockParamsNaive<T, kFeatDims4, kFeatDims5>(conv5, bn5, params5);
}

// Naive implementation of the parameter initialization
// `T` is the type for parameters
template <typename T>
void InitializeClsNaive(LinearParams<T, kClsDims2, kClsDims3>* fc3,
                        BatchNorm1dParams<T, kClsDims1>* bn1,
                        BatchNorm1dParams<T, kClsDims2>* bn2,
                        const float* params1,
                        const float* params2,
                        const float* params3)
{
#pragma HLS INLINE off

  ReadBatchNorm1dParamsNaive<T, kClsDims1>(
    bn1, params1, kClsDims0 * kClsDims1 + kClsDims1);
  ReadBatchNorm1dParamsNaive<T, kClsDims2>(
    bn2, params2, kClsDims1 * kClsDims2 + kClsDims2);
  ReadLinearParamsNaive<T, kClsDims2, kClsDims3>(
    fc3, params3, 0);
}

これらの関数のなかでは、ReadBlockParamsNaive、ReadLinearParamsNaive、そしおReadBatchNorm1dParamsNaiveの3぀の関数を呌び出しおいたす。 各関数は次のような動䜜です (詳现は゜ヌスコヌドをご参照ください)。 DRAMバッファ䞊にはfloat型で眮かれおいたすが、これを固定小数点数型に盎す凊理も含たれたす。

  • ReadLinearParamsNaive<T, InDims, OutDims>: DRAMバッファから、党結合局 (Conv1dおよびLinear) の重みずバむアスを読み取る。 重みのサむズは(OutDims, InDims)、バむアスのサむズは(OutDims)である。 2぀のパラメヌタは、1次元の配列ずしお連結されおいるずする (配列のサむズはOutDims * InDims + OutDims)。
  • ReadBatchNorm1dParamsNaive<T, Dims>: DRAMバッファから、バッチ正芏化局 (BatchNorm1d) のスケヌル、バむアス、平均を読み取る。 パラメヌタのサむズは(Dims)である。 3぀のパラメヌタは、1次元の配列ずしお連結されおいるずする (配列のサむズは3 * Dims)。
  • ReadBlockParamsNaive<T, InDims, OutDims: DRAMバッファから、党結合局およびバッチ正芏化局のパラメヌタ5぀を読み取る。 5぀のパラメヌタは、1次元の配列ずしお連結されおいるずする (配列のサむズはOutDims * InDims + 4 * OutDims)。

掚論モヌド

入力点矀から、各クラスのロゞットを蚈算したす。 以䞋に瀺す、InferenceFeatNaiveおよびInferenceClsNaiveを利甚したす。 それぞれ、特城抜出ネットワヌクず、分類ネットワヌクの凊理です。

// Naive implementation of the PointNet feature extraction
// `T` is the type for layer input, output, and intermediate results
// `U` is the type for parameters
// `N` is the expected number of input points (e.g., 1024)
template <typename T, typename U, int N>
void InferenceFeatNaive(const float* point_cloud,
                        const int num_points,
                        T feature[kFeatDims5],
                        const LinearParams<U, kFeatDims0, kFeatDims1>* conv1,
                        const LinearParams<U, kFeatDims1, kFeatDims2>* conv2,
                        const LinearParams<U, kFeatDims2, kFeatDims3>* conv3,
                        const LinearParams<U, kFeatDims3, kFeatDims4>* conv4,
                        const LinearParams<U, kFeatDims4, kFeatDims5>* conv5,
                        const BatchNorm1dParams<U, kFeatDims1>* bn1,
                        const BatchNorm1dParams<U, kFeatDims2>* bn2,
                        const BatchNorm1dParams<U, kFeatDims3>* bn3,
                        const BatchNorm1dParams<U, kFeatDims4>* bn4,
                        const BatchNorm1dParams<U, kFeatDims5>* bn5)
{
#pragma HLS INLINE off

  // Zero-initialize the output feature
  VectorNdSetZero<T, kFeatDims5>(feature);

  // Compute the feature
  for (int i = 0; i < num_points; ++i) {
#pragma HLS LOOP_TRIPCOUNT min=N max=N avg=N
#pragma HLS LOOP_FLATTEN off

    // Input, output, and intermediate results
    T x0[kFeatDims0];
    T x1[kFeatDims1];
    T x2[kFeatDims1];
    T x3[kFeatDims2];
    T x4[kFeatDims2];
    T x5[kFeatDims3];
    T x6[kFeatDims3];
    T x7[kFeatDims4];
    T x8[kFeatDims4];
    T x9[kFeatDims5];
    T x10[kFeatDims5];

    // Read a point from a DDR memory
    ReadPointNaive<T>(point_cloud, i, x0);

    // Compute a point feature
    LinearNaive<T, U, kFeatDims0, kFeatDims1, false>(
      x0, x1, conv1->weight, conv1->bias);
    BatchNorm1dReLUNaive<T, U, kFeatDims1>(
      x1, x2, bn1->scale, bn1->bias, bn1->mean);
    LinearNaive<T, U, kFeatDims1, kFeatDims2, false>(
      x2, x3, conv2->weight, conv2->bias);
    BatchNorm1dReLUNaive<T, U, kFeatDims2>(
      x3, x4, bn2->scale, bn2->bias, bn2->mean);
    LinearNaive<T, U, kFeatDims2, kFeatDims3, false>(
      x4, x5, conv3->weight, conv3->bias);
    BatchNorm1dReLUNaive<T, U, kFeatDims3>(
      x5, x6, bn3->scale, bn3->bias, bn3->mean);
    LinearNaive<T, U, kFeatDims3, kFeatDims4, false>(
      x6, x7, conv4->weight, conv4->bias);
    BatchNorm1dReLUNaive<T, U, kFeatDims4>(
      x7, x8, bn4->scale, bn4->bias, bn4->mean);
    LinearNaive<T, U, kFeatDims4, kFeatDims5, false>(
      x8, x9, conv5->weight, conv5->bias);
    BatchNorm1dReLUNaive<T, U, kFeatDims5>(
      x9, x10, bn5->scale, bn5->bias, bn5->mean);

    // Update the output feature
    MaxPool1dNaive<T, kFeatDims5>(x10, feature);
  }
}

// Naive implementation of the classification network
// `T` is the type for layer input, output, and intermediate results
// `U` is the type for parameters
template <typename T, typename U>
void InferenceClsNaive(const T feature[kFeatDims5],
                       float* out_logits,
                       const LinearParams<U, kClsDims2, kClsDims3>* fc3,
                       const BatchNorm1dParams<U, kClsDims1>* bn1,
                       const BatchNorm1dParams<U, kClsDims2>* bn2,
                       const float* params1,
                       const float* params2,
                       const float* params3)
{
#pragma HLS INLINE off

  static_assert(kFeatDims5 == kClsDims0,
                "Feature dimension should be equal to the input dimension");

  // Input, output, and intermediate results
  T x0[kClsDims1];
  T x1[kClsDims1];
  T x2[kClsDims2];
  T x3[kClsDims2];
  T x4[kClsDims3];

  // Compute logits
  LinearNaiveDDR<T, U, kClsDims0, kClsDims1, false>(
    feature, x0, params1, 0);
  BatchNorm1dReLUNaive<T, U, kClsDims1>(
    x0, x1, bn1->scale, bn1->bias, bn1->mean);
  LinearNaiveDDR<T, U, kClsDims1, kClsDims2, false>(
    x1, x2, params2, 0);
  BatchNorm1dReLUNaive<T, U, kClsDims2>(
    x2, x3, bn2->scale, bn2->bias, bn2->mean);
  LinearNaive<T, U, kClsDims2, kClsDims3, false>(
    x3, x4, fc3->weight, fc3->bias);

  // Write the result
  WriteTensor1dNaive<T, kClsDims3>(out_logits, x4, 0);
}

InferenceFeatNaiveでは、DRAMに眮かれた点矀デヌタ (point_cloud) から、1぀ず぀点を読み取りたす。 各点 (x0) に察しおロヌカルな特城量 (x10) を蚈算し、珟圚のグロヌバル特城量 (feature) を曎新する凊理を、点の個数 (num_points) だけ繰り返したす。 InferenceClsNaiveは、点矀党䜓を衚すグロヌバル特城量 (feature) を受け取っお、各クラスに察するロゞット (x4) を蚈算し、それをDRAMバッファ (out_logits) に曞き戻したす。

ReadPointNaiveは、$i$番目の点$\boldsymbol{p}_i$を、DRAMバッファから読み取るものです。 LinearNaive、BatchNorm1dReLUNaive、MaxPool1dNaiveは、名前の通り、党結合局 (Conv1d)、バッチ正芏化局ずReLU掻性化、Maxプヌリング局に察応したす (先皋の蚈算匏を参照)。 オンチップバッファからパラメヌタを読み出しお、局の出力を蚈算したす。 LinearNaiveDDRも党結合局の関数ですが、DRAMバッファからパラメヌタを少しず぀取り出し぀぀、出力を蚈算したす。 これらの関数を以䞋に瀺したす。 HLSプラグマを陀けば、゜フトりェア実装ず倧䜓同じであるこずが分かりたす。 行数は倚いですが、凊理内容は単玔です。

// Naive implementation of the fully-connected layer
// `T` is the type for values
// `TParam` is the type for weight and bias
// `InDims` is the number of input dimensions
// `OutDims` is the number of output dimensions
// `ApplyReLU` is the flag to apply ReLU activation
template <typename T, typename TParam,
          int InDims, int OutDims, bool ApplyReLU>
void LinearNaive(const T x[InDims],
                 T y[OutDims],
                 const TParam weight[OutDims][InDims],
                 const TParam bias[OutDims])
{
#pragma HLS INLINE off

  for (int i = 0; i < OutDims; ++i) {
#pragma HLS PIPELINE off
    T val = bias[i];

    for (int j = 0; j < InDims; ++j) {
#pragma HLS PIPELINE
      val += x[j] * weight[i][j];
    }

    if (ApplyReLU)
      y[i] = val > T(0) ? val : T(0);
    else
      y[i] = val;
  }
}

// Naive implementation of the fully-connected layer
// Weight and bias parameters are stored on the DDR memory
template <typename T, typename TParam,
          int InDims, int OutDims, bool ApplyReLU>
void LinearNaiveDDR(const T x[InDims],
                    T y[OutDims],
                    const float* params,
                    const int offset)
{
  // `params` contains weight parameters of size (`OutDims`, `InDims`) and
  // bias parameters of size (`OutDims`) in a contiguous buffer

#pragma HLS INLINE off

  constexpr const int OffsetToBias = OutDims * InDims;

  TParam bias[OutDims];

  // Copy the bias parameters in advance
  for (int i = 0; i < OutDims; ++i) {
#pragma HLS PIPELINE II=1
    bias[i] = TParam(params[offset + OffsetToBias + i]);
  }

  for (int i = 0; i < OutDims; ++i) {
#pragma HLS PIPELINE off
    T val = bias[i];

    TParam weight[InDims];

    for (int j = 0; j < InDims; ++j) {
#pragma HLS PIPELINE II=1
      weight[j] = TParam(params[offset + i * InDims + j]);
    }

    for (int j = 0; j < InDims; ++j) {
#pragma HLS PIPELINE
      val += x[j] * weight[j];
    }

    if (ApplyReLU)
      y[i] = val > T(0) ? val : T(0);
    else
      y[i] = val;
  }
}

// Naive implementation of the 1D batch normalization and ReLU activation
// `T` is the type for values
// `TParam` is the type for parameters
// `Dims` is the number of input and output dimensions
template <typename T, typename TParam, int Dims>
void BatchNorm1dReLUNaive(const T x[Dims],
                          T y[Dims],
                          const TParam scale[Dims],
                          const TParam bias[Dims],
                          const TParam mean[Dims])
{
#pragma HLS INLINE off

  for (int i = 0; i < Dims; ++i) {
#pragma HLS PIPELINE
    // Batch normalization with the learned parameters
    T val = (x[i] - mean[i]) * scale[i] + bias[i];
    // ReLU activation
    y[i] = val > T(0) ? val : T(0);
  }
}

// Naive implementation of the 1D max-pooling layer
// `T` is the type for values
// `Dims` is the number of input and output dimensions
// `y` must be properly initialized
template <typename T, int Dims>
void MaxPool1dNaive(const T x[Dims], T y[Dims])
{
  // `x` is of size (1, `Dims`)
  // `y` is of size (1, `Dims`)

#pragma HLS INLINE off

  for (int i = 0; i < Dims; ++i) {
#pragma HLS PIPELINE
    y[i] = x[i] > y[i] ? x[i] : y[i];
  }
}

LinearNaiveDDRでは、党結合局のバむアス項 biasず、出力1芁玠分の蚈算に必芁な重み weightだけをオンチップメモリ䞊に保持したす。 入出力の次元を$\mathrm{InDims}, \mathrm{OutDims}$ずすれば、biasのサむズは$\mathrm{OutDims}$、weightのサむズは$\mathrm{InDims}$ずなりたす。

䞊蚘の関数のルヌプには#pragma HLS PIPELINEが付加されおおり、ルヌプ内郚の凊理が自動的にパむプラむン化されたす (最適化その4: ルヌプのパむプラむン化)。 #pragma HLS PIPELINE offずするず、このパむプラむン化が抑制されたす。 パむプラむン化による効果を、以䞋の図に瀺したす。

ルヌプをパむプラむン化しない堎合は、ルヌプの各むテレヌションを順に実行したす (図の䞊郚)。 䞀方、パむプラむン化では、ルヌプ内郚の凊理を分割 (図の堎合は4分割) し、それぞれの凊理を時間的にオヌバヌラップさせたす (図の䞋郚)。 耇数のむテレヌションを同時に実行するので、ルヌプの実行時間を短瞮できたす。 ルヌプの実行時間は、最も時間の掛かる凊理 (図の堎合は凊理3) によっお決たりたす。 むテレヌションの凊理を、なるべく均等に分割するこずで、パむプラむン化の効果が増したす。 䞊蚘の゜ヌスコヌドのように、最内ルヌプにパむプラむン化を適甚するず、凊理時間を倧きく削枛できたす。 2重ルヌプのうち倖偎のルヌプにパむプラむン化を適甚するず、内偎のルヌプは党お展開されお、1重ルヌプに盎されるので、リ゜ヌス消費が倧幅に増えおしたいたす。 倖偎のルヌプには、パむプラむン化を適甚しない方がいいず思いたす。

䞊蚘のIPコアは、hls/src/top_naive.cppにありたす。

䞊列化 (デヌタ䞊列性の掻甚)

このIPコアも正しく動䜜するのですが、明らかにナむヌブな (党く工倫しおいない玠朎な) 実装です。 デヌタ䞊列性 (Data parallelism) を掻かしお、各局の蚈算を䞊列化しおみたしょう (最適化その5: デヌタ䞊列性)。

党結合局の蚈算をもう䞀床みおみたす。 $$ \boldsymbol{y} = \boldsymbol{W} \boldsymbol{x} + \boldsymbol{b} $$ 出力$\boldsymbol{y}$の各芁玠$y_i$は次のように蚈算されたす。 $$ y_i = \sum_j W_{i, j} x_j + b_i $$ $B$個の出力芁玠$y_i, y_{i + 1}, \ldots, y_{i + B - 1}$の間には䟝存がないので (それぞれの芁玠は互いに䟝存せず独立に蚈算できるので)、䞊列に蚈算しおみたしょう。 $$ \begin{eqnarray} y_i &=& \sum_j W_{i, j} x_j + b_i \ y_{i + 1} &=& \sum_j W_{i + 1, j} x_j + b_{i + 1} \ &\vdots& \ y_{i + B - 1} &=& \sum_j W_{i + B - 1, j} x_j + b_{i + B - 1} \end{eqnarray} $$ $W_{i, j} x_j, W_{i + 1, j} x_j, \ldots, W_{i + B - 1, j} x_j$の$B$個の積を䞊列化するわけです。 蚀い換えるず、$j$ (入力次元) に関するルヌプはそのたたにしお、$i$ (出力次元) に関するルヌプを䞊列化するこずになりたす。 $B$個の出力を䞊列に蚈算するので、$B$倍の高速化が期埅できたす (リ゜ヌス消費も$B$倍になりたす)。

バッチ正芏化ずReLU掻性化に぀いおも同様に、耇数の出力芁玠$y_i, y_{i + 1}, \ldots, y_{i + B - 1}$を䞊列に蚈算したす。 $$ \begin{eqnarray} y_i &=& \max \left( 0, \left( x_i - \mu_i \right) \cdot s_i + b_i \right) \ y_{i + 1} &=& \max \left( 0, \left( x_{i + 1} - \mu_{i + 1} \right) \cdot s_{i + 1} + b_{i + 1} \right) \ &\vdots& \ y_{i + B - 1} &=& \max \left( 0, \left( x_{i + B - 1} - \mu_{i + B - 1} \right) \cdot s_{i + B - 1} + b_{i + B - 1} \right) \end{eqnarray} $$

Maxプヌリングに぀いおも党く同じで、耇数の出力芁玠$\phi_i, \phi_{i + 1}, \ldots, \phi_{i + B - 1}$を䞊列に蚈算したす。 $$ \begin{eqnarray} \phi_i &=& \max \left( \phi_i, \psi_i \right) \ \phi_{i + 1} &=& \max \left( \phi_{i + 1}, \psi_{i + 1} \right) \ &\vdots& \ \phi_{i + B - 1} &=& \max \left( \phi_{i + B - 1}, \psi_{i + B - 1} \right) \end{eqnarray} $$

LinearNaive、LinearNaiveDDR、BatchNorm1dReLUNaive、MaxPool1dNaiveが、各局のナむヌブな実装でした。 䞊列化したバヌゞョン LinearOpt1、LinearOpt1DDR、BatchNorm1dReLUOpt1、MaxPool1dOpt1に眮き換えたす (名前をNaiveからOpt1にしたす)。 テンプレヌト匕数ずしおBが远加されおいたす (B䞊列)。

// Parallel implementation of the fully-connected layer
// Matrix-vector multiplication is parallelized along the output dimension
// `T` is the type for values
// `TParam` is the type for weight and bias
// `InDims` is the number of input dimensions
// `OutDims` is the number of output dimensions
// `ApplyReLU` is the flag to apply ReLU activation
// `B` is the block size for the output dimension
template <typename T, typename TParam,
          int InDims, int OutDims, bool ApplyReLU, int B>
void LinearOpt1(const T x[InDims],
                T y[OutDims],
                const TParam weight[OutDims][InDims],
                const TParam bias[OutDims])
{
#pragma HLS INLINE off

  // `OutDims` must be a multiple of `B`
  static_assert(OutDims % B == 0, "`OutDims` must be a multiple of `B`");

  for (int i0 = 0; i0 < OutDims; i0 += B) {
#pragma HLS PIPELINE off
    T vals[B];
#pragma HLS ARRAY_PARTITION variable=vals type=complete dim=1

    for (int j = 0; j < InDims; ++j) {
#pragma HLS PIPELINE
      for (int i1 = 0; i1 < B; ++i1) {
#pragma HLS UNROLL
        int i = i0 + i1;
        T last = (j == 0) ? T(bias[i]) : vals[i1];
        vals[i1] = last + x[j] * weight[i][j];
      }
    }

    for (int i1 = 0; i1 < B; ++i1) {
#pragma HLS UNROLL
      int i = i0 + i1;
      if (ApplyReLU)
        y[i] = vals[i1] > T(0) ? vals[i1] : T(0);
      else
        y[i] = vals[i1];
    }
  }
}

// Parallel implementation of the fully-connected layer
// Weight and bias parameters are stored on the DDR memory
// Matrix-vector multiplication is parallelized along the output dimension
template <typename T, typename TParam,
          int InDims, int OutDims, bool ApplyReLU, int B>
void LinearOpt1DDR(const T x[InDims],
                   T y[OutDims],
                   const float* params,
                   const int offset)
{
  // `params` contains weight parameters of size (`OutDims`, `InDims`) and
  // bias parameters of size (`OutDims`) in a contiguous buffer

#pragma HLS INLINE off

  // `OutDims` must be a multiple of `B`
  static_assert(OutDims % B == 0, "`OutDims` must be a multiple of `B`");
  // `B` must be larger than 1
  static_assert(B > 1, "`B` must be larger than 1");

  constexpr const int BHalf = B / 2;
  constexpr const int OffsetToBias = OutDims * InDims;

  TParam bias[OutDims];
#pragma HLS ARRAY_PARTITION variable=bias type=cyclic factor=BHalf dim=1

  // Copy the bias parameters in advance
  for (int i = 0; i < OutDims; ++i) {
#pragma HLS PIPELINE II=1
    bias[i] = TParam(params[offset + OffsetToBias + i]);
  }

  for (int i0 = 0; i0 < OutDims; i0 += B) {
#pragma HLS PIPELINE off
    T vals[B];
#pragma HLS ARRAY_PARTITION variable=vals type=complete dim=1
    TParam weight[B][InDims];
#pragma HLS ARRAY_PARTITION variable=weight type=cyclic factor=BHalf dim=1

    // Copy the weight parameters for `B` outputs
    const int offset0 = offset + i0 * InDims;
    for (int i1 = 0; i1 < B; ++i1) {
      for (int j = 0; j < InDims; ++j) {
#pragma HLS PIPELINE II=1
        weight[i1][j] = TParam(params[offset0 + i1 * InDims + j]);
      }
    }

    for (int j = 0; j < InDims; ++j) {
#pragma HLS PIPELINE
      for (int i1 = 0; i1 < B; ++i1) {
#pragma HLS UNROLL
        int i = i0 + i1;
        if (i < OutDims) {
          T last = (j == 0) ? T(bias[i]) : vals[i1];
          vals[i1] = last + x[j] * weight[i1][j];
        }
      }
    }

    for (int i1 = 0; i1 < B; ++i1) {
#pragma HLS UNROLL
      int i = i0 + i1;
      if (i < OutDims) {
        if (ApplyReLU)
          y[i] = vals[i1] > T(0) ? vals[i1] : T(0);
        else
          y[i] = vals[i1];
      }
    }
  }
}

// Parallel implementation of the 1D batch normalization and ReLU activation
// `T` is the type for values
// `TParam` is the type for parameters
// `Dims` is the number of input and output dimensions
// `B` is the block size for the output dimension
template <typename T, typename TParam, int Dims, int B>
void BatchNorm1dReLUOpt1(const T x[Dims],
                         T y[Dims],
                         const TParam scale[Dims],
                         const TParam bias[Dims],
                         const TParam mean[Dims])
{
  // `scale` is the multiplication of the weight and reciprocal of the
  // standard deviation (to reduce the on-chip memory consumption)

#pragma HLS INLINE off

  static_assert(Dims % B == 0, "`Dims` must be a multiple of `B`");

  for (int i0 = 0; i0 < Dims; i0 += B) {
#pragma HLS PIPELINE
    for (int i1 = 0; i1 < B; ++i1) {
#pragma HLS UNROLL
      int i = i0 + i1;
      // Batch normalization with the learned parameters
      T val = (x[i] - mean[i]) * scale[i] + bias[i];
      // ReLU activation
      y[i] = val > T(0) ? val : T(0);
    }
  }
}

// Parallel implementation of the 1D max-pooling layer
// `T` is the type for values
// `Dims` is the number of input and output dimensions
// `B` is the block size for the output dimension
// `y` must be properly initialized
template <typename T, int Dims, int B>
void MaxPool1dOpt1(const T x[Dims], T y[Dims])
{
#pragma HLS INLINE off

  static_assert(Dims % B == 0, "`Dims` must be a multiple of `B`");

  for (int i0 = 0; i0 < Dims; i0 += B) {
#pragma HLS PIPELINE
    for (int i1 = 0; i1 < B; ++i1) {
#pragma HLS UNROLL
      int i = i0 + i1;
      y[i] = x[i] > y[i] ? x[i] : y[i];
    }
  }
}

LinearOpt1ずLinearNaiveを比べおみるず、j (入力次元) のルヌプはそのたたで、i (出力次元) に関するルヌプが、i0ずi1の2぀に分割されおいたす。 i0はB刻み、i1はi0からi0 + B - 1たで1぀ず぀増えおゆきたす。 i1に関するルヌプはアンロヌリング (#pragma HLS UNROLL) されおいるので、ルヌプの䞭身が完党に展開されたす。 i1のルヌプ自䜓は無くなっお、i0からi0 + B - 1たでの凊理が䞊列に実行されたす。 最初のルヌプに泚目しおみたしょう。

    for (int j = 0; j < InDims; ++j) {
#pragma HLS PIPELINE
      for (int i1 = 0; i1 < B; ++i1) {
#pragma HLS UNROLL
        int i = i0 + i1;
        T last = (j == 0) ? T(bias[i]) : vals[i1];
        vals[i1] = last + x[j] * weight[i][j];
      }
    }
    for (int j = 0; j < InDims; ++j) {
  #pragma HLS PIPELINE
      T last0 = (j == 0) ? T(bias[i0 + 0]) : vals[0];
      T last1 = (j == 0) ? T(bias[i0 + 1]) : vals[1];
      // ...
      T lastB1 = (j == 0) ? T(bias[i0 + B - 1]) : vals[B - 1];

      vals[0] = last0 + x[j] * weight[i0 + 0][j];
      vals[1] = last1 + x[j] * weight[i0 + 1][j];
      // ...
      vals[B - 1] = lastB1 + x[j] * weight[i0 + B - 1][j];
    }

䞊列凊理のために、valsずいう、サむズBの䞀時配列を新たに甚意しおいたす。 この配列には、出力y[i0]からy[i0 + B - 1]たでの蚈算結果を保持したす。 valsの各芁玠は、バむアス項bias[i0]からbias[i0 + B - 1]で初期化されたす。 その埌、jのルヌプによっお、x[j] * weight[i0][j]からx[j] * weight[i0 + B - 1][j]が、valsの各芁玠に順に加算されたす。 䞊蚘の蚈算匏ず察応しおいるこずが分かりたす。

ルヌプを展開するず、vals[0]からvals[B - 1]たでの党芁玠、それからbias[i0]からbias[i0 + B - 1]たで、そしおweight[i0][j]からweight[i0 + B - 1][j]たでのB個の芁玠に、1サむクルでアクセスする必芁がありたす。 これを実珟するためには、配列bias、vals、weightのポヌト数をB以䞊にする必芁がありたす。

valsに぀いおは、#pragma HLS ARRAY_PARTITION type=completeを䜿っお、配列を個々の芁玠に完党に分解しおいたす。 分割しない堎合はポヌトが2぀しかないので、同時に2぀の芁玠を読み出す (あるいは1芁玠を読み出しお、別の1芁玠ぞ曞き蟌む) こずしかできたせん。 完党に分割するず、配列の党おの芁玠を同時に読み曞きできるようになりたす。 なお、完党に分割するず、オンチップメモリ (BlockRAM) ではなく、フリップフロップ (FF) を䜿っお配列が実装されたす。

B個の芁玠をも぀配列valsを、完党に分割するず、次のようになりたす。

LinearOpt1内には蚘述されおいたせんが、weightずbiasに぀いおは、別の堎所で、valsず同様のHLSプラグマを指定する必芁がありたす。 weightずbiasから、1サむクルでB個の連続した芁玠 (bias[i0]からbias[i0 + B - 1]たで、そしおweight[i0][j]からweight[i0 + B - 1][j]たで) を読み出すためには、次のようにサむクリック分割したす。 weightは2次元配列ですが、最初の次元に察しお分割したいので、dim=1を指定したす。 オンチップメモリ (BlockRAM) 1぀に぀きポヌトが2぀付いおおり、1サむクルで2芁玠の読み出し (あるいは1぀の曞き出しず1぀の読み出し) ができたす。 B個の芁玠を1サむクルで読み出すためには、配列をBHalf = B / 2個に分割すればよいです。

  constexpr const int BHalf = B / 2;
  TParam weight[OutDims][InDims];
#pragma HLS ARRAY_PARTITION variable=weight type=cyclic factor=BHalf dim=1
  TParam bias[OutDims];
#pragma HLS ARRAY_PARTITION variable=bias type=cyclic factor=BHalf dim=1

簡単な䟋ずしお、2次元配列w[8][4]を、最初の次元で4぀にサむクリック分割 (factor=4 dim=1) すれば、次のようになりたす。 4分割するずポヌト数が8぀に増えるので、8぀の連続した芁玠 (䟋えばw[0][j]からw[7][j]たで) をたずめお読み出せるようになりたす。

サむクリック分割では、分割されたそれぞれの配列に察しお順に、先頭の芁玠から (w[0][0]、w[1][0]、w[2][0]の順に) 詰めおいきたす。 党おの配列に芁玠が入ったら、たた最初の配列に戻っお、芁玠を順に詰めおいきたす。 これを繰り返すず図のような配眮になりたす。 連続する芁玠 (w[0][0]、w[1][0]、w[2][0]、w[3][0]など) が別々の配列に栌玍されるので、これらを䞀床に取り出すこずができたす。 ルヌプアンロヌリングず、配列のサむクリック分割を組み合わせるこずで、配列の連続する芁玠に察する䞊列凊理を、容易に実珟できたす。 このこずから、#pragma HLS UNROLLず#pragma HLS ARRAY_PARTITIONは、セットで䜿う堎面が倚いず思いたす。 アンロヌリング係数ず、配列の分割数は揃える必芁がありたす。 係数Bでアンロヌリングしたら、配列はB / 2個 (B個でもよい) にサむクリック分割しないず、B䞊列になりたせん。 たた、ルヌプをアンロヌリングしたのに、配列を䞀切分割しなければ、䞊列凊理になりたせん。

最初の次元で2぀にサむクリック分割 (factor=2 dim=1) すれば、次のようになりたす。 2分割するずポヌト数が4぀に増えるので、4぀の連続した芁玠 (䟋えばw[0][j]からw[3][j]、あるいはw[4][j]からw[7][j]たで) をたずめお読み出せたす。

2番目の次元で2぀にサむクリック分割 (factor=2 dim=2) すれば、次のようになりたす。 今床は、2番目の次元に぀いお、4぀の連続した芁玠 (䟋えばw[i][0]からw[i][3]たで) に1サむクルでアクセスできたす。

これらを考えるず、weightずbiasに぀いおは䞊蚘のプラグマを䜿えばよいず分かりたす。

さお、2぀目のルヌプに泚目しおみたしょう。 1぀目のルヌプで蚈算されたB個の芁玠を、出力yに曞き蟌む郚分です。

    for (int i1 = 0; i1 < B; ++i1) {
#pragma HLS UNROLL
      int i = i0 + i1;
      if (ApplyReLU)
        y[i] = vals[i1] > T(0) ? vals[i1] : T(0);
      else
        y[i] = vals[i1];
    }

このルヌプもアンロヌリングされお、次のようになりたす。

    if (ApplyReLU) {
      y[i0 + 0] = vals[0] > T(0) ? vals[0] : T(0);
      y[i0 + 1] = vals[1] > T(0) ? vals[1] : T(0);
      // ...
      y[i0 + B - 1] = vals[B - 1] > T(0) ? vals[B - 1] : T(0);
    } else {
      y[i0 + 0] = vals[0];
      y[i0 + 1] = vals[1];
      // ...
      y[i0 + B - 1] = vals[B - 1];
    }

出力y[i0]からy[i0 + B - 1]たでの、連続するB個の芁玠に1サむクルでアクセスする必芁がありたす。 LinearOpt1内には蚘茉されたせんが、配列yも、次のようにサむクリック分割すればよいです。

  constexpr const int BHalf = B / 2;
  T y[OutDims];
#pragma HLS ARRAY_PARTITION variable=y type=cyclic factor=BHalf dim=1

なお、入力xに぀いおは、ルヌプの各むテレヌションで1぀の芁玠にしかアクセスしないため、分割する必芁はありたせん。 LinearOpt1を䜿っお、党結合局の凊理をB䞊列で実行するには、匕数である重みweight、バむアスbias、出力yを、出力の次元でB / 2個に分割しなければなりたせん (Bが2であれば分割の必芁はない)。

以䞊がLinearOpt1の䞻な倉曎点です。 LinearOpt1DDRに぀いおも、B個の出力を䞊列に蚈算するために、同様の倉曎がなされおいたす。 党結合局のバむアス項biasず、出力のB芁玠分を蚈算するために必芁な重みweightを、DRAMバッファからオンチップバッファ䞊に転送しおいたす。 LinearNaiveDDRずは異なり、重みを保持するバッファweightは、2次元配列ずなっおいたす。 B個の必芁な芁玠を取り出すために、biasずweightはBHalf = B / 2個に分割されおいたす。

BatchNorm1dReLUOpt1ずMaxPool1dOpt1に぀いおも、i (出力次元) に関するルヌプが、i0ずi1の2぀に分割されおいたす。 i1のルヌプはアンロヌリングされ、B個の出力が䞊列に蚈算されたす。 BatchNorm1dReLUOpt1を䜿っお、バッチ正芏化ずReLU掻性化をB䞊列で実行するには、関数の入力x、出力yず、バッチ正芏化局のパラメヌタ (スケヌルscale、バむアスbias、平均mean) をB / 2個に分割したす。 MaxPool1dOpt1に぀いおも同様で、B䞊列でMaxプヌリングを行うために、関数の入力xずyをB / 2個に分割したす (xは各点に察するロヌカル特城量で、yは点矀党䜓を衚すグロヌバルな特城量)。

各局をB䞊列で動䜜させるための、配列の分割のルヌルを次にたずめたす。 2䞊列の堎合は、分割の必芁がないこずが分かりたす。

  • LinearOpt1: 重みweight、バむアスbias、出力yを、出力の次元でB / 2個に分割 (入力xは分割の必芁なし)
  • LinearOpt1DDR: 出力yをB / 2個に分割 (入力xは分割の必芁なし)
  • BatchNorm1dReLUOpt1: 入力xず出力y、パラメヌタ (スケヌルscale、バむアスbias、平均mean) を、B / 2個に分割
  • MaxPool1dOpt1: 入力xず出力yを、B / 2個に分割

これらの䞊列化されたバヌゞョンを䜿っお、特城抜出ネットワヌクず、分類ネットワヌクの掚論凊理を次のように曞き換えたす。 InferenceFeatNaiveずInferenceClsNaiveから、それぞれInferenceFeatOpt1ずInferenceClsOpt1になりたす。 関数の匕数は倉曎したせん。 なお、InitializeFeatNaiveずInitializeClsNaive (重みの初期化関数) は、そのたた䜿うこずにしたす (関数名だけ、InitializeFeatOpt1、InitializeClsOpt1ずしたした)。

// Parallel implementation of the PointNet feature extraction
// `T` is the type for layer input, output, and intermediate results
// `U` is the type for parameters
// `N` is the expected number of input points (e.g., 1024)
template <typename T, typename U, int N>
void InferenceFeatOpt1(const float* point_cloud,
                       const int num_points,
                       T feature[kFeatDims5],
                       const LinearParams<U, kFeatDims0, kFeatDims1>* conv1,
                       const LinearParams<U, kFeatDims1, kFeatDims2>* conv2,
                       const LinearParams<U, kFeatDims2, kFeatDims3>* conv3,
                       const LinearParams<U, kFeatDims3, kFeatDims4>* conv4,
                       const LinearParams<U, kFeatDims4, kFeatDims5>* conv5,
                       const BatchNorm1dParams<U, kFeatDims1>* bn1,
                       const BatchNorm1dParams<U, kFeatDims2>* bn2,
                       const BatchNorm1dParams<U, kFeatDims3>* bn3,
                       const BatchNorm1dParams<U, kFeatDims4>* bn4,
                       const BatchNorm1dParams<U, kFeatDims5>* bn5)
{
#pragma HLS INLINE off

  // Zero-initialize the output feature
  VectorNdSetZero<T, kFeatDims5>(feature);

  // Compute the feature
  for (int i = 0; i < num_points; ++i) {
#pragma HLS LOOP_TRIPCOUNT min=N max=N avg=N
#pragma HLS LOOP_FLATTEN off

    // Input, output, and intermediate results
    T x0[kFeatDims0];
    T x1[kFeatDims1];
    T x2[kFeatDims1];
    T x3[kFeatDims2];
    T x4[kFeatDims2];
    T x5[kFeatDims3];
    T x6[kFeatDims3];
    T x7[kFeatDims4];
    T x8[kFeatDims4];
    T x9[kFeatDims5];
    T x10[kFeatDims5];

#pragma HLS ARRAY_PARTITION variable=x3 type=cyclic factor=4 dim=1
#pragma HLS ARRAY_PARTITION variable=x5 type=cyclic factor=4 dim=1
#pragma HLS ARRAY_PARTITION variable=x7 type=cyclic factor=8 dim=1
#pragma HLS ARRAY_PARTITION variable=x9 type=cyclic factor=64 dim=1

    // Read a point from a DDR memory
    ReadPointNaive<T>(point_cloud, i, x0);

    // Compute a point feature
    LinearOpt1<T, U, kFeatDims0, kFeatDims1, false, 2>(
      x0, x1, conv1->weight, conv1->bias);
    BatchNorm1dReLUOpt1<T, U, kFeatDims1, 2>(
      x1, x2, bn1->scale, bn1->bias, bn1->mean);
    LinearOpt1<T, U, kFeatDims1, kFeatDims2, false, 8>(
      x2, x3, conv2->weight, conv2->bias);
    BatchNorm1dReLUOpt1<T, U, kFeatDims2, 2>(
      x3, x4, bn2->scale, bn2->bias, bn2->mean);
    LinearOpt1<T, U, kFeatDims2, kFeatDims3, false, 8>(
      x4, x5, conv3->weight, conv3->bias);
    BatchNorm1dReLUOpt1<T, U, kFeatDims3, 2>(
      x5, x6, bn3->scale, bn3->bias, bn3->mean);
    LinearOpt1<T, U, kFeatDims3, kFeatDims4, false, 16>(
      x6, x7, conv4->weight, conv4->bias);
    BatchNorm1dReLUOpt1<T, U, kFeatDims4, 2>(
      x7, x8, bn4->scale, bn4->bias, bn4->mean);
    LinearOpt1<T, U, kFeatDims4, kFeatDims5, false, 128>(
      x8, x9, conv5->weight, conv5->bias);
    BatchNorm1dReLUOpt1<T, U, kFeatDims5, 2>(
      x9, x10, bn5->scale, bn5->bias, bn5->mean);

    // Update the output feature
    MaxPool1dOpt1<T, kFeatDims5, 2>(x10, feature);
  }
}

// Parallel implementation of the classification network
// `T` is the type for layer input, output, and intermediate results
// `U` is the type for parameters
template <typename T, typename U>
void InferenceClsOpt1(const T feature[kFeatDims5],
                      float* out_logits,
                      const LinearParams<U, kClsDims2, kClsDims3>* fc3,
                      const BatchNorm1dParams<U, kClsDims1>* bn1,
                      const BatchNorm1dParams<U, kClsDims2>* bn2,
                      const float* params1,
                      const float* params2,
                      const float* params3)
{
#pragma HLS INLINE off

  static_assert(kFeatDims5 == kClsDims0,
                "Feature dimension should be equal to the input dimension");

  // Input, output, and intermediate results
  T x0[kClsDims1];
  T x1[kClsDims1];
  T x2[kClsDims2];
  T x3[kClsDims2];
  T x4[kClsDims3];

#pragma HLS ARRAY_PARTITION variable=x0 type=cyclic factor=8 dim=1
#pragma HLS ARRAY_PARTITION variable=x2 type=cyclic factor=4 dim=1

  // Compute logits
  LinearOpt1DDR<T, U, kClsDims0, kClsDims1, false, 16>(
    feature, x0, params1, 0);
  BatchNorm1dReLUOpt1<T, U, kClsDims1, 2>(
    x0, x1, bn1->scale, bn1->bias, bn1->mean);
  LinearOpt1DDR<T, U, kClsDims1, kClsDims2, false, 8>(
    x1, x2, params2, 0);
  BatchNorm1dReLUOpt1<T, U, kClsDims2, 2>(
    x2, x3, bn2->scale, bn2->bias, bn2->mean);
  LinearOpt1<T, U, kClsDims2, kClsDims3, false, 2>(
    x3, x4, fc3->weight, fc3->bias);

  // Write the result
  WriteTensor1dNaive<T, kClsDims3>(out_logits, x4, 0);
}

各局の関数を呌び出す際に、テンプレヌト匕数に䞊列化床も指定しおいたす。 䟋えば、特城抜出ネットワヌクの4番目の党結合局 (PyTorchのモデルにおけるPointNetFeat::conv4) は16䞊列、最埌の党結合局 (PointNetFeat::conv5) は128䞊列で実行されたす。 䞀方、バッチ正芏化局ずMaxプヌリングは、2䞊列で実行されおいたす。 各局の䞊列床をどのように決定したのかに぀いおは、埌述したす。

続いお、IPコアの最䞊䜍関数PointNetClsTopを以䞋に瀺したす。

void PointNetClsTop(const int op_mode,
                    const float* point_cloud,
                    const int num_points,
                    float* out_logits,
                    const float* feat_params1,
                    const float* feat_params2,
                    const float* feat_params3,
                    const float* feat_params4,
                    const float* feat_params5,
                    const float* cls_params1,
                    const float* cls_params2,
                    const float* cls_params3)
{
#pragma HLS INTERFACE m_axi port=point_cloud offset=slave bundle=gmem0
#pragma HLS INTERFACE m_axi port=out_logits offset=slave bundle=gmem0
#pragma HLS INTERFACE m_axi port=feat_params1 offset=slave bundle=gmem0
#pragma HLS INTERFACE m_axi port=feat_params2 offset=slave bundle=gmem0
#pragma HLS INTERFACE m_axi port=feat_params3 offset=slave bundle=gmem0
#pragma HLS INTERFACE m_axi port=feat_params4 offset=slave bundle=gmem0
#pragma HLS INTERFACE m_axi port=feat_params5 offset=slave bundle=gmem0
#pragma HLS INTERFACE m_axi port=cls_params1 offset=slave bundle=gmem0
#pragma HLS INTERFACE m_axi port=cls_params2 offset=slave bundle=gmem0
#pragma HLS INTERFACE m_axi port=cls_params3 offset=slave bundle=gmem0

#pragma HLS INTERFACE s_axilite port=op_mode bundle=control
#pragma HLS INTERFACE s_axilite port=point_cloud bundle=control
#pragma HLS INTERFACE s_axilite port=num_points bundle=control
#pragma HLS INTERFACE s_axilite port=out_logits bundle=control
#pragma HLS INTERFACE s_axilite port=feat_params1 bundle=control
#pragma HLS INTERFACE s_axilite port=feat_params2 bundle=control
#pragma HLS INTERFACE s_axilite port=feat_params3 bundle=control
#pragma HLS INTERFACE s_axilite port=feat_params4 bundle=control
#pragma HLS INTERFACE s_axilite port=feat_params5 bundle=control
#pragma HLS INTERFACE s_axilite port=cls_params1 bundle=control
#pragma HLS INTERFACE s_axilite port=cls_params2 bundle=control
#pragma HLS INTERFACE s_axilite port=cls_params3 bundle=control
#pragma HLS INTERFACE s_axilite port=return bundle=control

  // Parameters for feature extraction
  LinearParams<param_t, kFeatDims0, kFeatDims1> feat_conv1;
  LinearParams<param_t, kFeatDims1, kFeatDims2> feat_conv2;
  LinearParams<param_t, kFeatDims2, kFeatDims3> feat_conv3;
  LinearParams<param_t, kFeatDims3, kFeatDims4> feat_conv4;
  LinearParams<param_t, kFeatDims4, kFeatDims5> feat_conv5;
  BatchNorm1dParams<param_t, kFeatDims1> feat_bn1;
  BatchNorm1dParams<param_t, kFeatDims2> feat_bn2;
  BatchNorm1dParams<param_t, kFeatDims3> feat_bn3;
  BatchNorm1dParams<param_t, kFeatDims4> feat_bn4;
  BatchNorm1dParams<param_t, kFeatDims5> feat_bn5;

#pragma HLS ARRAY_PARTITION variable=feat_conv2.weight type=cyclic factor=4 dim=1
#pragma HLS ARRAY_PARTITION variable=feat_conv2.bias type=cyclic factor=4 dim=1
#pragma HLS ARRAY_PARTITION variable=feat_conv3.weight type=cyclic factor=4 dim=1
#pragma HLS ARRAY_PARTITION variable=feat_conv3.bias type=cyclic factor=4 dim=1
#pragma HLS ARRAY_PARTITION variable=feat_conv4.weight type=cyclic factor=8 dim=1
#pragma HLS ARRAY_PARTITION variable=feat_conv4.bias type=cyclic factor=8 dim=1
#pragma HLS ARRAY_PARTITION variable=feat_conv5.weight type=cyclic factor=64 dim=1
#pragma HLS ARRAY_PARTITION variable=feat_conv5.bias type=cyclic factor=64 dim=1

  // Parameters for classification network
  // LinearParams<param_t, kClsDims0, kClsDims1> cls_fc1;
  // LinearParams<param_t, kClsDims1, kClsDims2> cls_fc2;
  LinearParams<param_t, kClsDims2, kClsDims3> cls_fc3;
  BatchNorm1dParams<param_t, kClsDims1> cls_bn1;
  BatchNorm1dParams<param_t, kClsDims2> cls_bn2;

  // Extracted feature
  value_t feature[kFeatDims5];

  if (op_mode == kModeInitWeights) {
    // Initialize the PointNet feature extraction network
    InitializeFeatOpt1<param_t>(
      &feat_conv1, &feat_conv2, &feat_conv3, &feat_conv4, &feat_conv5,
      &feat_bn1, &feat_bn2, &feat_bn3, &feat_bn4, &feat_bn5,
      feat_params1, feat_params2, feat_params3, feat_params4, feat_params5);
    // Initialize the classification network
    InitializeClsOpt1<param_t>(
      &cls_fc3, &cls_bn1, &cls_bn2,
      cls_params1, cls_params2, cls_params3);
  } else if (op_mode == kModeInference) {
    // Run the PointNet feature extraction
    InferenceFeatOpt1<value_t, param_t, 1024>(
      point_cloud, num_points, feature,
      &feat_conv1, &feat_conv2, &feat_conv3, &feat_conv4, &feat_conv5,
      &feat_bn1, &feat_bn2, &feat_bn3, &feat_bn4, &feat_bn5);

    // Run the classification
    InferenceClsOpt1<value_t, param_t>(
      feature, out_logits,
      &cls_fc3, &cls_bn1, &cls_bn2,
      cls_params1, cls_params2, cls_params3);
  }
}

関数の入出力ポヌトに぀いおは党く同䞀です。 以前のバヌゞョンず比范するず、局の入出力やパラメヌタを保持するバッファ (feat_conv5.weight、feat_conv5.bias、x3、x5など) を分割するために、#pragma HLS ARRAY_PARTITIONが远加されおいるこずが分かりたす。 配列の分割数 (factor) に぀いおは、䞊述のルヌルに則っおいたす。 䟋えば、InferenceFeatOpt1ずPointNetClsTopをみるず、特城抜出ネットワヌクの最埌の党結合局を128䞊列で実行したいので、出力甚のバッファx10ず、党結合局の2぀のパラメヌタfeat_conv5.weight、feat_conv5.biasを64分割しおいたす (蚘述する堎所が散らばっおいるのが難点です)。 同様に、InferenceClsOpt1ずPointNetClsTopをみるず、分類ネットワヌクの最初の党結合局は16䞊列で実行されるので、出力甚のバッファx0は8分割しおいたす。 バッチ正芏化局ずMaxプヌリングは2䞊列なので、配列を分割する必芁はありたせん。

先述のように、配列を分割するずポヌト数が増えお、䞀床に倚くの芁玠を読み出せるようになりたすが、貎重なオンチップメモリの消費も増えたす。 オンチップメモリの消費を抑え぀぀、なるべく䞊列床を䞊げる必芁がありたす。 掚論時間の短瞮に最も効果がある郚分 (䟋えば特城抜出ネットワヌクの最埌の党結合局) の䞊列床を䞊げお、効果があたりない郚分 (䟋えばバッチ正芏化局) の䞊列床は䞋げおいたす。

ここで、各局の実行サむクル数を比范しおみたす (動䜜呚波数は150MHz)。 特城抜出ネットワヌクに぀いおは次のようになりたした。

å±€ InferenceFeatNaive InferenceFeatOpt1
党結合局1 (PointNetFeat::conv1) 577 (3.843us) 321 (2.138us)
バッチ正芏化局 + ReLU (PointNetFeat::bn1) 68 (0.453us) 36 (0.240us)
党結合局2 (PointNetFeat::conv2) 4,481 (29.84us) 569 (3.790us)
バッチ正芏化局 + ReLU (PointNetFeat::bn2) 68 (0.453us) 36 (0.240us)
党結合局3 (PointNetFeat::conv3) 4,481 (29.84us) 569 (3.790us)
バッチ正芏化局 + ReLU (PointNetFeat::bn3) 68 (0.453us) 36 (0.240us)
党結合局4 (PointNetFeat::conv4) 8,961 (59.68us) 569 (3.790us)
バッチ正芏化局 + ReLU (PointNetFeat::bn4) 132 (0.879us) 68 (0.453us)
党結合局5 (PointNetFeat::conv5) 137,217 (914.0us) 1,081 (7.199us)
バッチ正芏化局 + ReLU (PointNetFeat::bn5) 1,028 (6.846us) 516 (3.437us)
Maxプヌリング局 1,026 (6.833us) 514 (3.423us)
党䜓 (1回分) 158,149 (1.053ms) 4,357 (29.02us)
党䜓 (1024回分) 161,945,604 (1.079s) 4,462,596 (29.72ms)

特城抜出ネットワヌクに関しおは、やはり最埌の党結合局がボトルネックずなっおいたす。 128䞊列にするこずで、実行時間を126.9倍 (137,217サむクルから1,081サむクル) 削枛できおいたす。 4぀目の党結合局に぀いおも、16䞊列にするこずで、実行時間が15.75倍 (8,961サむクルから569サむクル) 短くなりたした。 党結合局やバッチ正芏化局、Maxプヌリング局にみられるデヌタ䞊列性を掻かしお、掚論時間を短瞮できたした。 たた分類ネットワヌクに぀いおは次のようになりたした。

å±€ InferenceClsNaive InferenceClsOpt1
党結合局1 (PointNetCls::fc1) 1,056,279 (7.035ms) 558,071 (3.717ms)
バッチ正芏化局 + ReLU (PointNetCls::bn1) 516 (3.437us) 260 (1.732us)
党結合局2 (PointNetCls::fc2) 266,007 (1.772ms) 148,183 (987.0us)
バッチ正芏化局 + ReLU (PointNetCls::bn2) 260 (1.732us) 132 (0.879us)
党結合局3 (PointNetCls::fc3) 10,481 (69.80us) 5,261 (35.04us)
党䜓 1,333,605 (8.882ms) 711,969 (4.742ms)

最初の党結合局は16䞊列で実行するようにしたしたが、実行時間は1.89倍 (1,056,279サむクルから558,071サむクル) しか短くなっおいたせん。 前述のように、分類ネットワヌクの最初の党結合局2぀では、パラメヌタをオンチップバッファに眮くのではなく、DRAMバッファから必芁な郚分だけを転送しおいたす。 行列の積や加算は16䞊列で実行されるのですが、デヌタ転送郚分の実行時間は短瞮されないので、このような結果になっおいたす。 2぀目の党結合局に関しおも同様に、8䞊列を指定したのですが、実行時間は1.80倍 (266,007サむクルから148,183サむクル) の削枛に留たっおいたす。

珟圚の実装では、入出力ポヌトの幅は32ビットで、1サむクルに぀きfloatのデヌタを1぀ず぀転送しおいたす。 入出力ポヌトの幅を広げお、1サむクルで耇数のデヌタを転送すれば、デヌタ転送の実行時間を短瞮できたす。 埌ほど、ポヌト幅を32ビットから64ビットに広げお、1サむクルでfloatのデヌタを2぀ず぀転送するように、改善したす。

IPコアの動䜜モヌドには2぀ありたすが、このうち重みの初期化モヌドに぀いおは、党く手を加えおいたせん。 重みの初期化は、IPコアの利甚開始前に䞀床だけ行われ、ネットワヌクの掚論時間ずは党く関係ないためです。

以䞊で掚論の䞊列化が枈みたした。 詳しくはhls/src/top_opt1.cppをご参照ください。

䞊列化その2 (タスク䞊列性の掻甚)

各局の蚈算は䞊列化できたしたが、特城抜出ネットワヌクの郚分には、ただ高速化の䜙地が残されおいたす。 特城抜出ネットワヌクの掚論凊理を、もう䞀床みおみたしょう。

  // Compute the feature
  for (int i = 0; i < num_points; ++i) {
#pragma HLS LOOP_TRIPCOUNT min=N max=N avg=N
#pragma HLS LOOP_FLATTEN off

    // ...

    // Read a point from a DDR memory
    ReadPointNaive<T>(point_cloud, i, x0);

    // Compute a point feature
    LinearOpt1<T, U, kFeatDims0, kFeatDims1, false, 2>(
      x0, x1, conv1->weight, conv1->bias);
    BatchNorm1dReLUOpt1<T, U, kFeatDims1, 2>(
      x1, x2, bn1->scale, bn1->bias, bn1->mean);
    LinearOpt1<T, U, kFeatDims1, kFeatDims2, false, 8>(
      x2, x3, conv2->weight, conv2->bias);
    BatchNorm1dReLUOpt1<T, U, kFeatDims2, 2>(
      x3, x4, bn2->scale, bn2->bias, bn2->mean);
    LinearOpt1<T, U, kFeatDims2, kFeatDims3, false, 8>(
      x4, x5, conv3->weight, conv3->bias);
    BatchNorm1dReLUOpt1<T, U, kFeatDims3, 2>(
      x5, x6, bn3->scale, bn3->bias, bn3->mean);
    LinearOpt1<T, U, kFeatDims3, kFeatDims4, false, 16>(
      x6, x7, conv4->weight, conv4->bias);
    BatchNorm1dReLUOpt1<T, U, kFeatDims4, 2>(
      x7, x8, bn4->scale, bn4->bias, bn4->mean);
    LinearOpt1<T, U, kFeatDims4, kFeatDims5, false, 128>(
      x8, x9, conv5->weight, conv5->bias);
    BatchNorm1dReLUOpt1<T, U, kFeatDims5, 2>(
      x9, x10, bn5->scale, bn5->bias, bn5->mean);

    // Update the output feature
    MaxPool1dOpt1<T, kFeatDims5, 2>(x10, feature);
  }

ルヌプの内郚をみるず、最初に、DRAMに眮かれた点矀point_cloudからi番目の点を取っおきお、オンチップバッファx0に栌玍しおいたす。 続いお、このx0がバケツリレヌのように、耇数の関数に枡されおいきたす。 䟋えば、最初の党結合局によっおx0からx1、バッチ正芏化局によっおx1からx2、次の党結合局によっおx2からx3が蚈算されおいたす。 ある局の関数 (䟋えばLinearOpt1(x4, x5)) は、その䞀぀前の関数の出力 (x4) を入力ずしお受け取り、出力 (x5) を次の関数に匕き枡したす。 党おの関数が、入出力を介しお、数珠぀なぎのようになっおいたす。 関数の実行の流れを図にするず、次のようになりたす。

先皋のパむプラむン化ず同様に、耇数の点に぀いお凊理を䞊列化できたす。

䟋えば、1぀目の点に察しお、最埌の党結合局を蚈算しおいる間に、2぀目の点に察しお、その䞀぀前のバッチ正芏化局を蚈算するずいうように、耇数の点に察する凊理を時間的にオヌバヌラップさせたす。 以前は、ルヌプ内の凊理をパむプラむン化しお、ルヌプの耇数のむテレヌションを䞊列に実行したした。 そしお、パむプラむンの各ステヌゞは、䞻に乗算や加算でした。 ここでは、各ステヌゞは䞀぀の関数 (タスク) に察応するので、より粗粒床なパむプラむン化ずいえたす。 このようなタスクレベルのパむプラむン化は、Vitis HLSではデヌタフロヌ最適化 (Dataflow optimization) ずよばれおいたす (最適化その6: デヌタフロヌ最適化)。 デヌタフロヌ最適化を適甚するには、いろいろな条件がありたすが、今回の堎合は倧䞈倫です。

以前述べたように、パむプラむンの各ステヌゞの実行サむクル数をなるべく均等に揃えるこずで、パむプラむンの効果が増したす。 各局の蚈算時間を、なるべく均䞀にしたいずいうこずです。 蚈算時間は、䞊の衚にたずめられおいたす。 デヌタ䞊列性を利甚する前は、実行サむクル数 (特に党結合局) には、かなりのばら぀きがありたした。 党結合局5぀だけ抜き出しおみるず、577、4,481、4,481、8,961、137,217ずなっおいたす。 それぞれの局を、2、8、8、16、128䞊列で実行するこずで (InferenceFeatOpt1を参照)、321、569、569、569、1,081サむクルに削枛され、ばら぀きもかなり抑えられたした。 最埌の党結合局を256䞊列にすれば、さらに均等になりたすが、回路が耇雑になり過ぎるのでやめたした。

パむプラむンは最も時間の長いステヌゞによっお性胜が制限されたす。 今回の堎合は、最埌の党結合局 (1,081サむクル) によっお性胜が決たりたす。 他のステヌゞは、1,081サむクル以䞋であれば、䜕サむクルであろうずも性胜に圱響を䞎えたせん。 リ゜ヌス消費を抑えるため、他のステヌゞに関しおは、1,081サむクルを超えない範囲で、なるべく䞊列床を萜ずしたした。

特城抜出ネットワヌクに関しおはこのように、デヌタフロヌ最適化を予め考慮したうえで、各局の䞊列床を指定したした。 分類ネットワヌクの䞊列床は、䜕ずなく決めおいたす。

デヌタフロヌ最適化を斜した実装を、次に瀺したす。 InferenceFeatOpt1から、InferenceFeatOpt2ずしたした。

// Parallel implementation of the PointNet feature extraction
// `T` is the type for layer input, output, and intermediate results
// `U` is the type for parameters
// `N` is the expected number of input points (e.g., 1024)
template <typename T, typename U, int N>
void InferenceFeatOpt2(...)
{
#pragma HLS INLINE off

  // Zero-initialize the output feature
  VectorNdSetZero<T, kFeatDims5>(feature);

  // Compute the feature
  for (int i = 0; i < num_points; ++i) {
#pragma HLS LOOP_TRIPCOUNT min=N max=N avg=N
#pragma HLS LOOP_FLATTEN off
#pragma HLS DATAFLOW

#pragma HLS STABLE variable=point_cloud
#pragma HLS STABLE variable=num_points
#pragma HLS STABLE variable=feature
#pragma HLS STABLE variable=conv1
#pragma HLS STABLE variable=conv2
#pragma HLS STABLE variable=conv3
#pragma HLS STABLE variable=conv4
#pragma HLS STABLE variable=conv5
#pragma HLS STABLE variable=bn1
#pragma HLS STABLE variable=bn2
#pragma HLS STABLE variable=bn3
#pragma HLS STABLE variable=bn4
#pragma HLS STABLE variable=bn5

    // Input, output, and intermediate results
    // ...

    // Read a point from a DDR memory
    ReadPointNaive<T>(point_cloud, i, x0);

    // Compute a point feature
    LinearOpt1<T, U, kFeatDims0, kFeatDims1, false, 2>(
      x0, x1, conv1->weight, conv1->bias);
    BatchNorm1dReLUOpt1<T, U, kFeatDims1, 2>(
      x1, x2, bn1->scale, bn1->bias, bn1->mean);
    LinearOpt1<T, U, kFeatDims1, kFeatDims2, false, 8>(
      x2, x3, conv2->weight, conv2->bias);
    BatchNorm1dReLUOpt1<T, U, kFeatDims2, 2>(
      x3, x4, bn2->scale, bn2->bias, bn2->mean);
    LinearOpt1<T, U, kFeatDims2, kFeatDims3, false, 8>(
      x4, x5, conv3->weight, conv3->bias);
    BatchNorm1dReLUOpt1<T, U, kFeatDims3, 2>(
      x5, x6, bn3->scale, bn3->bias, bn3->mean);
    LinearOpt1<T, U, kFeatDims3, kFeatDims4, false, 16>(
      x6, x7, conv4->weight, conv4->bias);
    BatchNorm1dReLUOpt1<T, U, kFeatDims4, 2>(
      x7, x8, bn4->scale, bn4->bias, bn4->mean);
    LinearOpt1<T, U, kFeatDims4, kFeatDims5, false, 128>(
      x8, x9, conv5->weight, conv5->bias);
    BatchNorm1dReLUOpt1<T, U, kFeatDims5, 2>(
      x9, x10, bn5->scale, bn5->bias, bn5->mean);

    // Update the output feature
    MaxPool1dOpt1<T, kFeatDims5, 2>(x10, feature);
  }
}

InferenceFeatOpt1ず異なるのはHLSプラグマの郚分だけです。 ルヌプの先頭郚分には#pragma HLS DATAFLOWの蚘述があり、ルヌプの䞭身をデヌタフロヌ最適化するように指瀺したす。 #pragma HLS STABLEの郚分は、ルヌプの各むテレヌションを開始するにあたっお、その倉数に぀いお同期をずる必芁がない、ずいうこずを瀺したす。 各局のパラメヌタや点矀など、ルヌプの実行䞭は倉化しない倉数に付䞎しおいたす。 この蚘述がないず、デヌタフロヌ最適化がうたく機胜したせん。

この2皮類のHLSプラグマを挿入するだけで、デヌタフロヌ最適化をいずも簡単に実珟できたす。 高䜍合成ツヌルは玠晎らしいず思いたす。 PointNetClsTop (トップ関数) や分類ネットワヌクの掚論 (InferenceClsOpt1) に぀いおは以前ず党く同じであるため、ここでは割愛したす。

デヌタフロヌ最適化による効果をみおみたす。 InferenceFeatOpt1では、1぀の点に察する順䌝播に4,357サむクル (29.02us) 芁しおいたしたが、InferenceFeatOpt2でも4,344サむクル (28.93us) で、ほが倉わりたせん。 䞀方、1,024個の点に察する凊理時間をみおみるず、InferenceFeatOpt1では4,462,596サむクル (29.72ms) でしたが、InferenceFeatOpt2では1,112,259サむクル (7.408ms) に削枛されおいたす。 パむプラむン化しおも、各入力デヌタに察する蚈算時間 (レむテンシ) は倉化したせんが、単䜍時間あたりに凊理可胜なデヌタ数 (スルヌプット) は改善するので、それに䌎っお党䜓の性胜も向䞊するずいうこずです。

これでデヌタフロヌ最適化は終わりです。 詳しくはhls/src/top_opt2.cppをご芧ください。

入出力ポヌト幅の拡匵

分類ネットワヌクの党結合局郚分では、積和挔算を䞊列化したにもかかわらず、党䜓の凊理時間はそれほど短瞮されたせんでした。 DRAMからオンチップバッファぞのパラメヌタ転送のサむクル数が、倉化しおいないためです。 そこで最埌の最適化ずしお、入出力ポヌトのビット幅を32から64に広げお、1サむクルに぀き2぀のfloatデヌタを転送できるように、実装を修正しおみたしょう (最適化その7: デヌタ転送)。

最初に、IPコアの最䞊䜍関数PointNetClsTopから修正したす。 修正前は、次のようになっおいたした。

void PointNetClsTop(const int op_mode,
                    const float* point_cloud,
                    const int num_points,
                    float* out_logits,
                    const float* feat_params1,
                    const float* feat_params2,
                    const float* feat_params3,
                    const float* feat_params4,
                    const float* feat_params5,
                    const float* cls_params1,
                    const float* cls_params2,
                    const float* cls_params3)
{
  // ...
}

これを、次のように64ビット幅にしたす。

void PointNetClsTop(const int op_mode,
                    const ap_uint<64>* point_cloud,
                    const int num_points,
                    ap_uint<64>* out_logits,
                    const ap_uint<64>* feat_params1,
                    const ap_uint<64>* feat_params2,
                    const ap_uint<64>* feat_params3,
                    const ap_uint<64>* feat_params4,
                    const ap_uint<64>* feat_params5,
                    const ap_uint<64>* cls_params1,
                    const ap_uint<64>* cls_params2,
                    const ap_uint<64>* cls_params3)
{
  // ...
}

ap_uintは、Vitis HLSで提䟛されおいる、任意ビット長の笊号なし敎数型です。 ここでは64ビットずしおいたす。 1サむクルに぀きデヌタを2぀ず぀読み取らなければいけないので、デヌタ転送に関する郚分を党お修正したす。 DRAMからパラメヌタを取り出しお、オンチップバッファに栌玍する、重み初期化関数InitializeFeatOpt1、InitializeClsOpt1も次のように盎しお、新たにInitializeFeatOpt3、InitializeClsOpt3ずしたす。 単に、関数の匕数をfloat*からap_uint<64>*に倉曎しただけです。

// Parallel implementation of the parameter initialization
// `T` is the type for parameters
template <typename T>
void InitializeFeatOpt3(LinearParams<T, kFeatDims0, kFeatDims1>* conv1,
                        LinearParams<T, kFeatDims1, kFeatDims2>* conv2,
                        LinearParams<T, kFeatDims2, kFeatDims3>* conv3,
                        LinearParams<T, kFeatDims3, kFeatDims4>* conv4,
                        LinearParams<T, kFeatDims4, kFeatDims5>* conv5,
                        BatchNorm1dParams<T, kFeatDims1>* bn1,
                        BatchNorm1dParams<T, kFeatDims2>* bn2,
                        BatchNorm1dParams<T, kFeatDims3>* bn3,
                        BatchNorm1dParams<T, kFeatDims4>* bn4,
                        BatchNorm1dParams<T, kFeatDims5>* bn5,
                        const ap_uint<64>* params1,
                        const ap_uint<64>* params2,
                        const ap_uint<64>* params3,
                        const ap_uint<64>* params4,
                        const ap_uint<64>* params5)
{
#pragma HLS INLINE off

  ReadBlockParamsOpt2<T, kFeatDims0, kFeatDims1>(conv1, bn1, params1);
  ReadBlockParamsOpt1<T, kFeatDims1, kFeatDims2>(conv2, bn2, params2);
  ReadBlockParamsOpt1<T, kFeatDims2, kFeatDims3>(conv3, bn3, params3);
  ReadBlockParamsOpt1<T, kFeatDims3, kFeatDims4>(conv4, bn4, params4);
  ReadBlockParamsOpt1<T, kFeatDims4, kFeatDims5>(conv5, bn5, params5);
}

// Parallel implementation of the parameter initialization
// `T` is the type for parameters
template <typename T>
void InitializeClsOpt3(LinearParams<T, kClsDims2, kClsDims3>* fc3,
                       BatchNorm1dParams<T, kClsDims1>* bn1,
                       BatchNorm1dParams<T, kClsDims2>* bn2,
                       const ap_uint<64>* params1,
                       const ap_uint<64>* params2,
                       const ap_uint<64>* params3)
{
#pragma HLS INLINE off

  ReadBatchNorm1dParamsOpt1<T, kClsDims1>(
    bn1, params1, kClsDims0 * kClsDims1 + kClsDims1);
  ReadBatchNorm1dParamsOpt1<T, kClsDims2>(
    bn2, params2, kClsDims1 * kClsDims2 + kClsDims2);
  ReadLinearParamsOpt1<T, kClsDims2, kClsDims3>(
    fc3, params3, 0);
}

最初の実装ではReadLinearParamsNaive、ReadBatchNorm1dParamsNaive、ReadBlockParamsNaiveを䜿っおいたしたが、ここでは新たにReadLinearParamsOpt1、ReadBatchNorm1dParamsOpt1、ReadBlockParamsOpt1、ReadBlockParamsOpt2の4皮類を䜿っおいたす。 詳しく䞭身をみおみたしょう。

// Parallel implementation of the parameter initialization
// Read the parameters for a linear layer from a DDR memory and
// store them to BRAM buffers
// `T` is the type for parameters
// `InDims` is the number of input dimensions
// `OutDims` is the number of output dimensions
template <typename T, int InDims, int OutDims>
void ReadLinearParamsOpt1(LinearParams<T, InDims, OutDims>* linear,
                          const ap_uint<64>* params,
                          const int offset)
{
#pragma HLS INLINE
  // `params` contains weight parameters of size (`OutDims`, `InDims`) and
  // bias parameters of size (`OutDims`) in a contiguous buffer

  static_assert(InDims % 2 == 0, "`InDims` must be a multiple of 2");
  static_assert(OutDims % 2 == 0, "`OutDims` must be a multiple of 2");
  assert(offset % 2 == 0);

  ReadTensor2dOpt1<T, OutDims, InDims>(linear->weight, params, offset);
  ReadTensor1dOpt1<T, OutDims>(linear->bias, params,
                               offset + InDims * OutDims);
}

// Parallel implementation of the parameter initialization
// Read the parameters for a 1D batch normalization layer from a DDR memory and
// store them to BRAM buffers
// `T` is the type for parameters
// `Dims` is the number of input and output dimensions
template <typename T, int Dims>
void ReadBatchNorm1dParamsOpt1(BatchNorm1dParams<T, Dims>* bn,
                               const ap_uint<64>* params,
                               const int offset)
{
#pragma HLS INLINE
  // `params` contains scale parameters of size (`Dims`),
  // bias of size (`Dims`), and mean of size (`Dims`) in a contiguous buffer

  static_assert(Dims % 2 == 0, "`Dims` must be a multiple of 2");
  assert(offset % 2 == 0);

  ReadTensor1dOpt1<T, Dims>(bn->scale, params, offset);
  ReadTensor1dOpt1<T, Dims>(bn->bias, params, offset + Dims);
  ReadTensor1dOpt1<T, Dims>(bn->mean, params, offset + Dims * 2);
}

// Parallel implementation of the parameter initialization
// Read the parameters for a linear and 1D batch normalization layer
// from a DDR memory and store them to BRAM buffers
// `T` is the type for parameters
// `InDims` is the number of input dimensions
// `OutDims` is the number of output dimensions
template <typename T, int InDims, int OutDims>
void ReadBlockParamsOpt1(LinearParams<T, InDims, OutDims>* linear,
                         BatchNorm1dParams<T, OutDims>* bn,
                         const ap_uint<64>* params)
{
#pragma HLS INLINE

  static_assert(InDims % 2 == 0, "`InDims` must be a multiple of 2");
  static_assert(OutDims % 2 == 0, "`OutDims` must be a multiple of 2");

  ReadTensor2dOpt1<T, OutDims, InDims>(linear->weight, params, 0);
  ReadTensor1dOpt1<T, OutDims>(linear->bias, params, InDims * OutDims);
  ReadTensor1dOpt1<T, OutDims>(bn->scale, params,
                               InDims * OutDims + OutDims);
  ReadTensor1dOpt1<T, OutDims>(bn->bias, params,
                               InDims * OutDims + OutDims * 2);
  ReadTensor1dOpt1<T, OutDims>(bn->mean, params,
                               InDims * OutDims + OutDims * 3);
}

// Parallel implementation of the parameter initialization
// Read the parameters for a linear and 1D batch normalization layer
// from a DDR memory and store them to BRAM buffers
// `T` is the type for parameters
// `InDims` is the number of input dimensions
// `OutDims` is the number of output dimensions
template <typename T, int InDims, int OutDims>
void ReadBlockParamsOpt2(LinearParams<T, InDims, OutDims>* linear,
                         BatchNorm1dParams<T, OutDims>* bn,
                         const ap_uint<64>* params)
{
#pragma HLS INLINE

  static_assert(InDims == 3, "`InDims` must be 3");
  static_assert(OutDims % 2 == 0, "`OutDims` must be a multiple of 2");

  ReadTensor2dOpt2<T, OutDims, InDims>(linear->weight, params, 0);
  ReadTensor1dOpt1<T, OutDims>(linear->bias, params, InDims * OutDims);
  ReadTensor1dOpt1<T, OutDims>(bn->scale, params,
                               InDims * OutDims + OutDims);
  ReadTensor1dOpt1<T, OutDims>(bn->bias, params,
                               InDims * OutDims + OutDims * 2);
  ReadTensor1dOpt1<T, OutDims>(bn->mean, params,
                               InDims * OutDims + OutDims * 3);
}

基本的には元のナむヌブな実装ず同じですが、匕数の型がfloat*からap_uint<64>*に倉わっおいたす。 関数の䞭身も単玔で、指定したオフセットから、指定したサむズのパラメヌタを読み取るこずを繰り返すだけです。 䟋えばバッチ正芏化局のパラメヌタを読み取るずきは、スケヌル、バむアス、平均の順に読み取りたす。 DRAMバッファ䞊には予め、正しい䜍眮にこの順で䞊べおおく必芁がありたす。 䞭で䜿われおいる関数ReadTensor1dOpt1、ReadTensor2dOpt1、ReadTensor2dOpt2は次の通りです。

union conv32_t
{
  std::uint32_t u32;
  int i32;
  float f;
};

// Interpret float as std::uint32_t
inline std::uint32_t FloatToU32(const float f)
{
  conv32_t conv;
  conv.f = f;
  return conv.u32;
}

// Interpret std::uint32_t as float
inline float U32ToFloat(const std::uint32_t u32)
{
  conv32_t conv;
  conv.u32 = u32;
  return conv.f;
}

// Read a 1D tensor from a DDR memory
template <typename T, int D0>
void ReadTensor1dNaive(T tensor[D0],
                       const float* src,
                       const int offset)
{
#pragma HLS INLINE off

  for (int i = 0; i < D0; ++i) {
#pragma HLS PIPELINE II=1
    tensor[i] = T(src[offset + i]);
  }
}

// Read a 1D tensor from a DDR memory
template <typename T, int D0>
void ReadTensor1dOpt1(T tensor[D0],
                      const ap_uint<64>* src,
                      const int offset)
{
#pragma HLS INLINE off

  static_assert(D0 % 2 == 0, "`D0` must be a multiple of 2");
  assert(offset % 2 == 0);

  constexpr const int D0Over2 = D0 / 2;
  const int offset2 = offset / 2;

  for (int i = 0; i < D0Over2; ++i) {
#pragma HLS PIPELINE II=1
    const ap_uint<64> tensor_data = src[offset2 + i];
    tensor[i * 2 + 0] = T(U32ToFloat(tensor_data.range(31, 0)));
    tensor[i * 2 + 1] = T(U32ToFloat(tensor_data.range(63, 32)));
  }
}

// Read a 2D tensor from a DDR memory
template <typename T, int D0, int D1>
void ReadTensor2dNaive(T tensor[D0][D1],
                       const float* src,
                       const int offset)
{
#pragma HLS INLINE off

  for (int i = 0; i < D0; ++i) {
    for (int j = 0; j < D1; ++j) {
#pragma HLS PIPELINE II=1
      const int idx = i * D1 + j;
      tensor[i][j] = T(src[offset + idx]);
    }
  }
}

// Read a 2D tensor from a DDR memory
template <typename T, int D0, int D1>
void ReadTensor2dOpt1(T tensor[D0][D1],
                      const ap_uint<64>* src,
                      const int offset)
{
#pragma HLS INLINE off

  static_assert(D1 % 2 == 0, "`D1` must be a multiple of 2");
  assert(offset % 2 == 0);

  constexpr const int D1Over2 = D1 / 2;
  const int offset2 = offset / 2;

  for (int i = 0; i < D0; ++i) {
    for (int j = 0; j < D1Over2; ++j) {
#pragma HLS PIPELINE II=1
      const int idx = i * D1Over2 + j;
      const ap_uint<64> tensor_data = src[offset2 + idx];
      tensor[i][j * 2 + 0] = T(U32ToFloat(tensor_data.range(31, 0)));
      tensor[i][j * 2 + 1] = T(U32ToFloat(tensor_data.range(63, 32)));
    }
  }
}

// Read a 2D tensor of size (`D0`, 3) from a DDR memory
template <typename T, int D0, int D1>
void ReadTensor2dOpt2(T tensor[D0][D1],
                      const ap_uint<64>* src,
                      const int offset)
{
#pragma HLS INLINE off

  static_assert(D0 % 2 == 0, "`D0` must be a multiple of 2");
  static_assert(D1 == 3, "`D1` must be 3");
  assert(offset % 2 == 0);

  constexpr const int Iter = D0 * D1 / (2 * 3);
  const int offset2 = offset / 2;

  for (int i = 0; i < Iter; ++i) {
#pragma HLS PIPELINE
    const int src_idx = i * 3;
    const int dst_idx = i * 2;
    const ap_uint<64> tensor_data0 = src[offset2 + src_idx + 0];
    const ap_uint<64> tensor_data1 = src[offset2 + src_idx + 1];
    const ap_uint<64> tensor_data2 = src[offset2 + src_idx + 2];
    tensor[dst_idx + 0][0] = T(U32ToFloat(tensor_data0.range(31, 0)));
    tensor[dst_idx + 0][1] = T(U32ToFloat(tensor_data0.range(63, 32)));
    tensor[dst_idx + 0][2] = T(U32ToFloat(tensor_data1.range(31, 0)));
    tensor[dst_idx + 1][0] = T(U32ToFloat(tensor_data1.range(63, 32)));
    tensor[dst_idx + 1][1] = T(U32ToFloat(tensor_data2.range(31, 0)));
    tensor[dst_idx + 1][2] = T(U32ToFloat(tensor_data2.range(63, 32)));
  }
}

比范できるように、デヌタを1぀ず぀読み取る、元のナむヌブな実装も茉せたした。 各関数の動䜜をたずめたす。

  • ReadTensor1dOpt1<T, D0>(tensor, src, offset): 指定されたDRAMバッファsrcの、floatでoffset個分だけずらした堎所から (srcに4 * offsetバむト分だけ足したアドレスから)、D0個分のfloatを2぀ず぀読み取る。 読み取ったデヌタはfloatからT型にキャストしお、指定された1次元のオンチップバッファtensor (サむズ(D0))に2぀ず぀栌玍する。 1サむクルで2぀ず぀読み取るため、サむズD0は偶数ず仮定しおいる。
  • ReadTensor2dOpt1<T, D0, D1>(tensor, src, offset): 指定されたDRAMバッファsrcからデヌタを2぀ず぀読み取っお、2次元のオンチップバッファtensor (サむズ(D0, D1))に栌玍する。 1サむクルで2぀ず぀読み取るため、サむズD1は偶数ず仮定しおいる。
  • ReadTensor2dOpt2<T, D0, D1>(tensor, src, offset): D1が3である堎合の専甚の実装。 3サむクル掛けお、指定されたDRAMバッファsrcからデヌタを6぀読み取った埌、オンチップバッファtensorに栌玍しおいく。 実装を簡略化するため、サむズに関しおは、D1は3、D0は偶数であるこずを仮定しおいる (芁玠数が偶数)。

ReadTensor2dOpt2およびReadBlockParamsOpt2は、特城抜出ネットワヌクにおける最初の党結合局の重みを転送するために䜿われおいたす (InitializeFeatOpt3を参照)。 最初の党結合局は、3次元の点の座暙を64次元の特城に倉換するので、重みのサむズは(64, 3)ずなりたす。 デヌタを2぀ず぀読み取りたいのに、2番目の次元が奇数で、実装䞊の郜合が悪いので、専甚の関数を甚意したわけです。 ReadTensor2dOpt2では、重みを6぀ず぀読み取るこずで察凊しおいたす。 別の察凊法ずしおは、重みのバッファサむズを(64, 3)から(64, 4)に広げるこずが考えられたす (4番目の次元は単に䜿わない)。

ReadBlockParamsOpt1ずReadBlockParamsOpt2の違いは、ReadTensor2dOpt1ずReadTensor2dOpt2のどちらを䜿っおいるかだけです。 2぀の関数は、C++17に甚意されたif constexpr文を䜿えば、1぀にたずめられるず思いたすが、今回はC++14たでの機胜を䜿っおいるので、別々にしおいたす。

ap_uint型にはrange()ずいう䟿利なメ゜ッドが甚意されおおり、指定したビットの郚分を自由に取り出せたす。 range(31, 0)で䞋䜍32ビット、range(63, 32)で䞊䜍32ビットを取り出しおいたす。

U32ToFloat()、FloatToU32()は、ビット衚珟を維持したたた、別の型に解釈するための関数です (floatず笊号なし32ビット敎数)。 tensor_data.range(31, 0)は32ビットの笊号なし敎数型 (unsigned intやap_uint<32>) ですが、実際にはfloatのデヌタが入っおいるので、U32ToFloat()を䜿っおfloatに解釈し盎しおいたす。 2぀の関数は、共甚䜓を䜿っお実珟しおいたす。 C++20であれば、std::bit_castで同等の凊理ができたす。

特城抜出ネットワヌクの掚論に着目したす (InferenceFeatOpt2を参照)。 i番目の点をDRAMバッファから読み取るReadPointNaiveも、64ビット幅に合わせお曞き盎したす。 修正埌のバヌゞョンをReadPointOpt1ずしたした。

// Read a point from a DDR memory
template <typename T>
void ReadPointNaive(const float* point_cloud,
                    const int idx,
                    T x[3])
{
#pragma HLS INLINE off

  for (int i = 0; i < 3; ++i) {
#pragma HLS PIPELINE II=1
    x[i] = T(point_cloud[idx * 3 + i]);
  }
}

// Read a point from a DDR memory
template <typename T>
void ReadPointOpt1(const ap_uint<64>* point_cloud,
                   const int idx,
                   T x[3])
{
#pragma HLS INLINE off

  const ap_uint<64> point_data0 = point_cloud[idx * 2 + 0];
  const ap_uint<64> point_data1 = point_cloud[idx * 2 + 1];
  x[0] = T(U32ToFloat(point_data0.range(31, 0)));
  x[1] = T(U32ToFloat(point_data0.range(63, 32)));
  x[2] = T(U32ToFloat(point_data1.range(31, 0)));
}

ReadPointNaiveでは、DRAMバッファpoint_cloudのサむズが$(N, 3)$であるこずを想定しおいたした。 䞀方ReadPointOpt1では、実装を簡単にするため、バッファサむズが$(N, 4)$であるずしたす (4番目の次元に぀いおは䜿わない)。 i番目の点を読み取るずきは、バッファのidx * 2 + 0番目ずidx * 2 + 1番目を参照すればよいです。

最埌に、分類ネットワヌクの掚論を盎したす (InferenceClsOpt1を参照)。 点矀の特城量から、物䜓の各クラスに察するロゞットを蚈算し、WriteTensor1dNaiveによりDRAMバッファに曞き蟌んでいたす。 WriteTensor1dNaiveを、64ビット幅に合わせお曞き盎したす。 修正埌のバヌゞョンをWriteTensor1dOpt1ずしたした。

// Write a 1D tensor to a DDR memory
template <typename T, int D0>
void WriteTensor1dNaive(float* dst,
                        const T tensor[D0],
                        const int offset)
{
#pragma HLS INLINE off

  for (int i = 0; i < D0; ++i) {
#pragma HLS PIPELINE II=1
    dst[offset + i] = static_cast<float>(tensor[i]);
  }
}

// Write a 1D tensor to a DDR memory
template <typename T, int D0>
void WriteTensor1dOpt1(ap_uint<64>* dst,
                       const T tensor[D0],
                       const int offset)
{
#pragma HLS INLINE off

  static_assert(D0 % 2 == 0, "`D0` must be a multiple of 2");
  assert(offset % 2 == 0);

  constexpr const int D0Over2 = D0 / 2;
  const int offset2 = offset / 2;

  for (int i = 0; i < D0Over2; ++i) {
#pragma HLS PIPELINE II=1
    ap_uint<64> tensor_data;
    tensor_data.range(31, 0) = FloatToU32(
      static_cast<float>(tensor[i * 2 + 0]));
    tensor_data.range(63, 32) = FloatToU32(
      static_cast<float>(tensor[i * 2 + 1]));
    dst[offset2 + i] = tensor_data;
  }
}

オンチップバッファtensorに眮かれたサむズ(D0)のデヌタを、1サむクルに2぀ず぀、DRAMに曞き戻しおいたす。 実装を簡単にするため、D0は偶数であるず仮定したす。 2぀のデヌタはT型ですが、゜フトりェア偎から利甚しやすいようにfloatに盎し、曎にFloatToU32を䜿っお、ビット衚珟を維持したたた32ビットの笊号なし敎数型に再解釈しおいたす。 これら2぀を、ap_uint<64>型の䞊䜍32ビットず䞋䜍32ビットに詰めお、DRAMバッファに曞き戻しおいたす。

最初の2぀の党結合局 (LinearOpt1DDR) も盎しお、新たにLinearOpt2DDRを䜜りたす。 重みずバむアスの転送郚分を倉曎したす。 転送に芁するサむクル数が半分ほどになるので、分類ネットワヌクの掚論時間の削枛が期埅されたす。 実装を簡単にするため、入出力の次元がいずれも偶数であるこずを前提ずしおいたす。

// Parallel implementation of the fully-connected layer
// Weight and bias parameters are stored on the DDR memory
// Matrix-vector multiplication is parallelized along the output dimension
// `T` is the type for values
// `TParam` is the type for weight and bias
// `InDims` is the number of input dimensions
// `OutDims` is the number of output dimensions
// `ApplyReLU` is the flag to apply ReLU activation
// `B` is the block size for the output dimension
template <typename T, typename TParam,
          int InDims, int OutDims, bool ApplyReLU, int B>
void LinearOpt2DDR(const T x[InDims],
                   T y[OutDims],
                   const ap_uint<64>* params,
                   const int offset)
{
  // `x` is of size (1, `InDims`)
  // `y` is of size (1, `OutDims`)
  // `params` contains weight parameters of size (`OutDims`, `InDims`) and
  // bias parameters of size (`OutDims`) in a contiguous buffer

#pragma HLS INLINE off

  // `OutDims` must be a multiple of `B`
  static_assert(OutDims % B == 0, "`OutDims` must be a multiple of `B`");
  // `B` must be larger than 1
  static_assert(B > 1, "`B` must be larger than 1");
  // `InDims` must be a multiple of 2
  static_assert(InDims % 2 == 0, "`InDims` must be a multiple of 2");
  // `OutDims` must be a multiple of 2
  static_assert(OutDims % 2 == 0, "`OutDims` must be a multiple of 2");
  // `offset` must be a multiple of 2
  assert(offset % 2 == 0);

  constexpr const int BHalf = B / 2;
  constexpr const int OffsetToBias = OutDims * InDims / 2;
  constexpr const int InDims2 = InDims / 2;
  constexpr const int OutDims2 = OutDims / 2;
  const int offset2 = offset / 2;

  TParam bias[OutDims];
#pragma HLS ARRAY_PARTITION variable=bias type=cyclic factor=BHalf dim=1

  // Copy the bias parameters in advance
  for (int i = 0; i < OutDims2; ++i) {
#pragma HLS PIPELINE II=1
    const ap_uint<64> bias_data = params[offset2 + OffsetToBias + i];
    bias[i * 2 + 0] = TParam(U32ToFloat(bias_data.range(31, 0)));
    bias[i * 2 + 1] = TParam(U32ToFloat(bias_data.range(63, 32)));
  }

  for (int i0 = 0; i0 < OutDims; i0 += B) {
#pragma HLS PIPELINE off
    T vals[B];
#pragma HLS ARRAY_PARTITION variable=vals type=complete dim=1
    TParam weight[B][InDims];
#pragma HLS ARRAY_PARTITION variable=weight type=cyclic factor=BHalf dim=1

    // Copy the weight parameters for `B` outputs
    const int offset0 = offset2 + i0 * InDims2;
    for (int i1 = 0; i1 < B; ++i1) {
      for (int j = 0; j < InDims2; ++j) {
#pragma HLS PIPELINE
        const ap_uint<64> weight_data = params[offset0 + i1 * InDims2 + j];
        weight[i1][j * 2 + 0] = TParam(
          U32ToFloat(weight_data.range(31, 0)));
        weight[i1][j * 2 + 1] = TParam(
          U32ToFloat(weight_data.range(63, 32)));
      }
    }

    for (int j = 0; j < InDims; ++j) {
#pragma HLS PIPELINE
      for (int i1 = 0; i1 < B; ++i1) {
#pragma HLS UNROLL
        int i = i0 + i1;
        if (i < OutDims) {
          T last = (j == 0) ? T(bias[i]) : vals[i1];
          vals[i1] = last + x[j] * weight[i1][j];
        }
      }
    }

    for (int i1 = 0; i1 < B; ++i1) {
#pragma HLS UNROLL
      int i = i0 + i1;
      if (i < OutDims) {
        if (ApplyReLU)
          y[i] = vals[i1] > T(0) ? vals[i1] : T(0);
        else
          y[i] = vals[i1];
      }
    }
  }
}

2぀のネットワヌクに぀いお、デヌタの入出力に関連する郚分を修正したした。 InferenceFeatOpt2ずInferenceClsOpt1に察しお、修正を斜したものをInferenceFeatOpt3、InferenceClsOpt3ずしたす。 InferenceFeatOpt3では、点矀デヌタを読み取る際に、ReadPointNaiveの代わりにReadPointOpt1を䜿っおいたす (他は同じ)。 たたInferenceClsOpt3では、ロゞットを曞き蟌む際に、WriteTensor1dNaiveではなくWriteTensor1dOpt1を䜿い、最初の2぀の党結合局に぀いおは、LinearOpt1DDRの代わりにLinearOpt2DDRを䜿っおいたす。

template <typename T, typename U, int N>
void InferenceFeatOpt3(...)
{
#pragma HLS INLINE off

  // Zero-initialize the output feature
  VectorNdSetZero<T, kFeatDims5>(feature);

  // Compute the feature
  for (int i = 0; i < num_points; ++i) {
    // ...

    // Read a point from a DDR memory
    ReadPointOpt1<T>(point_cloud, i, x0);

    // Compute a point feature
    // ...

    // Update the output feature
    MaxPool1dOpt1<T, kFeatDims5, 2>(x10, feature);
  }
}

template <typename T, typename U>
void InferenceClsOpt3(...)
{
#pragma HLS INLINE off

  // ...

  // Compute logits
  LinearOpt2DDR<T, U, kClsDims0, kClsDims1, false, 16>(
    feature, x0, params1, 0);
  BatchNorm1dReLUOpt1<T, U, kClsDims1, 2>(
    x0, x1, bn1->scale, bn1->bias, bn1->mean);
  LinearOpt2DDR<T, U, kClsDims1, kClsDims2, false, 8>(
    x1, x2, params2, 0);
  BatchNorm1dReLUOpt1<T, U, kClsDims2, 2>(
    x2, x3, bn2->scale, bn2->bias, bn2->mean);
  LinearOpt1<T, U, kClsDims2, kClsDims3, false, 2>(
    x3, x4, fc3->weight, fc3->bias);

  // Write the result
  WriteTensor1dOpt1<T, kClsDims3>(out_logits, x4, 0);
}

入出力ポヌト幅によっお、どの皋床実行時間を削枛できたでしょうか。 特城抜出ネットワヌクInferenceFeatOpt2の実行サむクル数は1,112,259 (7.408ms)、新たに甚意したInferenceFeatOpt3は1,112,254 (7.408ms) でした。 ほが䞀緒です。 分類ネットワヌクに関しおは、ポヌト幅32ビット甚のInferenceClsOpt1は711,969サむクル (4.742ms) でしたが、64ビット甚のInferenceClsOpt3では383,885サむクル (2.557ms) に削枛されたした。 ポヌト幅を2倍に広げたこずで、分類ネットワヌクの掚論時間を1.85倍短瞮できたわけです。

圓初のナむヌブ実装 (InferenceFeatNaive + InferenceClsNaive) ず、ここに瀺す実装 (InferenceFeatOpt3 + InferenceClsOpt3) ずで、実行サむクル数はどの皋床倉化したでしょうか。 䞊がナむヌブ実装、䞋が最適化枈みの実装での結果です。 ナむヌブ実装では、掚論に163,279,213サむクル (1.087s) 芁しおいたすが、最適化によっお1,496,143サむクル (9.964ms) にたで削枛されおいたす。 およそ109倍の差ですね。

以䞊で、高䜍合成の実装ができあがりたした。 hls/src/top_opt3.cppをご芧ください。

ビットストリヌムの準備

高䜍合成の実装ができたので、Vitis HLSでコンパむルし、IPコアを䜜成したす。 今回は、以䞋のような環境で䜜業しおいたす (詊す人はいないず思いたすが曞いおおきたす)。

  • Ubuntu 20.04.5 LTS
  • Intel(R) Xeon(R) E-2186G CPU @ 3.80GHz
  • 64GB DRAM
  • Vivado ML Edition 2022.1 (むンストヌル堎所は/tools/Xilinx以䞋)
  • CMake 3.16.3

たた、察象のFPGAボヌドは、Xilinx ZCU104 Evaluation Board (XCZU7EV-2FFVC1156)です。

今回甚意したGitHubリポゞトリでは、以䞋のようにmakeするだけで、自動的にIPコアを䜜成できたす。 TclスクリプトずCMakeを組み合わせお実珟されおいたす。 䞊のスクリヌンショットのように、Vitis HLSにはGUIが甚意されおいたすが、Tclスクリプトを䜿えばコマンドラむン䞊でのバッチ凊理が可胜です。 適圓な堎所にリポゞトリをクロヌンしたら、hlsディレクトリに移っお、䜜業甚のディレクトリを準備したす。 続いおCMakeプロゞェクトを構成し、所望のIPコアをmakeで䜜成したす。

# 予めVivadoずVitis HLSを䜿えるようにsourceする
> source /tools/Xilinx/Vivado/2022.1/settings64.sh

# GitHubリポゞトリのクロヌン
> git clone git@github.com:sterngerlach/advent_2022_point_cloud_classification.git
> cd advent_2022_point_cloud_classification

# 䜜業甚ディレクトリの準備
> cd hls
> mkdir build
> mkdir work

> cd build

# CMakeプロゞェクトを構成
# settings64.shによっおCMakeが曞き換えられるので、システムのCMakeを䜿う
> /usr/bin/cmake ..

# ナむヌブ実装からIPコアを䜜成
# workディレクトリ内に䜜られる
> make pointnet_naive_150_csynth_export

# デヌタ䞊列性を掻甚した (ルヌプアンロヌリングず配列の分割を枈たせた) IPコアを䜜成
> make pointnet_opt1_csynth_export

# デヌタフロヌ最適化を枈たせたIPコアを䜜成
> make pointnet_opt2_csynth_export

# 入出力のポヌト幅を64ビットに広げたIPコアを䜜成
> make pointnet_opt3_csynth_export

IPコアを䜜成したら、GUIを起動しお、合成結果をみおみたしょう (䞊のスクリヌンショットのような画面が開きたす)。

> cd hls/work

# ナむヌブ実装甚のVitis HLSプロゞェクトをGUIで開く
> vitis_hls -p pointnet_naive_150

# 他も同様
> vitis_hls -p pointnet_opt1
> vitis_hls -p pointnet_opt2
> vitis_hls -p pointnet_opt3

Vitis HLSを䜿うのはここたでで、これ以降は、Vivadoを䜿った䜜業に移りたす。 続いお、このIPコアを、別のIPコアず組み合わせお、ボヌドデザむンを甚意したす。 今回は、ボヌドデザむンの䜜成に぀いおは省略したす。 最初に、vivadoディレクトリに移っお、䜜業甚のディレクトリを準備したす。 続いおCMakeプロゞェクトを構成し、所望のボヌドデザむンをmakeで䜜成したす。

# 䜜業甚ディレクトリの準備
> cd vivado
> mkdir build
> mkdir work
> mkdir bitstream

> cd build

# CMakeプロゞェクトを構成
# settings64.shによっおCMakeが曞き換えられるので、システムのCMakeを䜿う
# Vitis HLSによるIPコアの合成が終わっおいないず゚ラヌ
> /usr/bin/cmake ..

# ナむヌブ実装のIPコアから、ボヌドデザむンを䜜成
> make pointnet_naive_150_create

# 最適化枈みのIPコアから、ボヌドデザむンを䜜成
> make pointnet_opt1_create
> make pointnet_opt2_create
> make pointnet_opt3_create

ボヌドデザむンを䜜成したら、GUIを起動しお、ブロック図をみおみたしょう。

> cd vivado/work
> vivado -project pointnet_naive_150/pointnet_naive_150.xpr
> vivado -project pointnet_opt1/pointnet_opt1.xpr
> vivado -project pointnet_opt2/pointnet_opt2.xpr
> vivado -project pointnet_opt3/pointnet_opt3.xpr

巊偎のFlow Navigatorから、「Open Block Design」を遞択するず、ブロック図を衚瀺できたす。

ブロック図を拡倧したものが以䞋です。

ボヌドデザむンに察しお、論理合成ず配眮配線を行い、回路情報をたずめたビットストリヌム (Bitstream) を䜜成したしょう。 マシンのスペックにもよりたすが、こちらの環境では、1぀のボヌドデザむンの論理合成ず配眮配線に、30分以䞊掛かりたした (8コアを䜿った堎合)。 今回のGitHubリポゞトリには、ビットストリヌムも入れおあるので、この䜜業は必芁ありたせん (詊しおみおも倧䞈倫です)。

> cd vivado/build
> make pointnet_naive_150_impl && make pointnet_naive_150_copy_bitstream
> make pointnet_opt1_impl && make pointnet_opt1_copy_bitstream
> make pointnet_opt2_impl && make pointnet_opt2_copy_bitstream
> make pointnet_opt3_impl && make pointnet_opt3_copy_bitstream

もう䞀床GUIを起動しお、合成枈みの回路をみおみたしょう。 巊偎のFlow Navigatorから、「Open Implemented Design」を遞択したす。 個人的には、ニュヌペヌクのマンハッタンのようにみえお、矎しいず思いたす。 GUI䞊で、リ゜ヌスの䜿甚率 (Utilization) や、電力消費の芋積もり (Power)、タむミング (Timing) などを確認できたす。

vivado/bitstreamディレクトリ以䞋に、生成されたビットストリヌムがコピヌされたす。 ビットストリヌム (拡匵子.bit) の他に、Hardware Handoffファむル (拡匵子.hwh) もありたす。 Handoffファむルには、回路のメタデヌタが含たれたす。 FPGAボヌドにビットストリヌムをロヌドするためには、2぀のファむルがセットで必芁になりたす。 ビットストリヌムを読み盎せば、動かす回路を䜕床でも切り替えられるずいうのが、ASICに察するFPGAの倧きな利点です。 さお、これらのファむルをscpなどでFPGAボヌド䞊に転送すれば、回路を動かす準備が敎いたす。

> cd vivado/bitstream
> ls
-rw-rw-r-- 1 x x  19M Dec 14 23:34 pointnet_naive_150.bit
-rw-rw-r-- 1 x x 363K Dec 14 23:34 pointnet_naive_150.hwh
-rw-rw-r-- 1 x x  19M Dec 15 00:01 pointnet_opt1.bit
-rw-rw-r-- 1 x x 363K Dec 15 00:01 pointnet_opt1.hwh
-rw-rw-r-- 1 x x  19M Dec 14 23:20 pointnet_opt2.bit
-rw-rw-r-- 1 x x 363K Dec 14 23:20 pointnet_opt2.hwh
-rw-rw-r-- 1 x x  19M Dec 15 18:07 pointnet_opt3.bit
-rw-rw-r-- 1 x x 363K Dec 15 18:07 pointnet_opt3.hwh

回路を動かす

ビットストリヌムを甚意できたので、いよいよ回路を動かしおみたす。 今回䜿甚するFPGAボヌド、Xilinx ZCU104 Evaluation Kitは、SoC (System-on-Chip) ずよばれおいたす。 FPGAの他に、クアッドコア ARM Cortex-A53 CPU (1.2GHz)、2GBのDRAMや、様々な呚蟺回路が統合されおいお、Linuxが動䜜したす。 ここではOSずしお、Ubuntu 20.04をベヌスずしたPynq Linux 2.7を䜿いたす。 Pynq LinuxにはpynqずよばれるPythonのラむブラリが付属しおおり、PythonからFPGA関連の凊理を簡単に行えたす。

以䞋を詊すためには、Pynq Linux䞊に、PyTorch 1.11.0や、TorchVision 0.12.0、NumPy、SciPy、H5py、Tqdmなどのラむブラリを予めむンストヌルする必芁がありたすが、ここでは説明が長くなっおしたうため割愛したす。 基本的にはpipコマンドでむンストヌルできたす。 なお、Xilinx ZCU104、Pynq Linux 2.7甚にビルドされたPyTorch 1.11.0、TorchVision 0.12.0のWheelファむルは、こちらのリポゞトリに眮いおありたす。 ここたで苊劎しお、なぜFPGA䞊で機械孊習モデルを動かそうずするのか、たたに自問自答するこずがありたす。

これ以降はC/C++ではなく、Pythonのコヌドを曞いおいきたす。

最初に、PyTorchのモデルの定矩を再掲したす (net/model.py)。 䜕の捻りもなく、シンプルですね。

class PointNetFeat(torch.nn.Module):
    def __init__(self):
        super().__init__()

        self.conv1 = torch.nn.Conv1d(3, 64, 1)
        self.conv2 = torch.nn.Conv1d(64, 64, 1)
        self.conv3 = torch.nn.Conv1d(64, 64, 1)
        self.conv4 = torch.nn.Conv1d(64, 128, 1)
        self.conv5 = torch.nn.Conv1d(128, 1024, 1)
        self.bn1 = torch.nn.BatchNorm1d(64)
        self.bn2 = torch.nn.BatchNorm1d(64)
        self.bn3 = torch.nn.BatchNorm1d(64)
        self.bn4 = torch.nn.BatchNorm1d(128)
        self.bn5 = torch.nn.BatchNorm1d(1024)

    def forward(self, x: torch.Tensor):
        # `x` is of size [B, N, 3]
        N = x.shape[1]
        # `x` is of size [B, 3, N]
        x = x.transpose(1, 2)

        # `x` is of size [B, 1024, N]
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))
        x = F.relu(self.bn4(self.conv4(x)))
        x = F.relu(self.bn5(self.conv5(x)))

        # `x` is of size [B, 1024]
        x = torch.max(x, dim=2)[0]

        return x

class PointNetCls(torch.nn.Module):
    def __init__(self, num_classes: int):
        super().__init__()

        # Feature extraction
        self.feat = PointNetFeat()

        # Classification network
        self.fc1 = torch.nn.Linear(1024, 512)
        self.fc2 = torch.nn.Linear(512, 256)
        self.fc3 = torch.nn.Linear(256, num_classes)
        self.bn1 = torch.nn.BatchNorm1d(512)
        self.bn2 = torch.nn.BatchNorm1d(256)

    def forward(self, x):
        # `x` is of size [B, N, 3]
        # `x` is of size [B, 1024]
        x = self.feat(x)

        # `x` is of size [B, `num_classes`]
        x = F.relu(self.bn1(self.fc1(x)))
        x = F.relu(self.bn2(self.fc2(x)))
        x = self.fc3(x)

        return x

次に、FPGAで高速化されたモデルを瀺したす (host/model_zcu104.py)。 モデルの名前はPointNetClsZCU104です。 䞊蚘のCPU版のモデル (PointNetCls) ず、䜿い勝手が同じになるようにしたした。

from net.model import PointNetCls

# Split the 64-bit address
def split_address(addr: int) -> Tuple[int, int]:
    mask = (1 << 32) - 1
    return addr & mask, addr >> 32

# Allocate a contiguous buffer for torch.nn.Conv1d (torch.nn.Linear)
def allocate_linear_buffer(in_dims: int, out_dims: int) \
    -> pynq.buffer.PynqBuffer:
    buf_size = in_dims * out_dims + out_dims
    return pynq.allocate(shape=(buf_size,), dtype=np.float32, cacheable=False)

# Allocate a contiguous buffer for a block with torch.nn.Conv1d
# (torch.nn.Linear) and torch.nn.BatchNorm1d
def allocate_block_buffer(in_dims: int, out_dims: int) \
    -> pynq.buffer.PynqBuffer:
    buf_size = 0
    buf_size += in_dims * out_dims + out_dims
    buf_size += out_dims * 3
    return pynq.allocate(shape=(buf_size,), dtype=np.float32, cacheable=False)

# Write the torch.nn.Conv1d parameters to the contiguous buffer
def write_conv1d_params(buf: pynq.buffer.PynqBuffer,
                        layer: torch.nn.Conv1d,
                        offset: int = 0) -> int:
    if layer.kernel_size != (1,):
        raise RuntimeError(f"Kernel size should be 1")

    weight_size = layer.out_channels * layer.in_channels
    bias_size = layer.out_channels

    buf[offset:offset+weight_size] = layer.weight.data.view(-1)
    offset += weight_size
    buf[offset:offset+bias_size] = layer.bias.data.view(-1)
    offset += bias_size

    return offset

# Write the torch.nn.Linear parameters to the contiguous buffer
def write_linear_params(buf: pynq.buffer.PynqBuffer,
                        layer: torch.nn.Linear,
                        offset: int = 0) -> int:
    weight_size = layer.out_features * layer.in_features
    bias_size = layer.out_features

    buf[offset:offset+weight_size] = layer.weight.data.view(-1)
    offset += weight_size
    buf[offset:offset+bias_size] = layer.bias.data.view(-1)
    offset += bias_size

    return offset

# Write the torch.nn.BatchNorm1d parameters to the contiguous buffer
def write_batchnorm1d_params(buf: pynq.buffer.PynqBuffer,
                             layer: torch.nn.BatchNorm1d,
                             offset: int = 0) -> int:
    dims = layer.num_features

    # `scale` is the multiplication of the weight and reciprocal of the
    # standard deviation (to reduce the on-chip memory consumption)
    std_inv = torch.sqrt(layer.running_var.data + layer.eps)
    std_inv = torch.reciprocal(std_inv)
    scale = std_inv * layer.weight.data

    buf[offset:offset+dims] = scale.data.view(-1)
    offset += dims
    buf[offset:offset+dims] = layer.bias.data.view(-1)
    offset += dims
    buf[offset:offset+dims] = layer.running_mean.data.view(-1)
    offset += dims

    return offset

# Write the block (torch.nn.Conv1d and torch.nn.BatchNorm1d) parameters
# to the contiguous buffer
def write_conv_batchnorm1d_params(buf: pynq.buffer.PynqBuffer,
                                  conv: torch.nn.Conv1d,
                                  bn: torch.nn.BatchNorm1d):
    offset = 0
    offset = write_conv1d_params(buf, conv, offset)
    offset = write_batchnorm1d_params(buf, bn, offset)

# Write the block (torch.nn.Linear and torch.nn.BatchNorm1d) parameters
# to the contiguous buffer
def write_linear_batchnorm1d_params(buf: pynq.buffer.PynqBuffer,
                                    linear: torch.nn.Linear,
                                    bn: torch.nn.BatchNorm1d):
    offset = 0
    offset = write_linear_params(buf, linear, offset)
    offset = write_batchnorm1d_params(buf, bn, offset)

class PointNetClsZCU104(torch.nn.Module):
    # Operation modes (refer to hls/src/op_modes.hpp)
    MODE_INIT_WEIGHTS = 100
    MODE_INFERENCE = 101

    def __init__(self, model_cpu: PointNetCls,
                 overlay_path: str, num_points: int):
        super().__init__()

        # Load an overlay
        self.overlay = self.load_overlay(overlay_path)
        # Get the IP core module
        self.net_ip: pynq.DefaultIP = self.overlay.PointNetClsTop
        # Get the control registers of the IP core
        self.registers = self.net_ip.register_map

        # Check the data width of the AXI master interface
        net_ip_params = self.overlay.ip_dict["PointNetClsTop"]["parameters"]
        self.axi_m_addr_width = int(net_ip_params["C_M_AXI_GMEM0_ADDR_WIDTH"])
        self.axi_m_data_width = int(net_ip_params["C_M_AXI_GMEM0_DATA_WIDTH"])

        # Allocate buffers for PointNet feature extraction network
        self.buf_feat_params1 = allocate_block_buffer(3, 64)
        self.buf_feat_params2 = allocate_block_buffer(64, 64)
        self.buf_feat_params3 = allocate_block_buffer(64, 64)
        self.buf_feat_params4 = allocate_block_buffer(64, 128)
        self.buf_feat_params5 = allocate_block_buffer(128, 1024)

        # Allocate buffers for classification network
        self.buf_cls_params1 = allocate_block_buffer(1024, 512)
        self.buf_cls_params2 = allocate_block_buffer(512, 256)
        self.buf_cls_params3 = allocate_linear_buffer(256, 40)

        # Allocate a buffer for point cloud
        self.num_points = num_points
        if self.axi_m_data_width == 32:
            self.buf_point_cloud: pynq.buffer.PynqBuffer = pynq.allocate(
                shape=(self.num_points, 3), dtype=np.float32, cacheable=False)
        elif self.axi_m_data_width == 64:
            self.buf_point_cloud: pynq.buffer.PynqBuffer = pynq.allocate(
                shape=(self.num_points, 4), dtype=np.float32, cacheable=False)
        else:
            raise RuntimeError(f"Unexpected data width for AXI master")

        # Allocate a buffer for output logits
        self.buf_out_logits: pynq.buffer.PynqBuffer = pynq.allocate(
            shape=(40,), dtype=np.float32, cacheable=False)

        # Copy parameters for PointNet feature extraction network
        write_conv_batchnorm1d_params(self.buf_feat_params1,
            model_cpu.feat.conv1, model_cpu.feat.bn1)
        write_conv_batchnorm1d_params(self.buf_feat_params2,
            model_cpu.feat.conv2, model_cpu.feat.bn2)
        write_conv_batchnorm1d_params(self.buf_feat_params3,
            model_cpu.feat.conv3, model_cpu.feat.bn3)
        write_conv_batchnorm1d_params(self.buf_feat_params4,
            model_cpu.feat.conv4, model_cpu.feat.bn4)
        write_conv_batchnorm1d_params(self.buf_feat_params5,
            model_cpu.feat.conv5, model_cpu.feat.bn5)

        # Copy parameters for classification network
        write_linear_batchnorm1d_params(self.buf_cls_params1,
            model_cpu.fc1, model_cpu.bn1)
        write_linear_batchnorm1d_params(self.buf_cls_params2,
            model_cpu.fc2, model_cpu.bn2)
        write_linear_params(self.buf_cls_params3, model_cpu.fc3)

        # Set the physical addresses of the buffers
        self.registers.point_cloud_1, self.registers.point_cloud_2 = \
            split_address(self.buf_point_cloud.device_address)
        self.registers.out_logits_1, self.registers.out_logits_2 = \
            split_address(self.buf_out_logits.device_address)
        self.registers.feat_params1_1, self.registers.feat_params1_2 = \
            split_address(self.buf_feat_params1.device_address)
        self.registers.feat_params2_1, self.registers.feat_params2_2 = \
            split_address(self.buf_feat_params2.device_address)
        self.registers.feat_params3_1, self.registers.feat_params3_2 = \
            split_address(self.buf_feat_params3.device_address)
        self.registers.feat_params4_1, self.registers.feat_params4_2 = \
            split_address(self.buf_feat_params4.device_address)
        self.registers.feat_params5_1, self.registers.feat_params5_2 = \
            split_address(self.buf_feat_params5.device_address)
        self.registers.cls_params1_1, self.registers.cls_params1_2 = \
            split_address(self.buf_cls_params1.device_address)
        self.registers.cls_params2_1, self.registers.cls_params2_2 = \
            split_address(self.buf_cls_params2.device_address)
        self.registers.cls_params3_1, self.registers.cls_params3_2 = \
            split_address(self.buf_cls_params3.device_address)

        # Synchronize the buffers
        self.buf_feat_params1.sync_to_device()
        self.buf_feat_params2.sync_to_device()
        self.buf_feat_params3.sync_to_device()
        self.buf_feat_params4.sync_to_device()
        self.buf_feat_params5.sync_to_device()
        self.buf_cls_params1.sync_to_device()
        self.buf_cls_params2.sync_to_device()
        self.buf_cls_params3.sync_to_device()

        # Initialize the weights (transfer the weights to the on-chip buffers)
        self.registers.op_mode = PointNetClsZCU104.MODE_INIT_WEIGHTS
        self.registers.CTRL.AP_START = 1
        self.wait_for_ip()

    def load_overlay(self, overlay_path):
        overlay = pynq.Overlay(overlay_path)

        if not overlay.is_loaded():
            raise RuntimeError(f"Unable to load overlay: {overlay_path}")

        return overlay

    def wait_for_ip(self):
        while self.registers.CTRL.AP_DONE == 0:
            pass

    def forward(self, x: torch.Tensor):
        # `x` is of size [B, N, 3]
        if x.ndim != 3 or x.shape[2] != 3:
            raise RuntimeError(f"Unexpected shape of the input: {x.shape}")

        batch_size = x.shape[0]
        num_points = x.shape[1]

        # Reallocate the buffer for point cloud if necessary
        if num_points > self.num_points:
            self.num_points = num_points
            self.buf_point_cloud.freebuffer()
            if self.axi_m_data_width == 32:
                self.buf_point_cloud: pynq.buffer.PynqBuffer = pynq.allocate(
                    shape=(self.num_points, 3),
                    dtype=np.float32, cacheable=False)
            elif self.axi_m_data_width == 64:
                self.buf_point_cloud: pynq.buffer.PynqBuffer = pynq.allocate(
                    shape=(self.num_points, 4),
                    dtype=np.float32, cacheable=False)
            else:
                raise RuntimeError(f"Unexpected data width for AXI master")
            self.registers.point_cloud_1, self.registers.point_cloud_2 = \
                split_address(self.buf_point_cloud.device_address)

        # Allocate the Tensor for output
        out = torch.empty(size=(batch_size, 40),
                          dtype=x.dtype, device=x.device)

        # Run the inference
        self.registers.op_mode = PointNetClsZCU104.MODE_INFERENCE
        self.registers.num_points = num_points

        for i in range(batch_size):
            # Copy the input point cloud
            self.buf_point_cloud[:num_points, :3] = x[i].view(-1, 3)
            self.buf_point_cloud.sync_to_device()

            # Run the inference
            self.registers.CTRL.AP_START = 1
            self.wait_for_ip()

            # Copy the output logits
            self.buf_out_logits.sync_from_device()
            out[i, :] = torch.from_numpy(self.buf_out_logits)

        return out

IPコアの初期化

PointNetClsZCU104クラスのコンストラクタで、以䞋のような手順で初期化し、IPコアを䜿えるようにしたす。 この手順で行う必芁はありたせん。 各手順に぀いお、順番に説明したす。 詳しくは、Pynqの公匏ドキュメントをご芧ください。

  1. ビットストリヌムのロヌド (load_overlay)
  2. DRAMバッファの確保 (allocate_block_buffer、pynq.allocate)
  3. DRAMバッファぞのパラメヌタのコピヌ (write_conv_batchnorm1d_params、write_linear_batchnorm1d_params、write_linear_params)
  4. DRAMバッファの物理アドレスを、ポヌトのレゞスタに察しお蚭定
  5. DRAMバッファの内容を同期 (sync_to_device)
  6. 重み初期化モヌドで、IPコアを動䜜させ、DRAMバッファに眮かれたパラメヌタをオンチップバッファ䞊にコピヌ
  7. IPコアの動䜜終了を埅機 (wait_for_ip)

ビットストリヌムを操䜜するためのクラスはpynq.Overlayであり、ファむルパスを䞎えお、指定したビットストリヌムをロヌドしたす。 拡匵子が.bitのビットストリヌムの他に、.hwhのHandoffファむルも必芁です。 ビットストリヌムがpath/to/X.bitであれば、察応するHandoffがpath/to/X.hwhになければ゚ラヌずなりたす。 pynq.Overlayクラスのむンスタンスself.overlayを起点ずしお、FPGAに察する様々な凊理を行っおいきたす。

オヌバヌレむ (ビットストリヌム) をロヌドしたら、自䜜のIPコアPointNetClsTopを取り出しお、self.net_ipに栌玍したす。 IPコアのプロパティ名は、ボヌドデザむンにおける各IPの名前ず察応しおいたす (こちらの画像を参照。) 䟋えば、割蟌みコントロヌラ (AXI Interrupt Controller) には、axi_intc_0プロパティを通じおアクセスできたす。 IPコアを操䜜するためのクラスは、デフォルトではpynq.DefaultIPずなっおいたす。 このクラスを継承しお、自䜜のIPコアをより䟿利に䜿えるように、様々なメ゜ッドを远加するこずも可胜です。 さらに、IPコアの制埡レゞスタにアクセスするためのむンタフェヌスregister_map (pynq.registers.RegisterMapのサブクラス) を取り出しお、self.registersに栌玍したす。

次の3行で、IPコアの入出力ポヌトのアドレス幅ずデヌタ幅を調べお、self.axi_m_addr_widthおよびself.axi_m_data_widthに栌玍したす。 前者は64、埌者は32たたは64です (入出力ポヌトの型をap_uint<64>*ずした堎合は64、float*のたたであれば32)。 前述の通り、ポヌト幅が32ビットであれば、点矀バッファのサむズは$(N, 3)$でよいのですが、64ビットの堎合は、デヌタを2぀ず぀読み取る関係䞊、バッファサむズを$(N, 4)$にする必芁がありたす。 self.axi_m_data_widthを参照すれば、点矀バッファのサむズを決定できたす。

続いお、パラメヌタや入出力を保持するためのDRAMバッファを確保したす。 このバッファは少し特殊なもので、LinuxカヌネルのCMA (Contiguous Memory Allocator) ずいう機胜により確保されたす。 通垞のmalloc()やnewを䜿っおバッファを確保するず、そのバッファぞの仮想アドレスしか分かりたせん。 䞀方、FPGA偎からは、物理アドレスを䜿甚しおバッファにアクセスするので、仮想アドレスだけでなく、物理アドレスも予め知っおおく必芁がありたす。

allocate_linear_buffer関数は、その名の通り、党結合局 (入力次元in_dims、出力次元out_dims) のパラメヌタ甚のバッファを確保したす。 最初に、党結合局の重み (in_dims * out_dims) ずバむアス (out_dims) の芁玠数を足しお、バッファサむズを決定したす。 続いお、pynq.allocate関数を呌び出しお、指定したサむズおよびデヌタ型np.float32 (float) の、1次元のバッファを確保したす。 このバッファはDRAMの特殊な領域に眮かれお、メモリ䞊で連続しおいるこずが保蚌されたす。 allocate_block_buffer関数は、党結合局ずバッチ正芏化局のパラメヌタを保持するためのバッファを確保したす。 党パラメヌタの芁玠数を足し合わせおサむズを決定し、pynq.allocate関数を䜿っお、1次元のバッファを確保したす。 これらのバッファはpynq.buffer.PynqBufferクラスのむンスタンスですが、NumPy配列 (np.ndarray) ず同じように利甚できたす。 䟋えば、torch.from_numpy関数により、PyTorchのテン゜ルに倉換できたす。

特城抜出ネットワヌク (buf_feat_params1からbuf_feat_params5) ず、分類ネットワヌク (buf_cls_params1からbuf_cls_params3) のパラメヌタ甚のバッファを確保したす。 その埌、入力 (点矀) ず出力 (ロゞット) 甚のバッファも確保したす。 入力に぀いおは䞊述の通り、ポヌトのビット幅が64であれば(self.num_points, 4)、32であれば(self.num_points, 3)ずしたす。

DRAMバッファを確保し終えたら、次はモデルのパラメヌタをバッファぞコピヌしたす。 モデルはPointNetClsクラスのむンスタンスで、コンストラクタの匕数model_cpuずしお枡されたす。 write_conv1d_params、write_linear_paramsは、それぞれtorch.nn.Conv1d、torch.nn.Linearのパラメヌタのコピヌに䜿われたす。 write_conv1d_paramsでは、カヌネルサむズが1である (それゆえ党結合局torch.nn.Linearず動䜜が同じである) こずを前提ずしたす。 重みずバむアスの順で、指定された1次元のDRAMバッファに䞊べおゆきたす。 IPコア偎の期埅通りにデヌタが配眮されるように、现心の泚意を払う必芁がありたす。 これら2぀の関数は、高䜍合成の実装における、ReadLinearParamsNaiveやReadLinearParamsOpt1ず適合するように䜜られおいたす。

write_batchnorm1d_paramsは、torch.nn.BatchNorm1dのパラメヌタを、指定されたDRAMバッファにコピヌしたす。 IPコア偎では、ReadBatchNorm1dParamsNaiveやReadBatchNorm1dParamsOpt1に瀺すように、スケヌル、バむアス、平均の順で、パラメヌタが䞊ぶこずを期埅しおいたす。 バッチ正芏化局の分散ず重みから、スケヌルを蚈算しおいたす (蚈算匏に぀いおは先述)。

write_conv_batchnorm1d_paramsずwrite_linear_batchnorm1d_paramsは、党結合局 (torch.nn.Conv1d、torch.nn.Linear) ずバッチ正芏化局 (torch.nn.BatchNorm1d) のパラメヌタを、指定されたDRAMバッファにコピヌしたす。 党結合局の重み、バむアス、それからバッチ正芏化局のスケヌル、バむアス、平均を、この順で䞊べる必芁がありたす。 IPコア偎のReadBlockParamsNaive、ReadBlockParamsOpt1、ReadBlockParamsOpt2ず察応するこずが分かりたす。 モデルのパラメヌタはPyTorchのテン゜ルですが、そのたたDRAMバッファ (pynq.buffer.PynqBuffer) に代入できたす。

パラメヌタを無事にコピヌできたので、DRAMバッファの物理アドレスを蚭定したす。 IPコアのトップ関数PointNetClsTopは次のように宣蚀されおいたした (float*の代わりにap_uint<64>*もあり)。

void PointNetClsTop(const int op_mode,
                    const float* point_cloud,
                    const int num_points,
                    float* out_logits,
                    const float* feat_params1,
                    const float* feat_params2,
                    const float* feat_params3,
                    const float* feat_params4,
                    const float* feat_params5,
                    const float* cls_params1,
                    const float* cls_params2,
                    const float* cls_params3)
{
#pragma HLS INTERFACE m_axi port=point_cloud offset=slave bundle=gmem0
#pragma HLS INTERFACE m_axi port=out_logits offset=slave bundle=gmem0
#pragma HLS INTERFACE m_axi port=feat_params1 offset=slave bundle=gmem0
// ...
#pragma HLS INTERFACE m_axi port=cls_params3 offset=slave bundle=gmem0

#pragma HLS INTERFACE s_axilite port=op_mode bundle=control
#pragma HLS INTERFACE s_axilite port=point_cloud bundle=control
#pragma HLS INTERFACE s_axilite port=num_points bundle=control
#pragma HLS INTERFACE s_axilite port=out_logits bundle=control
#pragma HLS INTERFACE s_axilite port=feat_params1 bundle=control
// ...
#pragma HLS INTERFACE s_axilite port=cls_params3 bundle=control
#pragma HLS INTERFACE s_axilite port=return bundle=control
}

op_modeずnum_pointsを陀く、DRAMバッファ甚の入出力ポヌトに぀いお、#pragma HLS INTERFACE m_axiず#pragma HLS INTERFACE s_axiliteの蚘述がみられたす。 この2぀のHLSプラグマを付䞎するず、各ポヌトに察しお、DRAMバッファの物理アドレスを指定するための、制埡レゞスタが䜜成されたす。 アドレスは64ビットですが、制埡レゞスタのデヌタ幅は32ビットなので、䞊䜍32ビットず䞋䜍32ビット甚に、2぀の制埡レゞスタが甚意されたす。 䟋えば、point_cloudポヌトに぀いおは、point_cloud_1 (䞋䜍32ビット) ず、point_cloud_2 (䞊䜍32ビット) の、2぀です。 DRAMバッファの物理アドレスを蚭定すれば、ポヌトずDRAMバッファずが玐づけられ、FPGA偎からバッファにアクセスできるようになりたす。 Pynqラむブラリを䜿うず、普通に倀を代入しおいるようにみえたすが、実際には、メモリマップトI/Oで実珟されおいたす。 蚀い換えるず、各制埡レゞスタには専甚のアドレスが割り振られおおり、そのアドレスに察しお読み曞きしおいたす。 制埡レゞスタぞのアクセスには、先ほどのself.registersを利甚したす。

op_modeずnum_pointsに぀いおも、#pragma HLS INTERFACE s_axiliteの蚘述があるので、この2぀ (動䜜モヌドず点の個数) を蚭定するための制埡レゞスタが甚意されたす。

ここたで枈んだら、sync_to_deviceメ゜ッドによりDRAMバッファの内容を同期させお、FPGA偎から正しく読めるようにしたす。

最埌に、動䜜モヌドop_modeを重み初期化に蚭定し、制埡レゞスタのうちCTRL.AP_STARTを1にするこずで、IPコアの動䜜を開始したす。 重み初期化モヌドでは、DRAMバッファからパラメヌタを読み出しお、オンチップバッファに栌玍したす。 #pragma HLS INTERFACE s_axilite port=return bundle=controlの蚘述があるおかげで、゜フトりェア偎からIPコアを制埡するためのCTRLレゞスタが甚意されたす。 IPコアの動䜜を開始したら、wait_for_ipメ゜ッドを呌んで、動䜜終了 (パラメヌタの転送完了) を埅機したす。 wait_for_ipメ゜ッド内では、CTRLレゞスタのAP_DONEが1になるたで、ビゞヌりェむトしたす。 以䞊で初期化がおしたいです。

掚論

初期化には様々な工皋があっお面倒ですが、掚論は比范的簡単です。 PyTorchの通垞のモゞュヌルず同じく、forwardメ゜ッドに掚論凊理を蚘述したす。 入力点矀xは、サむズが$(B, N, 3)$のバッチであるずしたす ($B$はバッチサむズ、$N$は点の個数)。 今回のIPコアは、バッチデヌタを扱うようには䜜っおいないので、バッチ内の各サンプルを1぀ず぀凊理するこずになりたす。 出力outは、物䜓のクラス数を$K$ずするず、サむズが$(B, K)$ずなりたす。 今回はModelNet40ずよばれるデヌタセットを䜿うので、クラス数は$K = 40$です。

最初に、点矀のサむズ$N$が、点矀甚に確保しおある珟圚のDRAMバッファよりも倧きければ、DRAMバッファを確保し盎したす。 続いお、バッチ内の各サンプルに察しお掚論凊理を行っお、物䜓の各クラスに察するロゞット (スコア) を蚈算したす。 点矀甚のDRAMバッファbuf_point_cloudに点矀デヌタをコピヌしお、FPGA偎から正しく読み出せるように、バッファを同期したす。 ゜フトりェア偎からは、入出力ポヌトの幅 (32か64かどうか) はそれほど意識する必芁がありたせん。 2぀の制埡レゞスタ (動䜜モヌドop_modeず点の個数num_points) は、予め蚭定しおおきたす。

CTRLレゞスタのAP_STARTを1にするこずで、掚論モヌドでのIPコアの動䜜を開始したす。 wait_for_ipメ゜ッドにより動䜜の終了を埅機したす。 モデルの出力であるロゞットは、IPコア偎からDRAMバッファbuf_out_logitsに曞き蟌たれおいるので、それをPyTorchのテン゜ルに倉換したうえで、出力甚のテン゜ルoutに改めお曞き蟌みたす。 以䞊が掚論凊理の説明でした。

このように、IPコアの実装だけでなく、それを実際に䜿うためのドラむバも甚意する必芁があるので、手間が掛かりたすね。 今回は、Pynqラむブラリを䜿ったので、FPGAに関する凊理は、比范的容易に蚘述できたした。 たた、CPU・GPU版のモデルず同じように䜿いたいので、PyTorchのモゞュヌル (torch.nn.Module) ずしおドラむバを䜜成したした。 Pythonの代わりにC++を䜿うこずも、もちろん可胜です。 その堎合は、ビットストリヌムのロヌド (䟋えばこちら)、メモリマップトI/Oの準備 (䟋えばこちら)、DRAMバッファの確保 (䟋えばこちら)などを、C++で蚘述するこずになりたす (Pynqラむブラリをそのたた移怍したのを芚えおいたす)。

評䟡

掚論時間の比范

ようやく、IPコアを䜿った評䟡に入りたした。 最初に、掚論時間を比范しおみたしょう。 以䞋の゜ヌスコヌドを利甚したす (host/time_zcu104.py)。

def main():
    # Parse the command-line arguments
    args = parse_command_line()

    # Create a PointNet classification model
    model = PointNetCls(num_classes=40)
    # Create an FPGA model
    model_zcu104 = PointNetClsZCU104(model, args.bitstream, args.num_points)

    model.eval()
    model_zcu104.eval()

    # Test the output
    # Create a random input point cloud
    point_cloud = torch.rand(size=(1, args.num_points, 3))
    out_cpu = model(point_cloud)
    out_zcu104 = model_zcu104(point_cloud)

    print(f"Output (CPU):\n{out_cpu}")
    print(f"Output (FPGA):\n{out_zcu104}")

    # Measure the inference times
    times_cpu = []
    times_zcu104 = []

    for _ in range(args.runs):
        # Create a random input point cloud
        point_cloud = torch.rand(size=(1, args.num_points, 3))

        t0 = time.monotonic()
        model(point_cloud)
        elapsed_cpu = (time.monotonic() - t0) * 1e3

        t0 = time.monotonic()
        model_zcu104(point_cloud)
        elapsed_zcu104 = (time.monotonic() - t0) * 1e3

        times_cpu.append(elapsed_cpu)
        times_zcu104.append(elapsed_zcu104)

    time_avg_cpu = np.mean(times_cpu)
    time_std_cpu = np.std(times_cpu)
    time_avg_zcu104 = np.mean(times_zcu104)
    time_std_zcu104 = np.std(times_zcu104)
    speedup_factor = time_avg_cpu / time_avg_zcu104

    print(f"Inference time (CPU): " \
          f"mean: {time_avg_cpu:.3f}ms, " \
          f"std: {time_std_cpu:.3f}ms")
    print(f"Inference time (FPGA): " \
          f"mean: {time_avg_zcu104:.3f}ms, " \
          f"std: {time_std_zcu104:.3f}ms")
    print(f"Speedup: {speedup_factor:.3f}x")

ここでは粟床は気にしないので、孊習枈みのモデルをロヌドする凊理は省かれおいたす。 䜆し、CPU版のモデルPointNetClsず、FPGA版のモデルPointNetClsZCU104ずで、パラメヌタを揃える必芁はありたす。 たた、CPU版のモデルはevalモヌドで動䜜させたす。 バッチ正芏化局の挙動が蚓緎モヌドになり、バッチ数が1のずきに゚ラヌずなりたす。 たた、蚓緎枈みのパラメヌタではなく、入力のバッチから平均や暙準偏差が蚈算されるので、FPGA版のモデルず出力結果が合わなくなりたす。 指定された回数args.runsだけ、掚論時間の蚈枬を行い、平均ず暙準偏差、たた高速化率を算出したす。 たた最初に、双方のモデルの出力が合っおいるかどうか (倧䜓近い倀が出力されるか) を確認しおいたす (本圓は、IPコアの䜜成時にもテストしたす)。

FPGAボヌド䞊で以䞋のコマンドを実行したす。

> cd advent_2022_point_cloud_classification/host

# ナむヌブ実装 (動䜜呚波数150MHz)
> sudo XILINX_XRT=/usr ./time_zcu104.sh ../vivado/bitstream/pointnet_naive_150.bit

# デヌタ䞊列性を掻甚した (ルヌプアンロヌリングず配列の分割を枈たせた) 実装 (動䜜呚波数150MHz)
> sudo XILINX_XRT=/usr ./time_zcu104.sh ../vivado/bitstream/pointnet_opt1.bit

# デヌタフロヌ最適化を枈たせた実装 (動䜜呚波数150MHz)
> sudo XILINX_XRT=/usr ./time_zcu104.sh ../vivado/bitstream/pointnet_opt2.bit

# 入出力のポヌト幅を64ビットに広げた実装 (動䜜呚波数150MHz)
> sudo XILINX_XRT=/usr ./time_zcu104.sh ../vivado/bitstream/pointnet_opt3.bit

ナむヌブな実装でテストした堎合の出力䟋を以䞋に瀺したす。

$ sudo XILINX_XRT=/usr ./time_zcu104.sh ../vivado/bitstream/pointnet_naive_150.bit
Output (CPU):
tensor([[-0.0594, -0.0272,  0.0115, -0.0481, -0.0529,  0.0449, -0.0634, -0.0328,
          0.0348, -0.0071, -0.0228,  0.0412,  0.0128, -0.0175, -0.0086, -0.0023,
         -0.0192, -0.0101, -0.0072,  0.0520, -0.0106, -0.0110,  0.0113,  0.0499,
         -0.0563, -0.0523, -0.0711, -0.0104, -0.0048, -0.0404,  0.0375,  0.0089,
          0.0326, -0.0408, -0.0302, -0.0041,  0.0534, -0.0349,  0.0380, -0.0020]],
       grad_fn=<AddmmBackward0>)
Output (FPGA):
tensor([[-0.0592, -0.0274,  0.0114, -0.0491, -0.0527,  0.0446, -0.0632, -0.0335,
          0.0337, -0.0071, -0.0258,  0.0399,  0.0119, -0.0170, -0.0091, -0.0030,
         -0.0216, -0.0112, -0.0106,  0.0522, -0.0111, -0.0130,  0.0114,  0.0487,
         -0.0571, -0.0523, -0.0714, -0.0103, -0.0058, -0.0389,  0.0383,  0.0068,
          0.0306, -0.0421, -0.0314, -0.0052,  0.0539, -0.0360,  0.0399, -0.0031]])
Inference time (CPU): mean: 369.048ms, std: 1.086ms
Inference time (FPGA): mean: 1071.358ms, std: 0.023ms
Speedup: 0.344x

CPU版のモデルではfloatを䜿いたすが、FPGA版のモデルでは固定小数点数 (ap_fixed) を䜿っおいるので、同じモデルパラメヌタず入力を䞎えおも、出力結果には倚少のずれが生じたす (ここでは、固定小数点数のビット幅を32ビット、敎数郚を16ビット、小数郚を16ビットに蚭定しおいたす)。 しかし、CPU版ずFPGA版のモデルで、倧䜓䌌たような出力が埗られおいたす (小数第2䜍くらいたでは合っおいたす)。 クラス分類問題であれば、これで問題ないず思いたす。 掚論時間をみるず、ナむヌブな実装では、CPU版のモデルよりも3倍皋床遅いこずが分かりたす。

各実装に察する掚論時間をたずめたす。

実装 掚論時間の平均 (ms) 暙準偏差 (ms) 高速化率 (゜フトりェア比) 高速化率 (ナむヌブ実装比)
CPU版 369.0 1.086 1.0x 2.904x
ナむヌブ (100MHz) 1606.4 0.041 0.230x 0.667x
ナむヌブ (150MHz) 1071.4 0.023 0.344x 1.0x
ナむヌブ (200MHz) 872.05 0.077 0.423x 1.223x
ナむヌブ (250MHz) 665.33 0.073 0.555x 1.610x
デヌタ䞊列性 (150MHz) 34.60 0.027 10.66x 30.97x
デヌタフロヌ最適化 (150MHz) 12.93 0.016 28.54x 82.86x
ポヌト幅拡匵 (150MHz) 10.80 0.012 34.17x 99.20x

ナむヌブな実装 (150MHz) は、CPUに比べお性胜がたったの0.344倍でした。 ナむヌブな実装のたたでは、動䜜呚波数を250MHzたで䞊げおも、䟝然ずしおCPUよりも遅いです。 デヌタ䞊列性の利甚によっお、掚論時間は30.97倍も短瞮され、CPUに比べお10.66倍高速になりたした。 Vitis HLSにより出力されたクロックサむクル数をみるず、ナむヌブな実装 (150MHz) では161,945,604 (1.079s)、䞊列化埌の実装では4,462,596 (29.72ms)ずなっおいたす。 実際には、前者は1.071s、埌者は34.60msなので、倧䜓合っおいるずいえたす。 特城抜出ネットワヌクにおけるデヌタフロヌ最適化の掻甚によっお、掚論時間はさらに2.68倍短瞮され、CPUに比べお28.54倍、圓初のナむヌブな実装に比べお82.86倍も高速になりたした。 さらにポヌト幅を32ビットから64ビットに拡匵するこずで、䞻に分類ネットワヌクが高速化されたした。 掚論時間は1.20倍短瞮され、CPUに比べお34.17倍、圓初のナむヌブな実装ず比べるず99.20倍の高速化ずなりたした。 このように、各皮最適化を斜すこずで、着実に高速化できおいるこずが分かりたす。 しかも、基本的には、各皮HLSプラグマを挿入するだけよいので、非垞に楜です。

粟床

぀ぎにモデルの分類粟床をみおみたしょう。 ここではModelNet40デヌタセットの、テストデヌタを利甚したす。 デヌタセットはこちらからダりンロヌドできたす。 各サンプルは、飛行機、自動車、ラップトップ、人間など、単䞀の物䜓を衚すCADモデルから埗られた、2048個の点をも぀点矀です。 以䞋の゜ヌスコヌドを利甚したす (host/test_zcu104.py)。 デヌタセットの凊理や、モデルの蚓緎に぀いおは、GitHubのリポゞトリを参照しおください。

def test(args: argparse.Namespace,
         model: torch.nn.Module,
         model_zcu104: torch.nn.Module,
         test_loader: torch.utils.data.DataLoader):
    print(f"Testing PointNet ...")

    # model.eval()
    model_zcu104.eval()

    # test_loss_total = 0.0
    # correct = 0
    test_loss_total_zcu104 = 0.0
    correct_zcu104 = 0

    with torch.no_grad():
        for i, batch in enumerate(test_loader):
            if i % 5 == 0:
                print(f"Testing batch {i} ...")

            data, target = batch["points"], batch["label"]

            # out = model(data)
            # pred = out.argmax(dim=1, keepdim=True)
            # loss = F.cross_entropy(out, target)
            # correct += pred.eq(target.view_as(pred)).sum().item()
            # test_loss_total += loss.item() * len(data)

            out_zcu104 = model_zcu104(data)
            pred_zcu104 = out_zcu104.argmax(dim=1, keepdim=True)
            loss_zcu104 = F.cross_entropy(out_zcu104, target)
            correct_zcu104 += pred_zcu104.eq(
                target.view_as(pred_zcu104)).sum().item()
            test_loss_total_zcu104 += loss_zcu104.item() * len(data)

    # test_loss_avg = test_loss_total / len(test_loader.dataset)
    # test_acc = correct * 1e2 / len(test_loader.dataset)
    test_loss_avg_zcu104 = test_loss_total_zcu104 / len(test_loader.dataset)
    test_acc_zcu104 = correct_zcu104 * 1e2 / len(test_loader.dataset)

    # print(f"Test result (CPU): " \
    #       f"loss: {test_loss_avg:.6f}, " \
    #       f"accuracy: {test_acc:.3f}%, " \
    #       f"correct: {correct}")
    print(f"Test result (FPGA): " \
          f"loss: {test_loss_avg_zcu104:.6f}, " \
          f"accuracy: {test_acc_zcu104:.3f}%, " \
          f"correct: {correct_zcu104}, " \
          f"total: {len(test_loader.dataset)}")

FPGAボヌド䞊で以䞋のコマンドを実行したす。

> cd advent_2022_point_cloud_classification/host

# デヌタ䞊列性を掻甚した (ルヌプアンロヌリングず配列の分割を枈たせた) 実装 (動䜜呚波数150MHz)
> sudo XILINX_XRT=/usr ./test_zcu104.sh ../vivado/bitstream/pointnet_opt1.bit

# デヌタフロヌ最適化を枈たせた実装 (動䜜呚波数150MHz)
> sudo XILINX_XRT=/usr ./test_zcu104.sh ../vivado/bitstream/pointnet_opt2.bit

# 入出力のポヌト幅を64ビットに広げた実装 (動䜜呚波数150MHz)
> sudo XILINX_XRT=/usr ./test_zcu104.sh ../vivado/bitstream/pointnet_opt3.bit

出力結果の䟋を以䞋に瀺したす。

> sudo XILINX_XRT=/usr ./test_zcu104.sh ../vivado/bitstream/pointnet_opt1.bit
Testing batch 0 ....
Testing batch 5 ...
...
Testing batch 2445 ...
Testing batch 2450 ...
Testing batch 2455 ...
Testing batch 2460 ...
Testing batch 2465 ...
Test result (FPGA): loss: 0.375841, accuracy: 89.506%, correct: 2209, total: 2468

各実装に察する粟床をたずめたす。 党郚で2,468個のテストサンプルがありたす。 ナむヌブ実装に関しおは、時間が掛かりすぎるので省略しおいたす。

実装 正解数 粟床
CPU版 2209 89.506%
デヌタ䞊列性 (150MHz) 2209 89.506%
デヌタフロヌ最適化 (150MHz) 2209 89.506%
ポヌト幅拡匵 (150MHz) 2209 89.506%

いずれのIPコアも、CPU䞊で動かした堎合ず党く同じ粟床が埗られおいたす。 floatの代わりに固定小数点数ap_fixedを䜿っおいたすが、いたのずころは粟床䜎䞋はみられたせん。

リ゜ヌス消費

各皮IPコアの、リ゜ヌス消費を調べおみたしょう。 リ゜ヌス消費は、LUT (ルックアップテヌブル)、FF (フリップフロップ)、BRAM (BlockRAM)、URAM (UltraRAM)、DSP (Digital Signal Processor)の5぀に分類されたす。

リ゜ヌス消費を衚にたずめたす。

実装 LUT FF BRAM (36Kb) URAM DSP
合蚈 230,400 460,800 312 96 1,728
ナむヌブ (100MHz) 22,378 (9.71%) 11,045 (2.40%) 149.5 (47.92%) 2 (2.08%) 48 (2.78%)
ナむヌブ (150MHz) 22,140 (9.61%) 12,428 (2.70%) 161.5 (51.76%) 2 (2.08%) 48 (2.78%)
ナむヌブ (200MHz) 21,344 (9.26%) 13,616 (2.95%) 149.5 (47.92%) 2 (2.08%) 48 (2.78%)
ナむヌブ (250MHz) 20,663 (8.97%) 14,713 (3.19%) 149.5 (47.92%) 2 (2.08%) 20 (1.16%)
デヌタ䞊列性 (150MHz) 58,223 (25.27%) 42,755 (9.28%) 287.5 (92.15%) 0 (0.00%) 768 (44.44%)
デヌタフロヌ最適化 (150MHz) 136,408 (59.20%) 48,940 (10.62%) 310.5 (99.52%) 0 (0.00%) 808 (46.76%)
ポヌト幅拡匵 (150MHz) 84,263 (36.57%) 49,660 (10.78%) 263.5 (84.46%) 64 (66.67%) 808 (46.76%)

デヌタ䞊列性を掻甚するず、耇数の積和挔算を䞊列に行う必芁があるため、DSPの消費が倧幅に増加しおいるこずが分かりたす。 䞀方、デヌタフロヌ最適化を甚いおも、リ゜ヌス消費はそれほど増えおいたせん (ただし、BRAMが䞍足しお、LUTをLUTRAMずしお䜿っおいるので、LUTの消費は増加しおいたす)。 デヌタフロヌ最適化によっお、リ゜ヌス消費の増加を抑え぀぀、回路の性胜を改善できたす。 ポヌト幅を拡匵しおも、URAM以倖のリ゜ヌス消費はあたり倉わっおいたせん (BRAMが䞍足しお゚ラヌになったので、オンチップバッファの䞀郚をURAMで実装しおいたす)。

今回は20䞇円皋床するFPGAボヌド、Xilinx ZCU104 Evaluation Kitを䜿いたした。 このボヌドのFPGAチップ (XCZU7EV-2FFVC1156) には、BRAMだけでなくURAMも提䟛されおいるので、比范的倧きなオンチップバッファ (数MB皋床) を䜜成できたす。 URAM (UltraRAM) はBRAM (BlockRAM) に比べお個数が少ないですが (BRAMが312個に察しおURAMは96個)、1個あたりの容量は倧きいので、粗粒床だずいえたす。 䜎䟡栌のFPGAボヌドだず、URAMが提䟛されおいないので、BRAMを倧切に䜿う必芁がありたす。 個人的には、BRAMが䞀番最初に枯枇するこずが倚いです (FPGAに慣れおいない初心者なので、うたく実装できたせん)。

倀のビット幅削枛

いたたでは、局の入出力やモデルのパラメヌタを衚珟するのに、32ビットの固定小数点数 (敎数郚ず小数郚が16ビットず぀) を䜿っおいたした。 粟床をある皋床保ったたた、ビット数 (リ゜ヌス消費) を抑えられるでしょうか。 ここでは、以䞋のビット数の組み合わせで、IPコア (動䜜呚波数150MHz) を䜜っおみたしょう。 IPコアは、デヌタ䞊列性を掻甚、デヌタフロヌ最適化を斜し、さらにポヌト幅を拡匵したバヌゞョンです。 これらのビット数は䜕ずなく決めたした。 モデルのパラメヌタの方は、局の入出力に比べお倀域が狭いので、よりビット数を削枛できるかもしれたせん。

名前 局の入出力 (value_t) モデルのパラメヌタ (param_t)
28-28 28ビット (敎数郚14 + 小数郚14) 28ビット (敎数郚10 + 小数郚18)
28-24 28ビット (敎数郚14 + 小数郚14) 24ビット (敎数郚8 + 小数郚16)
24-24 24ビット (敎数郚12 + 小数郚12) 24ビット (敎数郚8 + 小数郚16)
24-20 24ビット (敎数郚12 + 小数郚12) 20ビット (敎数郚6 + 小数郚14)
24-16 24ビット (敎数郚12 + 小数郚12) 16ビット (敎数郚4 + 小数郚12)
20-20 20ビット (敎数郚10 + 小数郚10) 20ビット (敎数郚6 + 小数郚14)
20-16 20ビット (敎数郚10 + 小数郚10) 16ビット (敎数郚4 + 小数郚12)

各実装における粟床をたずめたす。

実装 正解数 粟床
CPU版 2209 89.506%
ポヌト幅拡匵 (150MHz) 2209 89.506%
ポヌト幅拡匵 (150MHz、28-28) 2206 89.384%
ポヌト幅拡匵 (150MHz、28-24) 2206 89.384%
ポヌト幅拡匵 (150MHz、24-24) 2200 89.141%
ポヌト幅拡匵 (150MHz、24-20) 550 22.285%
ポヌト幅拡匵 (150MHz、24-16) 121 4.903%
ポヌト幅拡匵 (150MHz、20-20) 448 18.152%
ポヌト幅拡匵 (150MHz、20-16) 122 4.903%

たた、リ゜ヌス消費を以䞋にたずめたす。

実装 LUT FF BRAM (36Kb) URAM DSP
合蚈 230,400 460,800 312 96 1,728
ポヌト幅拡匵 (150MHz) 84,263 (36.57%) 49,660 (10.78%) 263.5 (84.46%) 64 (66.67%) 808 (46.76%)
ポヌト幅拡匵 (150MHz、28-28) 74,342 (32.27%) 47,267 (10.26%) 261.5 (83.81%) 64 (66.67%) 808 (46.76%)
ポヌト幅拡匵 (150MHz、28-24) 63,749 (27.67%) 39,139 (8.49%) 257 (82.37%) 64 (66.67%) 404 (23.38%)
ポヌト幅拡匵 (150MHz、24-24) 59,970 (26.03%) 36,240 (7.86%) 257 (82.37%) 64 (66.67%) 404 (23.38%)
ポヌト幅拡匵 (150MHz、24-20) 75,997 (32.98%) 40,762 (8.85%) 259 (83.01%) 64 (66.67%) 202 (11.69%)

ビット数を削枛しおも、掚論時間は倉わりたせんでした。 ビット数の削枛に応じお、実装を少し盎す必芁がありそうです。

䞊蚘の結果をみるず、重みのビット数を24ビットから20ビットに削枛した途端に、分類粟床が䞀気に䜎䞋しおいるこずが分かりたす (ここたでの急激な䜎䞋には驚きたした)。 局の入出力ずモデルのパラメヌタをいずれも24ビットに蚭定したIPコアが、最もリ゜ヌス効率がよいずいえたす。 リ゜ヌス消費をみるず、ビット数を削枛するこずで回路の耇雑さが埐々に䞋がっおゆき、それに䌎っおLUTやFFの䜿甚量が挞枛しおいたす。 28ビットから24ビットに萜ずすず、積和挔算に必芁なDSPブロックの数が半枛しおいるこずが分かりたす。 24ビットから20ビットにするず、DSPの䜿甚量はさらに半枛しおいたす (その分LUTずFFが増加しおいたす)。 BRAMやURAMに぀いおは、ビット数をもう少し枛らさないず、消費量が枛らないようです (オンチップメモリの䞍足が頭痛の皮になりたす)。

おわりに

今回は、FPGAを甚いお、点矀の分類タスクを高速化したした。 分類タスクには、軜量か぀シンプルなPointNetを利甚したした。 FPGAのリ゜ヌス消費を抑えるため、モデルを簡略化し、たた蚈算順序を倉曎したした。 続いお、Xilinx瀟の高䜍合成ツヌルVitis HLS 2022.1を䜿っお、PointNet甚のカスタムIPコアを䜜成したした。 パむプラむン化、局の蚈算の䞊列化 (ルヌプのアンロヌリングず配列の分割)、デヌタフロヌ最適化などを䜿っお、IPコアの実装を少しず぀改善しおいきたした。

IPコアを他のIPコアず繋ぎ合わせおボヌドデザむンを䜜成し、Xilinx Vivado 2022.1により論理合成・配眮配線を行っお、FPGAに曞き蟌み可胜なビットストリヌムを䜜成したした。 ビットストリヌムをロヌドしお高速に掚論するためのドラむバを、Pynqラむブラリにより蚘述したした。 ModelNet40デヌタセットを䜿甚し、Xilinx ZCU104 Evaluation Kit䞊で、掚論時間、リ゜ヌス消費、分類粟床の3぀の芳点から評䟡を行いたした。 たた、耇数のボヌドデザむンでの性胜を比范するこずで、各皮最適化による効果を調べたした。 ビット数を削枛し、リ゜ヌス効率を改善するこずも詊みたした。

高䜍合成ツヌルを䜿うこずで、Verilog HDLなしで、C/C++だけで、高効率なIPコアを䜜成できたした。 しかしそれでも、PyTorchなどの深局孊習ラむブラリを䜿うのず比べお、゜ヌスコヌドの蚘述量は䜕倍も倚くなりたした。 内郚で行われおいる凊理の流れをよく芳察しお、党お理解しないず、それを高速化するIPコアも䜜成できたせん。 リ゜ヌス制玄、デヌタ転送など、考えなくおはならない事柄も倚いです。 䜜業工皋が倚くお倧倉ですが、自䜜のIPコアが正しく動䜜した (゜フトりェア実装ず同じような出力が埗られた) ずきや、実装を高速化できたずきの歓びは、そのぶん倧きいず思いたす。 有難うございたした。

GPUっお䟿利だなあず思うこずしきりです。