Skip to content

Commit

Permalink
Moved get_cluster_fill_rates() and get_writer_profiles() and othe…
Browse files Browse the repository at this point in the history
…r to handwriter

Moved `get_cluster_fill_rates()`, `get_writer_profiles()` and `plot_writer_profiles()` to handwriter because they fit better in that package.

`get_writer_profiles()` and `plot_writer_profiles()` were added since the last version, so these functions were simply removed from handwriterRF and not deprecated.

`get_cluster_fill_rates()` was included in the last version, so it was deprecated.
  • Loading branch information
stephaniereinders committed Dec 12, 2024
1 parent b4cbf1d commit 7ee55cd
Show file tree
Hide file tree
Showing 23 changed files with 166 additions and 253 deletions.
2 changes: 2 additions & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ Depends:
Imports:
dplyr,
handwriter (>= 3.2.3),
lifecycle,
magrittr,
purrr,
ranger,
Expand All @@ -32,3 +33,4 @@ Config/testthat/edition: 3
URL: https://github.com/CSAFE-ISU/handwriterRF
BugReports: https://github.com/CSAFE-ISU/handwriterRF/issues
VignetteBuilder: knitr
Roxygen: list(markdown = TRUE)
3 changes: 1 addition & 2 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,8 @@ export(compare_writer_profiles)
export(get_cluster_fill_rates)
export(get_distances)
export(get_ref_scores)
export(get_writer_profiles)
export(interpret_slr)
export(plot_scores)
export(plot_writer_profiles)
export(train_rf)
importFrom(lifecycle,deprecated)
importFrom(magrittr,"%>%")
4 changes: 3 additions & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

* Created `compare_documents()` to compare two handwritten documents using either a similarity score or a score-based likelihood ratio as a comparison method.

* Created functions `get_writer_profiles()` and `compare_writer_profiles()` to make experiments faster on large numbers of documents compared to `compare_documents()`. `get_writer_profiles()` estimates writer profiles for every handwritten document in a folder. Then `compare_writer_profiles()` calculates either a similarity score or score-based likelihood ratio for every pair of documents.
* Created function `compare_writer_profiles()` to make experiments faster on large numbers of documents compared to `compare_documents()`. `compare_writer_profiles()` calculates either a similarity score or score-based likelihood ratio for every pair of documents.

* Created new data frames of writer profiles `train`, `validation`, and `test`. Created a new `random_forest` from `train`. Created `ref_scores`, a list of same writer and different writer similarity scores, from `validation`.

Expand All @@ -14,6 +14,8 @@

* Created `plot_scores()` to plot histograms of the reference same writer and different writer similarity scores in `random_forest$scores`.

* Deprecated `get_cluster_fill_rates()` in favor of `handwriter::get_cluster_fill_rates()`.

# handwriterRF 1.0.2

* Removed quotes around "same writer" and "different writer" in documentation.
Expand Down
6 changes: 4 additions & 2 deletions R/clusters.R
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@

#' Get Cluster Fill Rates
#'
#' Calculate cluster fill rates from a data frame of cluster fill counts created
#' with \code{\link[handwriter]{get_cluster_fill_counts}}.
#' `r lifecycle::badge('deprecated')`
#' `get_cluster_fill_rates` is deprecated. Use \code{\link[handwriter]{get_cluster_fill_rates}} instead.
#'
#' @param df A data frame of cluster fill rates created with
#' \code{\link[handwriter]{get_cluster_fill_counts}}.
Expand All @@ -35,6 +35,8 @@
#'
#' @md
get_cluster_fill_rates <- function(df) {
lifecycle::deprecate_warn("1.0.3", "get_cluster_fill_rates()", "handwriter::get_cluster_fill_rates()")

# get label columns. docname is required for input data frames but writer and
# doc are optional.
label_cols <- df %>%
Expand Down
79 changes: 6 additions & 73 deletions R/compare.R
Original file line number Diff line number Diff line change
Expand Up @@ -107,10 +107,12 @@ compare_documents <- function(sample1,
params <- copy_samples_to_project_dir(params)

message("Estimating writer profiles...")
profiles <- get_writer_profiles(input_dir = file.path(params$project_dir, "docs"),
template = templateK40,
num_cores = 1,
output_dir = params$project_dir)
profiles <- handwriter::get_writer_profiles(
input_dir = file.path(params$project_dir, "docs"),
measure = "rates",
template = templateK40,
num_cores = 1,
output_dir = params$project_dir)

message("Calculating distance between samples...")
dist_measures <- which_dists(rforest = params$rforest)
Expand Down Expand Up @@ -199,75 +201,6 @@ compare_writer_profiles <- function(
return(df)
}


#' Estimate Writer Profiles
#'
#' Estimate writer profiles from handwritten documents scanned and saved as PNG
#' files. Each file in `input_dir` is split into component shapes called graphs
#' with [`handwriter::process_batch_dir`]. Then the graphs are sorted into
#' clusters with similar shapes using the cluster `template` and
#' [`handwriter::get_clusters_batch`]. An estimate of the writer profile for a
#' document is the proportion of graphs from that document assigned to each of
#' the clusters in `template`. The writer profiles are estimated by running
#' [`handwriter::get_cluster_fill_counts`] and then [`get_cluster_fill_rates`].
#'
#' The functions [`handwriter::process_batch_dir`] and
#' [`handwriter::get_clusters_batch`] take upwards of 30 seconds per document
#' and the results are saved to RDS files in `project_dir` > graphs and
#' `project_dir` > clusters, respectively.
#'
#' @param input_dir A filepath to a folder containing one or more handwritten
#' documents, scanned and saved as PNG file(s).
#' @param num_cores An integer number greater than or equal to 1 of cores to use
#' for parallel processing.
#' @param template Optional. A cluster template created with
#' [`handwriter::make_clustering_template`]. The default is the cluster
#' template `templateK40` included with 'handwriterRF'.
#' @param output_dir Optional. A filepath to a folder to save the RDS files
#' created by [`handwriter::process_batch_dir`] and
#' [`handwriter::get_clusters_batch`]. If no folder is supplied, the RDS files
#' will be saved to the temporary directory and then deleted before the
#' function terminates.
#'
#' @return A data frame
#' @export
#'
#' @examples
#' \donttest{
#' docs <- system.file(file.path("extdata", "docs"), package = "handwriterRF")
#' profiles <- get_writer_profiles(docs)
#'
#' plot_writer_profiles(profiles)
#' }
#'
get_writer_profiles <- function(input_dir, num_cores = 1, template = templateK40, output_dir = NULL) {
if (is.null(output_dir)) {
output_dir <- file.path(tempdir(), "writer_profiles")
create_dir(output_dir)
}

handwriter::process_batch_dir(
input_dir = input_dir,
output_dir = file.path(output_dir, "graphs")
)

clusters <- handwriter::get_clusters_batch(
template = template,
input_dir = file.path(output_dir, "graphs"),
output_dir = file.path(output_dir, "clusters"),
num_cores = num_cores,
save_master_file = FALSE
)
counts <- handwriter::get_cluster_fill_counts(clusters)
profiles <- get_cluster_fill_rates(counts)

if (output_dir == file.path(tempdir(), "writer_profiles")) {
unlink(file.path(tempdir(), "writer_profiles"), recursive = TRUE)
}

return(profiles)
}

# Internal Functions ------------------------------------------------------

handle_null_values <- function(params) {
Expand Down
3 changes: 2 additions & 1 deletion R/data.R
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,8 @@
#' was assigned on each iteration. The output of \code{\link[handwriter]{make_clustering_template}} stores
#' the within cluster distances on each iteration, but the previous iterations were removed here to reduce the file size.}
#' \item{wcss}{A vector of the
#' within-cluster sum of squares on each iteration of the K-means algorithm.}}
#' within-cluster sum of squares on each iteration of the K-means algorithm.}
#' }
#' @examples
#' # view number of clusters
#' templateK40$K
Expand Down
7 changes: 7 additions & 0 deletions R/handwriterRF-package.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
#' @keywords internal
"_PACKAGE"

## usethis namespace: start
#' @importFrom lifecycle deprecated
## usethis namespace: end
NULL
61 changes: 0 additions & 61 deletions R/plots.R
Original file line number Diff line number Diff line change
Expand Up @@ -120,64 +120,3 @@ plot_scores <- function(scores, obs_score = NULL, n_bins = 50) {

return(p)
}

#' Plot Writer Profiles
#'
#' Create a line plot of cluster fill rates for one or more documents, where the
#' cluster fill rates serve as writer profiles. Each cluster fill rates for each
#' document are plotted as different colored lines.
#'
#' @param rates A data frame of cluster fill rates created with
#' \code{\link[handwriterRF]{get_cluster_fill_rates}}
#' @param color_by A column name. 'ggplot2' will always group by docname, but
#' will use this column to assign colors.
#' @param ... Additional arguments passed to `ggplot2::facet_wrap`, such as
#' `facets`, `nrow`, etc.
#'
#' @return A line plot
#'
#' @export
#'
#' @examples
#' plot_writer_profiles(rates = test[1:4, ])
#'
#' plot_writer_profiles(rates = test[1:4, ], facets = "writer")
#'
#' plot_writer_profiles(rates = test[1:4, ], facets = "writer~docname")
#'
#' @md
plot_writer_profiles <- function(rates, color_by = "docname", ...) {
# prevent note: "no visible binding for global variable"
docname <- cluster <- rate <- .data <- NULL

rates <- rates %>%
tidyr::pivot_longer(
cols = tidyselect::starts_with("cluster"),
names_to = "cluster",
values_to = "rate"
) %>%
dplyr::mutate(
docname = factor(docname),
cluster = as.integer(stringr::str_replace(cluster, "cluster", ""))
)

p <- rates %>%
ggplot2::ggplot(ggplot2::aes(
x = cluster,
y = rate,
group = docname,
color = .data[[color_by]]
)) +
ggplot2::geom_line() +
ggplot2::geom_point() +
ggplot2::theme_bw()

# optional. facet by writer or docname
extra_args <- list(...)
if (length(extra_args) > 0) {
p <- p +
ggplot2::facet_wrap(...)
}

return(p)
}
2 changes: 1 addition & 1 deletion data-raw/make_cluster_fill_rates.R
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ all_clusters <- rbind(LND, WOZ, PHR)

cfc <- get_combined_cfc(all_clusters)

cfr <- get_cluster_fill_rates(cfc = cfc)
cfr <- handwriter::get_cluster_fill_rates(cfc = cfc)

usethis::use_data(cfc, overwrite = TRUE)
usethis::use_data(cfr, overwrite = TRUE)
3 changes: 1 addition & 2 deletions data-raw/train_valid_test.R
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@ load_cluster_fill_rates <- function(clusters_dir) {
files <- list.files(clusters_dir, full.names = TRUE)
dfs <- lapply(files, readRDS)
clusters <- do.call(rbind, dfs)
counts <- handwriter::get_cluster_fill_counts(clusters)
rates <- get_cluster_fill_rates(counts)
rates <- handwriter::get_cluster_fill_rates(clusters)
return(rates)
}

Expand Down
21 changes: 21 additions & 0 deletions man/figures/lifecycle-deprecated.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
21 changes: 21 additions & 0 deletions man/figures/lifecycle-experimental.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
29 changes: 29 additions & 0 deletions man/figures/lifecycle-stable.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
21 changes: 21 additions & 0 deletions man/figures/lifecycle-superseded.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
4 changes: 2 additions & 2 deletions man/get_cluster_fill_rates.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 7ee55cd

Please sign in to comment.