Skip to content

Commit

Permalink
Merge branch 'pre-release' into master for the first 2020.03 release
Browse files Browse the repository at this point in the history
  • Loading branch information
ivannz committed Mar 12, 2020
2 parents ec26a6e + cb5b7e6 commit 0a6bad6
Show file tree
Hide file tree
Showing 44 changed files with 1,781 additions and 994 deletions.
29 changes: 29 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Version 2020.03

## Major changes in `.nn`
* The structure of the `.nn` sub-module now more closely resembles that of `torch`
_ `.base` : `CplxToCplx` and parameter type `CplxParameter`
_ `.casting` : real-Cplx tensor conversion layers
_ `.linear`, `.conv`, `.activation` : essential layers and activations
_ `.container` : sequential container which explicitly checks types of internal layers
_ `.extra` : 1-dim Bernoulli Dropout for complex-valued tensors (Cplx)
* `CplxToCplx` can now promote torch's univariate functions to split-complex activations, e.g. use `CplxToCplx[AvgPoool1d]` instead of `CplxAvgPool1d`
* Niche complex-valued containers were removed, dropped dedicated activations, like `CplxLog` and `CplxExp`


## Major changes in `.nn.relevance`
* misnamed Bayesian layers in `.nn.relevance` were moved around and corrected
- layers in `.real` and `.complex` were renamed to Var Dropout, with deprecation warnings for old names
- `.ard` implements the Bayesian Dropout methods with Automatic Relevance Determination priors
* `.extensions` submodule contains relaxations, approximations, and related but non-Bayesian layers
- `\ell_0` stochastic regularization layer was moved to `.real`
- `Lasso` was kept to illustrate extensibility, but similarly moved to `.real`
- Variational Dropout approximations and speeds ups were moved to `.complex`


## Enhancements
* `CplxParameter` now supports real-to-complex promotion during `load_state_dict`
* added submodule-specific readme's, explaining typical use cases and peculiarities

# Prior to 2020.03
Prior version used different version numbering and although the layers are backwards compatible, their location within the library was much different.
83 changes: 83 additions & 0 deletions PLAN.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# Plan for the first release

[+] make load_state_dict respect components of CplxParameter and allow promoting real-tensors to complex-tensors provided the state dict has no .real or .imag, but a correct key referring to the parameter.

[+] fix the incorrect naming of bayesain methods in `nn.relevance`
* rename `*ARD` named layers in `.real` and `.complex` to `*VD` layers, since they use log-uniform prior and thus are in fact Variational Dropout layers
* start deprecating importing `*ARD` named layers from `.real` and `.complex`
* fix aliases of imported layers in `.extensions`
* expose all base VD/ARD layers in `\_\_init\_\_.py` and require importing modifications from `.extensions`
* fix the text in nn/relevance/README.md

[+] fix the names for L0 regularized layer whith in fact performs probabilistic sparsification, and is not related to Variational inference

[+] check if `setup.py` has correct requirements and specifiy them explicitly
* `requires` is not a keyword, use `install_requires` and `tests_require`

[+] investigate reordering base classes in `LinearMasked(MaskedWeightMixin, Linear, _BaseRealMixin)` and similar in `nn.masked`.
* could moving it further into the bases result in a slower property lookup? It seems no:
* from python decsriptors [doc](https://docs.python.org/3/howto/descriptor.html)
> The implementation works through a precedence chain that gives data descriptors priority over instance variables, instance variables priority over non-data descriptors, and assigns lowest priority to \_\_getattr\_\_
* lookup order is thus by \_\_getattribute\_\_: descriptors (aka @property), instance \_\_dict\_\_, class attributes \_\_dict\_\_, and lastly \_\_getattr\_\_.
* moved MaskedWeightMixin into \_BaseMixin

[+] get rid of `torch_module` from `.utils` and declare `activations` explicitly

[+] clean up the `nn` module itself
* remove crap from `.sequential`: `CplxResidualBottleneck`, `CplxResidualSequential` and CplxBusResidualSequential must go, and move CplxSequential to base layers
* split `.layers`, `.activation`, and `.sequential`
* `.modules.base` : base classes (CplxToCplx, BaseRealToCplx, BaseCplxToReal), and parameter type (CplxParameter, CplxParameterAccessor)
* `.modules.casting` : converting real tensors in various formats to and from Cplx (InterleavedRealToCplx, ConcatenatedRealToCplx, CplxToInterleavedReal, CplxToConcatenatedReal, AsTypeCplx)
* `.modules.linear` : Linear, Bilinear, Identity, PhaseShift
* `.modules.conv` : everything convolutional
* `.modules.activation` : activations (CplxModReLU, CplxAdaptiveModReLU, CplxModulus, CplxAngle) and layers (CplxReal, CplxImag)
* `.modules.container` : CplxSequential
* `.modules.extra` : Dropout, AvgPool1d
* move `.batchnorm` to modules, keep `.init` in `.nn`
* fix imports from adjacent modules: `nn.masked` and `nn.relevance`.

[+] in `nn.relevance.complex` : drop `Cplx(*map(torch.randn_like, (s2, s2)))` and write `Cplx(torch.randn_like(s2), torch.randn_like(s2))` explicitly
* implemented `cplx.randn` and `cplx.randn_like`

[+] residual clean up in `nn` module
* `.activation` : `CplxActivation` is the same as CplxToCplx[...]
* CplxActivation promotes classic (real) torch functions to split activations, so yes.
* See if it is possible to implement function promotion through CplxToCplx[...]
* it is possbile: just reuse CplxActivation
* Currently CplxToCplx promotes layers and real functions to inpdependently applied layers/functions (split)
* how should we proceed with cplx. trig functions? a wrapper, or hardcoded activations?
* the latter seems more natural, as the trig functions are vendored by this module
* since torch is the base, and implements a great number of univariate tensor functions and could potentially be extended, it is more natural to use a wrapper (rationale behind CplxToCplx[...]).
* `.modules.extra` : this needs thorough cleaning
* drop CplxResidualBottleneck, CplxResidualSequential and CplxBusResidualSequential
* abandon `torch_module` and code the trig activations by hand.
* remove alias CplxDropout1d : use torch.nn names as much as possible
* deprecate CplxAvgPool1d: it can be created in runtime with CplxToCplx\[torch.nn.AvgPool1d\]

[+] documentation for bayesian and maskable layers
* in `nn.relevance.base`, making it like in `nn.masked`
* classes in `nn.relevance` `.real` and `.complex` should be also documented properly, the same goes for `.extensions`

[+] restrucure the extensions and non-bayesian layers
* new folder structure
* take ard-related declarations and move them to `relevance/ard.py`, everythin else to a submodule
* `.extensions` submodule:
* `complex` for cplx specific etended layers: bogus penalties, approximations and other stuff, -- not directly related to variational dropout or automatic relevance determination
* `real` for supplementary real-valued layers
* decide the fate of `lasso` class in `nn.relevance`:
* it is irrelevant to Bayesian methods: move it to `extensions/real`

[+] documentation
* go through README-s in each submodule to make sure that info there is correct and typical use cases described
* `nn.init` : document the initializations according to Trabelsi et al. (2018)
* seems to be automatically documented using `functools.wraps` from the original `torch.nn.init` procedures.

[+] add missing tests to the unit test suite
* tests for `*state_dict` api compliance of `nn.masked` and `nn.base.CplxParameter`
* implementing these test helped figure out and fix edge cases and fix them, so yay for TDD!

[ ] Improve implementation
* (Bernoulli Dropout) need 1d (exists), 2d and 3d
* (Convolutions) implement 3d convolutions and 3d vardropout convolutions both real and complex
* (Transposed Convolutions) figure out the math and implement var dropout for transposed convos

2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

A lightweight extension for `torch.nn` that adds layers and activations, which respect algebraic operations over the field of complex numbers.

The core implementation of the complex-valued batch normalization and weight initialization layers is based on the ICLR 2018 parer by Chiheb Trabelsi et al. on Deep Complex Networks _[1]_ and borrows ideas from the [implementation](https://github.com/ChihebTrabelsi/deep_complex_networks). Real-valued variational dropout and automatic relevance determination are based on the profound works by Diederik Kingma et al. (2015) _[2]_ and Dmitry Molchanov et al. (2017) _[3]_. Complex-valued Bayesian sparsification layers are based on original research.
The core implementation of the complex-valued batch normalization and weight initialization layers is based on the ICLR 2018 parer by Chiheb Trabelsi et al. on Deep Complex Networks _[1]_ and borrows ideas from the [implementation](https://github.com/ChihebTrabelsi/deep_complex_networks). Real-valued variational dropout and automatic relevance determination are original implementations based on the profound works by Diederik Kingma et al. (2015) _[2]_ and Dmitry Molchanov et al. (2017) _[3]_. Complex-valued Bayesian sparsification layers are based on original research.

# Installation

Expand Down
19 changes: 19 additions & 0 deletions cplxmodule/cplx.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch
import torch.nn.functional as F

from math import sqrt
from .utils import complex_view, fix_dim


Expand Down Expand Up @@ -462,6 +463,24 @@ def tanh(input):
return sinh(input) / cosh(input)


def randn(*size, dtype=None, device=None, requires_grad=False):
"""Generate standard complex Gaussian noise."""
normal = torch.randn(2, *size, dtype=dtype, layout=torch.strided,
device=device, requires_grad=False) / sqrt(2)
z = Cplx(normal[0], normal[1])
return z.requires_grad_(True) if requires_grad else z


def randn_like(input, dtype=None, device=None, requires_grad=False):
"""Returns a tensor with the same size as `input` that is filled with
standard comlpex Gaussian random numbers.
"""
return randn(*input.size(),
dtype=input.dtype if dtype is None else dtype,
device=input.device if device is None else device,
requires_grad=requires_grad)


def modrelu(input, threshold=0.5):
r"""Compute the modulus relu of the complex tensor in re-im pair."""
# scale = (1 - \trfac{b}{|z|})_+
Expand Down
62 changes: 34 additions & 28 deletions cplxmodule/nn/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,25 +4,31 @@

### Real-Complex Conversion layers

* ConcatenatedRealToCplx
* CplxToConcatenatedReal
* InterleavedRealToCplx
* CplxToInterleavedReal
* ConcatenatedRealToCplx, CplxToConcatenatedReal
* InterleavedRealToCplx (RealToCplx), CplxToInterleavedReal (CplxToReal)
* AsTypeCplx
* CplxReal
* CplxImag

### Base building blocks
### Basic building blocks

* CplxLinear
* CplxConv1d
* CplxConv2d
* CplxBilinear
* CplxReal, CplxImag
* CplxIdentity, CplxLinear, CplxBilinear
* CplxConv1d, CplxConv2d
* CplxSequential

### Complex activation layers

* CplxModulus, CplxAngle
* CplxModReLU, CplxAdaptiveModReLU

### Complex batch normalization

Batch normalization layers, based on 2d vector whitening proposed in _[1]_, are provided by `nn.modules.batchnorm`.

* CplxBatchNorm1d, CplxBatchNorm2d, CplxBatchNorm3d

### Miscellaneous layers

* CplxDropout
* CplxAvgPool1d
* CplxPhaseShift

### Complex-valued parameter representation
Expand All @@ -38,6 +44,8 @@ The base class for complex-valued layers is `CplxToCplx`. It does not have `__in
It is possible to promote an existing real-valued module to complex-valued module, which is *shared* between the real and imaginary parts and acts on them independently, i.e. the same layer is applied twice. For example, the typical use case is to convert a real-valued activation to split complex-valued acitvation:

```python
import torch

from cplxmodule import cplx
from cplxmodule.nn import CplxToCplx

Expand All @@ -54,17 +62,16 @@ z = cplx.Cplx(torch.ones(1, 1), - torch.ones(1, 1))
CplxSharedLinear(1, 3, bias=False)(z)
```

## Initialization

Functions in `nn.init` implement various random initialization strategies suitable for complex-valued layers, that were researched in _[1]_.
It is also possible to promote a unary and not-inplace real-valued function from `torch.` to a complex-valued split activation of tranformation, i.e.
```python
CplxSplitSin = CplxToCplx[torch.sin]

## BatchNorm layers
CplxSplitSin()(z)
```

Whitening-based batch normalization layers proposed in _[1]_ are provided by `nn.batchnorm`.
## Initialization

* CplxBatchNorm1d
* CplxBatchNorm2d
* CplxBatchNorm3d
Functions in `nn.init` implement various random initialization strategies suitable for complex-valued layers, that were researched in _[1]_.

## Usage

Expand All @@ -73,7 +80,6 @@ Basically the module is designed in such a way as to be ready for plugging into
Importing the building blocks.
```python
import torch
import torch.nn

# complex valued tensor class
from cplxmodule import cplx
Expand All @@ -82,13 +88,13 @@ from cplxmodule import cplx
from cplxmodule.nn import RealToCplx, CplxToReal

# layers of encapsulating other complex valued layers
from cplxmodule.nn.sequential import CplxSequential
from cplxmodule.nn import CplxSequential

# common layers
from cplxmodule.nn.layers import CplxConv1d, CplxLinear
from cplxmodule.nn import CplxConv1d, CplxLinear

# activation layers
from cplxmodule.nn.activation import CplxModReLU, CplxActivation
from cplxmodule.nn import CplxModReLU
```

After `RealToCplx` layer the intermediate inputs are `Cplx` objects, which are abstractions for complex valued tensors, represented by real and imaginary parts, and which obey complex arithmetic (currently no support for mixed-type arithmetic like `torch.Tensor +/-* Cplx`).
Expand All @@ -100,10 +106,10 @@ cplx = RealToCplx()(z)
print(cplx)
```

Stacking and constructing purely complex-to-complex pipelines with troch.nn.Sequential:
Stacking and constructing purely complex-to-complex pipelines with `torch.nn.Sequential`:
```python
n_features, n_channels = 16, 4
z = torch.randn(256, n_features*2)
z = torch.randn(256, n_channels, n_features * 2)

complex_model = CplxSequential(
CplxLinear(n_features, n_features, bias=True),
Expand All @@ -117,7 +123,7 @@ complex_model = CplxSequential(
# complex: batch x (3 * n_channels) x (n_features - (4-1))
CplxToCplx[torch.nn.Flatten](start_dim=-2),

CplxActivation(torch.tanh),
CplxToCplx[torch.tanh](),
)
```

Expand All @@ -141,7 +147,7 @@ real_input_model = torch.nn.Sequential(
)

print(real_input_model(z).shape)
# >>> torch.Size([256, 312])
# >>> torch.Size([256, 4, 32])
```

# References
Expand Down
13 changes: 3 additions & 10 deletions cplxmodule/nn/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,7 @@
from .layers import RealToCplx, AsTypeCplx
from .layers import CplxToCplx, CplxToReal
from .modules import *

from .layers import CplxParameter

from .layers import CplxLinear
from .layers import CplxBilinear
from .conv import CplxConv1d, CplxConv2d

from .activation import CplxModulus, CplxAngle
from .sequential import CplxSequential
from .modules.base import CplxParameter
from . import init

# from .relevance.real import LinearARD
# from .relevance.real import Conv1dARD, Conv2dARD
Expand Down
4 changes: 2 additions & 2 deletions cplxmodule/nn/init.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def cplx_xavier_uniform_(tensor, gain=1.0):


def cplx_trabelsi_standard_(cplx, kind="glorot"):
"""Standard complex initialization proposed in Trabelsi et al. (2017)."""
"""Standard complex initialization proposed in Trabelsi et al. (2018)."""
kind = kind.lower()
assert kind in ("glorot", "xavier", "kaiming", "he")

Expand All @@ -86,7 +86,7 @@ def cplx_trabelsi_standard_(cplx, kind="glorot"):


def cplx_trabelsi_independent_(cplx, kind="glorot"):
"""Orthogonal complex initialization proposed in Trabelsi et al. (2017)."""
"""Orthogonal complex initialization proposed in Trabelsi et al. (2018)."""
kind = kind.lower()
assert kind in ("glorot", "xavier", "kaiming", "he")

Expand Down
Loading

0 comments on commit 0a6bad6

Please sign in to comment.