-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdeepfreq.py
92 lines (86 loc) · 4.17 KB
/
deepfreq.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
"""
%=============================================================================
% DESCRIPTION:
% This is the Python implementation of the frequency estimation module from
% DeepFreq, for line spectral estimation (LSE) problems. This was written for
% the submitted manuscript:
%
% "Real-time sinusoidal parameter estimation for damage growth monitoring
% during ultrasonic very high cycle fatigue tests."
%
%=============================================================================
% Version 1.1.0, Authored by:
% Shawn L. KISER (Msc) @ https://www.linkedin.com/in/shawn-kiser/
% Laboratoire PIMM, Arts et Metiers Institute of Technology, CNRS, Cnam,
% HESAM Universite, 151 boulevard de l’Hopital, 75013 Paris (France)
%
% Based on:
% [1] G. Izacard, S. Mohan, and C. Fernandez-Granda, “Data-driven Estimation
% of Sinusoid Frequencies,” arXiv:1906.00823 [cs, eess, stat], Feb. 2021,
% Available: http://arxiv.org/abs/1906.00823
%
%=============================================================================
% The MIT License (MIT)
%
% Copyright (c) 2021 Shawn L. KISER
%
% Permission is hereby granted, free of charge, to any person obtaining a copy
% of this software and associated documentation files (the "Software"), to deal
% in the Software without restriction, including without limitation the rights
% to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
% copies of the Software, and to permit persons to whom the Software is
% furnished to do so, subject to the following conditions:
%
% The above copyright notice and this permission notice shall be included in
% all copies or substantial portions of the Software.
%
% THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
% IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
% FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
% AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
% LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
% OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
% THE SOFTWARE.
%=============================================================================
"""
import torch.nn as nn
def set_fr_module(args):
"""
Create a frequency-representation module
"""
net = None
if args.fr_module_type == 'fr':
assert args.fr_size == args.fr_inner_dim * args.fr_upsampling, \
'The desired size of the frequency representation (fr_size) must be equal to inner_dim*upsampling'
net = FrequencyRepresentationModule(signal_dim=args.signal_dim, n_filters=args.fr_n_filters,
inner_dim=args.fr_inner_dim, n_layers=args.fr_n_layers,
upsampling=args.fr_upsampling, kernel_size=args.fr_kernel_size,
kernel_out=args.fr_kernel_out)
else:
raise NotImplementedError('Frequency representation module type not implemented')
return net
class FrequencyRepresentationModule(nn.Module):
def __init__(self, signal_dim=50, n_filters=8, n_layers=3, inner_dim=125,
kernel_size=3, upsampling=8, kernel_out=25):
super().__init__()
self.fr_size = inner_dim * upsampling
self.n_filters = n_filters
self.in_layer = nn.Linear(2 * signal_dim, inner_dim * n_filters, bias=False)
mod = []
for n in range(n_layers):
mod += [
nn.Conv1d(n_filters, n_filters, kernel_size=kernel_size, padding=kernel_size // 2, bias=False,
padding_mode='circular'),
nn.BatchNorm1d(n_filters),
nn.ReLU(),
]
self.mod = nn.Sequential(*mod)
self.out_layer = nn.ConvTranspose1d(n_filters, 1, kernel_out, stride=upsampling,
padding=(kernel_out - upsampling + 1) // 2, output_padding=1, bias=False)
def forward(self, inp):
bsz = inp.size(0)
inp = inp.view(bsz, -1)
x = self.in_layer(inp).view(bsz, self.n_filters, -1)
x = self.mod(x)
x = self.out_layer(x).view(bsz, -1)
return x