Skip to content

Commit

Permalink
Performance improvements of function synsNMF
Browse files Browse the repository at this point in the history
  • Loading branch information
alesantuz committed Dec 19, 2023
1 parent c479a6a commit 318b5aa
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 32 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: musclesyneRgies
Title: Extract Muscle Synergies from Electromyography
Version: 1.2.5.9004
Version: 1.2.5.9005
Authors@R:
person("Alessandro", "Santuz", , "alessandro.santuz@gmail.com", role = c("aut", "cre"),
comment = c(ORCID = "0000-0002-6577-5101"))
Expand Down
12 changes: 12 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,15 @@
# musclesyneRgies 1.2.5.9005 (development version)
### How to install
```
install.packages("remotes")
remotes::install_github("alesantuz/musclesyneRgies")
```
### How to use
README and vignettes are available both on [CRAN](https://CRAN.R-project.org/package=musclesyneRgies) and on [GitHub](https://github.com/alesantuz/musclesyneRgies).

### What's changed
- The function `synsNMF` is now faster.

# musclesyneRgies 1.2.5.9004 (development version)
### How to install
```
Expand Down
78 changes: 47 additions & 31 deletions R/synsNMF.R
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ synsNMF <- function(V,
M_list <- list() # To save factorisation M matrices (muscle weights)
P_list <- list() # To save factorisation P matrices (activation patterns)
Vr_list <- list() # To save factorisation Vr matrices (reconstructed signals)
iters <- numeric() # To save the iterations number
iterations <- numeric() # To save the iterations number

# Remove time column and transpose for upcoming NMF
V <- subset(V, select = -time) |>
Expand Down Expand Up @@ -94,50 +94,67 @@ synsNMF <- function(V,
M_temp <- list()
P_temp <- list()
Vr_temp <- list()
iters <- numeric()

for (run in 1:runs) { # Run NMF multiple times for each syn and choose best run
# Initialise the two factorisation matrices with random values (uniform distribution)
P <- matrix(stats::runif(r * n, min = min(V), max = max(V)), nrow = r, ncol = n)
M <- matrix(stats::runif(m * r, min = min(V), max = max(V)), nrow = m, ncol = r)

# Iteration "zero"
P <- P * crossprod(M, V) / crossprod((crossprod(M, M)), P)
M <- M * tcrossprod(V, P) / tcrossprod(M, tcrossprod(P, P))
Vr <- M %*% P # Reconstructed matrix

# Iteration "zero" and first updates
MV <- crossprod(M, V)
MM <- crossprod(M, M)
P <- P * MV / MM %*% P
VP <- tcrossprod(V, P)
PP <- tcrossprod(P, P)
M <- M * VP / tcrossprod(M, PP)
# Reconstruction
Vr <- M %*% P
# Reconstruction quality (coefficient of determination)
R2 <- 1 - (sum((V - Vr)^2) / sum((V - mean(V))^2))
R2 <- 1 - sum((V - Vr)^2) / sum((V - mean(V))^2)

# l2-norm normalisation which eliminates trivial scale indeterminacies
# l2-norm normalisation eliminates trivial scale indeterminacies
# See Févotte, C., Idier, J. (2011)
l2_norms <- apply(M, 2, function(nn) sqrt(sum(nn^2)))
M <- sweep(M, 2, l2_norms, FUN = "/")
P <- sweep(P, 1, l2_norms, FUN = "*")
# Calculate l2-norm of the columns of M and normalise M and P
l2_norms <- sqrt(colSums(M^2))
M <- M / matrix(l2_norms, nrow = nrow(M), ncol = ncol(M), byrow = TRUE)
P <- P * l2_norms

# Start iterations for NMF convergence
for (iter in 2:max_iter) {
P <- P * crossprod(M, V) / crossprod((crossprod(M, M)), P)
M <- M * tcrossprod(V, P) / tcrossprod(M, tcrossprod(P, P))
iter <- 1
while (iter < max_iter) {
iter <- iter + 1

# Updates
MV <- crossprod(M, V)
MM <- crossprod(M, M)
P <- P * MV / MM %*% P
VP <- tcrossprod(V, P)
PP <- tcrossprod(P, P)
M <- M * VP / tcrossprod(M, PP)

# Reconstruction
Vr <- M %*% P
R2[iter] <- 1 - (sum((V - Vr)^2) / sum((V - mean(V))^2))
# Reconstruction quality
R2[iter] <- 1 - sum((V - Vr)^2) / sum((V - mean(V))^2)

# l2-norm normalisation
l2_norms <- apply(M, 2, function(nn) sqrt(sum(nn^2)))
M <- sweep(M, 2, l2_norms, FUN = "/")
P <- sweep(P, 1, l2_norms, FUN = "*")
l2_norms <- sqrt(colSums(M^2))
M <- M / matrix(l2_norms, nrow = nrow(M), ncol = ncol(M), byrow = TRUE)
P <- P * l2_norms

# Check if the increase of R2 in the last "last_iter" iterations
# is less than the target
if (iter > last_iter &&
R2[iter] - R2[iter - last_iter] < R2[iter] * R2_target / 100) {
R2[iter] - R2[iter - last_iter] < R2[iter] * R2_target / 100) {
break
}
}
R2_choice[run] <- R2[iter]

M_temp[[run]] <- M
P_temp[[run]] <- P
Vr_temp[[run]] <- Vr
iters[run] <- iter
}

choice <- which.max(R2_choice)
Expand All @@ -146,7 +163,7 @@ synsNMF <- function(V,
M_list[[syn_index]] <- M_temp[[choice]]
P_list[[syn_index]] <- P_temp[[choice]]
Vr_list[[syn_index]] <- Vr_temp[[choice]]
iters[syn_index] <- iter
iterations[syn_index] <- iters[choice]
}

if (is.na(fixed_syns)) {
Expand All @@ -155,16 +172,15 @@ synsNMF <- function(V,
iter <- 0 # Initialise iterations
while (MSE > MSE_min) {
iter <- iter + 1
if (iter == r - 1) {
if (iter == max_syns - 1) {
break
}
R2_interp <- data.frame(
synergies = 1:(r - iter + 1),
R2_values = R2_cross[iter:r]
)

lin <- stats::lm(R2_values ~ synergies, R2_interp)$fitted.values
MSE <- sum((lin - R2_interp$R2_values)^2) / nrow(R2_interp)
synergies <- 1:(max_syns - iter + 1)
R2_values <- R2_cross[iter:max_syns]

lin <- stats::lm(R2_values ~ synergies)$fitted.values
MSE <- sum((lin - R2_values)^2) / length(R2_values)
}
syns_R2 <- iter

Expand All @@ -174,7 +190,7 @@ synsNMF <- function(V,
colnames(M_list[[syns_R2]]) <- paste0("Syn", seq_len(ncol(M_list[[syns_R2]])))
M_choice <- M_list[[syns_R2]]
Vr_choice <- Vr_list[[syns_R2]]
iters <- as.numeric(iters[syns_R2])
iterations <- iterations[syns_R2]
} else if (is.numeric(fixed_syns)) {
syns_R2 <- fixed_syns

Expand All @@ -184,7 +200,7 @@ synsNMF <- function(V,
colnames(M_list[[1]]) <- paste0("Syn", seq_len(ncol(M_list[[1]])))
M_choice <- M_list[[1]]
Vr_choice <- Vr_list[[1]]
iters <- as.numeric(iters[1])
iterations <- iterations[1]
}

SYNS <- list(
Expand All @@ -193,7 +209,7 @@ synsNMF <- function(V,
P = P_choice,
V = V,
Vr = Vr_choice,
iterations = iters,
iterations = iterations,
R2 = data.frame(
synergies = min_syns:max_syns,
R2 = R2_cross
Expand Down

0 comments on commit 318b5aa

Please sign in to comment.