Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove channels #340

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 11 additions & 17 deletions cirkit/backend/torch/circuits.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,16 @@ def lookup(
if in_graph is None:
yield layer, ()
continue
# in_graph: An input batch (assignments to variables) of shape (B, C, D)
# in_graph: An input batch (assignments to variables) of shape (B, D)
# scope_idx: The scope of the layers in each fold, a tensor of shape (F, D'), D' < D
# x: (B, C, D) -> (B, C, F, D') -> (F, C, B, D')
x = in_graph[..., layer.scope_idx].permute(2, 1, 0, 3)
# x: (B, D) -> (B, F, D') -> (F, B, D')
if len(in_graph.shape) != 2:
raise ValueError(
"The input to the circuit should have shape (B, D), "
"where B is the batch size and D is the number of variables "
"the circuit is defined on"
)
x = in_graph[..., layer.scope_idx].permute(1, 0, 2)
yield layer, (x,)
continue

Expand Down Expand Up @@ -119,7 +125,6 @@ class AbstractTorchCircuit(TorchDiAcyclicGraph[TorchLayer]):
def __init__(
self,
scope: Scope,
num_channels: int,
layers: Sequence[TorchLayer],
in_layers: dict[TorchLayer, Sequence[TorchLayer]],
outputs: Sequence[TorchLayer],
Expand All @@ -131,7 +136,6 @@ def __init__(

Args:
scope: The variables scope.
num_channels: The number of channels per variable.
layers: The sequence of layers.
in_layers: A dictionary mapping layers to their inputs, if any.
outputs: A list of output layers.
Expand All @@ -146,7 +150,6 @@ def __init__(
fold_idx_info=fold_idx_info,
)
self._scope = scope
self._num_channels = num_channels
self._properties = properties

@property
Expand All @@ -167,15 +170,6 @@ def num_variables(self) -> int:
"""
return len(self.scope)

@property
def num_channels(self) -> int:
"""Retrieve the number of channels of each variable.

Returns:
The number of variables.
"""
return self._num_channels

@property
def properties(self) -> StructuralProperties:
"""Retrieve the structural properties of the circuit.
Expand Down Expand Up @@ -270,8 +264,8 @@ def forward(self, x: Tensor) -> Tensor:
following the topological ordering.

Args:
x: The tensor input of the circuit, with shape $(B, C, D)$, where B is the batch size,
$C$ is the number of channels, and $D$ is the number of variables.
x: The tensor input of the circuit, with shape $(B, D)$, where B is the batch size,
and $D$ is the number of variables.

Returns:
Tensor: The tensor output of the circuit, with shape $(B, O, K)$,
Expand Down
4 changes: 1 addition & 3 deletions cirkit/backend/torch/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,6 @@ def _compile_circuit(self, sc: Circuit) -> AbstractTorchCircuit:
layers = list(compiled_layers_map.values())
cc = cc_cls(
sc.scope,
sc.num_channels,
layers=layers,
in_layers=in_layers,
outputs=outputs,
Expand Down Expand Up @@ -275,7 +274,6 @@ def _fold_circuit(compiler: TorchCompiler, cc: AbstractTorchCircuit) -> Abstract
# Instantiate a folded circuit
return type(cc)(
cc.scope,
cc.num_channels,
layers,
in_layers,
outputs,
Expand Down Expand Up @@ -507,7 +505,7 @@ def match_optimizer_fuse(match: LayerOptMatch) -> tuple[TorchLayer, ...]:
if optimize_result is None:
return cc, False
layers, in_layers, outputs = optimize_result
cc = type(cc)(cc.scope, cc.num_channels, layers, in_layers, outputs, properties=cc.properties)
cc = type(cc)(cc.scope, layers, in_layers, outputs, properties=cc.properties)
return cc, True


Expand Down
47 changes: 25 additions & 22 deletions cirkit/backend/torch/layers/inner.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,12 +155,8 @@ def __init__(
num_folds: The number of channels.

Raises:
NotImplementedError: If the arity is not 2.
ValueError: If the number of input units is not the same as the number of output units.
"""
# TODO: generalize kronecker layer as to support a greater arity
if arity != 2:
raise NotImplementedError("Kronecker only implemented for binary product units.")
super().__init__(
num_input_units,
num_input_units**arity,
Expand All @@ -177,18 +173,25 @@ def config(self) -> Mapping[str, Any]:
}

def forward(self, x: Tensor) -> Tensor:
x0 = x[:, 0].unsqueeze(dim=-1) # shape (F, B, Ki, 1).
x1 = x[:, 1].unsqueeze(dim=-2) # shape (F, B, 1, Ki).
# shape (F, B, Ki, Ki) -> (F, B, Ko=Ki**2).
return self.semiring.mul(x0, x1).flatten(start_dim=-2)
# x: (F, H, B, Ki)
y0 = x[:, 0]
for i in range(1, x.shape[1]):
y0 = y0.unsqueeze(dim=-1) # (F, B, K, 1).
y1 = x[:, i].unsqueeze(dim=-2) # (F, B, 1, Ki).
# y0: (F, B, K=K * Ki).
y0 = torch.flatten(self.semiring.mul(y0, y1), start_dim=-2)
# y0: (F, B, Ko=Ki ** arity)
return y0

def sample(self, x: Tensor) -> tuple[Tensor, Tensor | None]:
# x: (F, H, C, K, num_samples, D)
x0 = x[:, 0].unsqueeze(dim=3) # (F, C, Ki, 1, num_samples, D)
x1 = x[:, 1].unsqueeze(dim=2) # (F, C, 1, Ki, num_samples, D)
# shape (F, C, Ki, Ki, num_samples, D) -> (F, C, Ko=Ki**2, num_samples, D)
x = x0 + x1
return torch.flatten(x, start_dim=2, end_dim=3), None
y0 = x[:, 0]
for i in range(1, x.shape[1]):
y0 = y0.unsqueeze(dim=3) # (F, C, K, 1, num_samples, D)
y1 = x[:, i].unsqueeze(dim=2) # (F, C, 1, Ki, num_samples, D)
y0 = torch.flatten(y0 + y1, start_dim=2, end_dim=3)
# y0: (F, C, Ko=Ki ** arity, num_samples, D)
return y0, None


class TorchSumLayer(TorchInnerLayer):
Expand Down Expand Up @@ -273,11 +276,11 @@ def sample(self, x: Tensor) -> tuple[Tensor, Tensor]:
if negative or not normalized:
raise TypeError("Sampling in sum layers only works with positive weights summing to 1")

# x: (F, H, C, Ki, num_samples, D) -> (F, C, H * Ki, num_samples, D)
x = x.permute(0, 2, 1, 3, 4, 5).flatten(2, 3)
c = x.shape[1]
num_samples = x.shape[3]
d = x.shape[4]
# x: (F, H, Ki, num_samples, D) -> (F, H * Ki, num_samples, D)
x = x.flatten(1, 2)

num_samples = x.shape[2]
d = x.shape[3]

# mixing_distribution: (F, Ko, H * Ki)
mixing_distribution = torch.distributions.Categorical(probs=weight)
Expand All @@ -286,9 +289,9 @@ def sample(self, x: Tensor) -> tuple[Tensor, Tensor]:
mixing_samples = mixing_distribution.sample((num_samples,))
mixing_samples = E.rearrange(mixing_samples, "n f k -> f k n")

# mixing_indices: (F, C, Ko, num_samples, D)
mixing_indices = E.repeat(mixing_samples, "f k n -> f c k n d", c=c, d=d)
# mixing_indices: (F, Ko, num_samples, D)
mixing_indices = E.repeat(mixing_samples, "f k n -> f k n d", d=d)

# x: (F, C, Ko, num_samples, D)
x = torch.gather(x, dim=2, index=mixing_indices)
# x: (F, Ko, num_samples, D)
x = torch.gather(x, dim=1, index=mixing_indices)
return x, mixing_samples
Loading
Loading