diff --git a/se3cnn/point/operations.py b/se3cnn/point/operations.py index 7e9f0d4..b701532 100644 --- a/se3cnn/point/operations.py +++ b/se3cnn/point/operations.py @@ -14,6 +14,7 @@ def forward(self, features, geometry): :param geometry: tensor [batch, point, xyz] :return: tensor [batch, point, channel] """ + assert features.size()[:2] == geometry.size()[:2], "features size ({}) and geometry size ({}) should match".format(features.size(), geometry.size()) rb = geometry.unsqueeze(1) # [batch, 1, b, xyz] ra = geometry.unsqueeze(2) # [batch, a, 1, xyz] k = self.kernel(rb - ra) # [batch, a, b, i, j] @@ -31,6 +32,7 @@ def forward(self, features, geometry): :param geometry: tensor [batch, point, xyz] :return: tensor [batch, point, point, channel] """ + assert features.size()[:2] == geometry.size()[:2], "features size ({}) and geometry size ({}) should match".format(features.size(), geometry.size()) rb = geometry.unsqueeze(1) # [batch, 1, b, xyz] ra = geometry.unsqueeze(2) # [batch, a, 1, xyz] k = self.kernel(rb - ra) # [batch, a, b, i, j] @@ -50,6 +52,7 @@ def forward(self, features, geometry): :param geometry: tensor [batch, point, xyz] :return: tensor [batch, point, channel] """ + assert features.size()[:2] == geometry.size()[:2], "features size ({}) and geometry size ({}) should match".format(features.size(), geometry.size()) batch, n, _ = geometry.size() rb = geometry.unsqueeze(1) # [batch, 1, b, xyz]