Skip to content

Commit

Permalink
No public description
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 701795491
  • Loading branch information
tf-text-github-robot committed Dec 2, 2024
1 parent 1cdc3eb commit bc16e6e
Show file tree
Hide file tree
Showing 8 changed files with 69 additions and 76 deletions.
8 changes: 4 additions & 4 deletions tensorflow_text/core/kernels/constrained_sequence_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,10 @@ namespace {

// Validate that a given constraint tensor is the proper shape (dimension
// 2, with shape [num_states + 1, num_states + 1].
tensorflow::Status ValidateConstraintTensor(const Tensor &tensor,
const int num_states,
const bool use_start_end_states,
const string &name) {
absl::Status ValidateConstraintTensor(const Tensor &tensor,
const int num_states,
const bool use_start_end_states,
const string &name) {
if (tensor.shape().dims() != 2) {
return InvalidArgument(
tensorflow::strings::StrCat(name, " must be of rank 2"));
Expand Down
11 changes: 5 additions & 6 deletions tensorflow_text/core/kernels/mst_op_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -89,15 +89,15 @@ class MaxSpanningTreeOpKernel : public tensorflow::OpKernel {
// Solve the batch of MST problems in parallel. Set a high cycles per unit
// to encourage finer sharding.
constexpr int64 kCyclesPerUnit = 1000 * 1000 * 1000;
std::vector<tensorflow::Status> statuses(batch_size);
std::vector<absl::Status> statuses(batch_size);
context->device()->tensorflow_cpu_worker_threads()->workers->ParallelFor(
batch_size, kCyclesPerUnit, [&](int64 begin, int64 end) {
for (int64 problem = begin; problem < end; ++problem) {
statuses[problem] = RunSolver(problem, num_nodes_b, scores_bxmxm,
max_scores_b, argmax_sources_bxm);
}
});
for (const tensorflow::Status &status : statuses) {
for (const absl::Status &status : statuses) {
OP_REQUIRES_OK(context, status);
}
}
Expand All @@ -112,10 +112,9 @@ class MaxSpanningTreeOpKernel : public tensorflow::OpKernel {
// at index |problem| in |num_nodes_b| and |scores_bxmxm|. On success, sets
// the values at index |problem| in |max_scores_b| and |argmax_sources_bxm|.
// On error, returns non-OK.
tensorflow::Status RunSolver(int problem, BatchedSizes num_nodes_b,
BatchedScores scores_bxmxm,
BatchedMaxima max_scores_b,
BatchedSources argmax_sources_bxm) const {
absl::Status RunSolver(int problem, BatchedSizes num_nodes_b,
BatchedScores scores_bxmxm, BatchedMaxima max_scores_b,
BatchedSources argmax_sources_bxm) const {
// Check digraph size overflow.
const int32 num_nodes = num_nodes_b(problem);
const int32 input_dim = argmax_sources_bxm.dimension(1);
Expand Down
19 changes: 9 additions & 10 deletions tensorflow_text/core/kernels/mst_solver.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ class MstSolver {
// error. Discards existing state; call AddArc() and AddRoot() to add arcs
// and root selections. If |forest| is true, then this solves for a maximum
// spanning forest (i.e., a set of disjoint trees that span the digraph).
tensorflow::Status Init(bool forest, Index num_nodes);
absl::Status Init(bool forest, Index num_nodes);

// Adds an arc from the |source| node to the |target| node with the |score|.
// The |source| and |target| must be distinct node indices in [0,n), and the
Expand All @@ -116,10 +116,10 @@ class MstSolver {
//
// NB: If multiple spanning trees achieve the maximum score, |argmax| will be
// set to one of the maximal trees, but it is unspecified which one.
tensorflow::Status Solve(absl::Span<Index> argmax);
absl::Status Solve(absl::Span<Index> argmax);

// Convience method
tensorflow::Status Solve(std::vector<Index>* argmax) {
absl::Status Solve(std::vector<Index> *argmax) {
return Solve(absl::MakeSpan(argmax->data(), argmax->size()));
}

Expand Down Expand Up @@ -235,12 +235,12 @@ class MstSolver {
// phase finds the best inbound arc for each node, contracting cycles as they
// are formed. Stops when every node has selected an inbound arc and there
// are no cycles.
tensorflow::Status ContractionPhase();
absl::Status ContractionPhase();

// Runs the expansion phase of the solver, or returns non-OK on error. This
// phase expands each contracted node, breaks cycles, and populates |argmax|
// with the maximum spanning tree.
tensorflow::Status ExpansionPhase(absl::Span<Index> argmax);
absl::Status ExpansionPhase(absl::Span<Index> argmax);

// If true, solve for a spanning forest instead of a spanning tree.
bool forest_ = false;
Expand Down Expand Up @@ -303,7 +303,7 @@ class MstSolver {
// Implementation details below.

template <class Index, class Score>
tensorflow::Status MstSolver<Index, Score>::Init(bool forest, Index num_nodes) {
absl::Status MstSolver<Index, Score>::Init(bool forest, Index num_nodes) {
if (num_nodes <= 0) {
return tensorflow::errors::InvalidArgument("Non-positive number of nodes: ",
num_nodes);
Expand Down Expand Up @@ -374,7 +374,7 @@ Score MstSolver<Index, Score>::RootScore(Index root) const {
}

template <class Index, class Score>
tensorflow::Status MstSolver<Index, Score>::Solve(absl::Span<Index> argmax) {
absl::Status MstSolver<Index, Score>::Solve(absl::Span<Index> argmax) {
MaybePenalizeRootScoresForTree();
TF_RETURN_IF_ERROR(ContractionPhase());
TF_RETURN_IF_ERROR(ExpansionPhase(argmax));
Expand Down Expand Up @@ -510,7 +510,7 @@ void MstSolver<Index, Score>::ContractCycle(Index node) {
}

template <class Index, class Score>
tensorflow::Status MstSolver<Index, Score>::ContractionPhase() {
absl::Status MstSolver<Index, Score>::ContractionPhase() {
// Skip the artificial root since it has no inbound arcs.
for (Index target = 1; target < num_current_nodes_; ++target) {
// Find the maximum inbound arc for the current |target|, if any.
Expand Down Expand Up @@ -541,8 +541,7 @@ tensorflow::Status MstSolver<Index, Score>::ContractionPhase() {
}

template <class Index, class Score>
tensorflow::Status MstSolver<Index, Score>::ExpansionPhase(
absl::Span<Index> argmax) {
absl::Status MstSolver<Index, Score>::ExpansionPhase(absl::Span<Index> argmax) {
if (argmax.size() < num_original_nodes_) {
return tensorflow::errors::InvalidArgument(
"Argmax array too small: ", num_original_nodes_,
Expand Down
2 changes: 1 addition & 1 deletion tensorflow_text/core/kernels/sentence_breaking_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ struct ErrorOptions {
bool error_on_malformatting = false;
};

Status GetErrorOptions(OpKernelConstruction* context, ErrorOptions* out) {
absl::Status GetErrorOptions(OpKernelConstruction* context, ErrorOptions* out) {
*out = ErrorOptions();

string error_policy;
Expand Down
30 changes: 15 additions & 15 deletions tensorflow_text/core/kernels/sentence_breaking_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@ using ::tensorflow::Status;
namespace tensorflow {
namespace text {

Status UnicodeUtil::GetOneUChar(const absl::string_view& input,
bool* has_more_than_one_char,
UChar32* result) const {
absl::Status UnicodeUtil::GetOneUChar(const absl::string_view& input,
bool* has_more_than_one_char,
UChar32* result) const {
UErrorCode status = U_ZERO_ERROR;
const char* source = input.data();
const char* limit = input.data() + input.length();
Expand All @@ -54,8 +54,8 @@ Status UnicodeUtil::GetOneUChar(const absl::string_view& input,
return absl::OkStatus();
}

Status UnicodeUtil::IsTerminalPunc(const absl::string_view& input,
bool* result) const {
absl::Status UnicodeUtil::IsTerminalPunc(const absl::string_view& input,
bool* result) const {
*result = false;
const auto& ellipsis_status = IsEllipsis(input, result);
// If there was a error decoding, or if we found an ellipsis, then return.
Expand Down Expand Up @@ -89,8 +89,8 @@ Status UnicodeUtil::IsTerminalPunc(const absl::string_view& input,
return absl::OkStatus();
}

Status UnicodeUtil::IsClosePunc(const absl::string_view& input,
bool* result) const {
absl::Status UnicodeUtil::IsClosePunc(const absl::string_view& input,
bool* result) const {
*result = false;
if (input == "''") {
*result = true;
Expand Down Expand Up @@ -128,8 +128,8 @@ Status UnicodeUtil::IsClosePunc(const absl::string_view& input,
return absl::OkStatus();
}

Status UnicodeUtil::IsOpenParen(const absl::string_view& input,
bool* result) const {
absl::Status UnicodeUtil::IsOpenParen(const absl::string_view& input,
bool* result) const {
*result = false;
bool has_more_than_one_char = false;
UChar32 char_value;
Expand All @@ -155,8 +155,8 @@ Status UnicodeUtil::IsOpenParen(const absl::string_view& input,
return absl::OkStatus();
}

Status UnicodeUtil::IsCloseParen(const absl::string_view& input,
bool* result) const {
absl::Status UnicodeUtil::IsCloseParen(const absl::string_view& input,
bool* result) const {
*result = false;
bool has_more_than_one_char = false;
UChar32 char_value;
Expand All @@ -183,8 +183,8 @@ Status UnicodeUtil::IsCloseParen(const absl::string_view& input,
return absl::OkStatus();
}

Status UnicodeUtil::IsPunctuationWord(const absl::string_view& input,
bool* result) const {
absl::Status UnicodeUtil::IsPunctuationWord(const absl::string_view& input,
bool* result) const {
*result = false;
bool has_more_than_one_char = false;
UChar32 char_value;
Expand Down Expand Up @@ -213,8 +213,8 @@ Status UnicodeUtil::IsPunctuationWord(const absl::string_view& input,
return absl::OkStatus();
}

Status UnicodeUtil::IsEllipsis(const absl::string_view& input,
bool* result) const {
absl::Status UnicodeUtil::IsEllipsis(const absl::string_view& input,
bool* result) const {
*result = false;
if (input == "...") {
*result = true;
Expand Down
25 changes: 10 additions & 15 deletions tensorflow_text/core/kernels/sentence_breaking_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,34 +33,29 @@ class UnicodeUtil {
explicit UnicodeUtil(UConverter* converter) : converter_(converter) {}

// Returns true iff a string is terminal punctuation.
::tensorflow::Status IsTerminalPunc(const absl::string_view& input,
bool* result) const;
absl::Status IsTerminalPunc(const absl::string_view& input,
bool* result) const;

// Returns true iff a string is close punctuation (close quote or close
// paren).
::tensorflow::Status IsClosePunc(const absl::string_view& input,
bool* result) const;
absl::Status IsClosePunc(const absl::string_view& input, bool* result) const;

// Returns true iff a string is an open paren.
::tensorflow::Status IsOpenParen(const absl::string_view& input,
bool* result) const;
absl::Status IsOpenParen(const absl::string_view& input, bool* result) const;

// Returns true iff a string is a close paren.
::tensorflow::Status IsCloseParen(const absl::string_view& input,
bool* result) const;
absl::Status IsCloseParen(const absl::string_view& input, bool* result) const;

// Returns true iff a word is made of punctuation characters only.
::tensorflow::Status IsPunctuationWord(const absl::string_view& input,
bool* result) const;
absl::Status IsPunctuationWord(const absl::string_view& input,
bool* result) const;

// Returns true iff a string is an ellipsis token ("...").
::tensorflow::Status IsEllipsis(const absl::string_view& input,
bool* result) const;
absl::Status IsEllipsis(const absl::string_view& input, bool* result) const;

private:
::tensorflow::Status GetOneUChar(const absl::string_view&,
bool* has_more_than_one_char,
UChar32* result) const;
absl::Status GetOneUChar(const absl::string_view&,
bool* has_more_than_one_char, UChar32* result) const;

// not owned. mutable because UConverter contains some internal options and
// buffer.
Expand Down
26 changes: 13 additions & 13 deletions tensorflow_text/core/kernels/sentence_fragmenter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ bool IsPeriodSeparatedAcronym(const Token &token) {

// Returns true iff the token can appear after a space in a sentence-terminal
// token sequence.
Status SpaceAllowedBeforeToken(const UnicodeUtil *util, const Token &token,
bool *result) {
absl::Status SpaceAllowedBeforeToken(const UnicodeUtil *util,
const Token &token, bool *result) {
const tstring &word = token.word();
bool is_ellipsis = false;
TF_RETURN_IF_ERROR(util->IsEllipsis(word, &is_ellipsis));
Expand Down Expand Up @@ -77,8 +77,8 @@ class SentenceFragmenter::FragmentBoundaryMatch {

// Follows the state transition for the token at the given index. Returns
// true for success, or false if there was no valid transition.
Status Advance(const UnicodeUtil *util, const Document &document, int index,
bool *result) {
absl::Status Advance(const UnicodeUtil *util, const Document &document,
int index, bool *result) {
const Token &token = document.tokens()[index];
const tstring &word = token.word();
bool no_transition = false;
Expand Down Expand Up @@ -176,7 +176,7 @@ class SentenceFragmenter::FragmentBoundaryMatch {
int limit_index_ = -1;
};

Status SentenceFragmenter::FindFragments(
absl::Status SentenceFragmenter::FindFragments(
std::vector<SentenceFragment> *result) {
// Partition tokens into sentence fragments.
for (int i_start = 0; i_start < document_->tokens().size();) {
Expand Down Expand Up @@ -215,7 +215,7 @@ Status SentenceFragmenter::FindFragments(
// scan "!!!" looking for a fragment boundary. Since we failed to find one last
// time, we'll fail again this time and therefore continue past "y" to find the
// next boundary. We will not try to scan "!!!" a third time.
Status SentenceFragmenter::FindNextFragmentBoundary(
absl::Status SentenceFragmenter::FindNextFragmentBoundary(
int i_start, SentenceFragmenter::FragmentBoundaryMatch *result) const {
FragmentBoundaryMatch current_match;
FragmentBoundaryMatch previous_match;
Expand Down Expand Up @@ -276,8 +276,8 @@ Status SentenceFragmenter::FindNextFragmentBoundary(
// punctuation that turns out not to be a sentence boundary, e.g.,
// "Yahoo! (known for search, etc.) blah", but this is not expected to happen
// often.
Status SentenceFragmenter::UpdateLatestOpenParenForFragment(int i_start,
int i_end) {
absl::Status SentenceFragmenter::UpdateLatestOpenParenForFragment(int i_start,
int i_end) {
for (int i = i_end; i > i_start; --i) {
const auto &token = document_->tokens()[i - 1];
bool is_open_paren = false;
Expand All @@ -293,7 +293,7 @@ Status SentenceFragmenter::UpdateLatestOpenParenForFragment(int i_start,
return absl::OkStatus();
}

Status SentenceFragmenter::FillInFragmentFields(
absl::Status SentenceFragmenter::FillInFragmentFields(
int i_start, const FragmentBoundaryMatch &match,
SentenceFragment *fragment) const {
// Set the fragment's boundaries.
Expand Down Expand Up @@ -343,7 +343,7 @@ Status SentenceFragmenter::FillInFragmentFields(
//
// We treat "!" as the first terminal punctuation mark; the ellipsis acts as
// left context.
Status SentenceFragmenter::GetAdjustedFirstTerminalPuncIndex(
absl::Status SentenceFragmenter::GetAdjustedFirstTerminalPuncIndex(
const FragmentBoundaryMatch &match, int *result) const {
// Get terminal punctuation span.
int i1 = match.first_terminal_punc_index();
Expand Down Expand Up @@ -385,7 +385,7 @@ Status SentenceFragmenter::GetAdjustedFirstTerminalPuncIndex(
// true sentence boundary. The terminal punctuation mark must be unambiguous
// (.!?), as ambiguous ones (ellipsis/emoticon) do not necessarily imply a
// sentence boundary.
Status SentenceFragmenter::HasUnattachableTerminalPunc(
absl::Status SentenceFragmenter::HasUnattachableTerminalPunc(
const FragmentBoundaryMatch &match, bool *result) const {
*result = false;
// Get terminal punctuation span.
Expand Down Expand Up @@ -415,8 +415,8 @@ Status SentenceFragmenter::HasUnattachableTerminalPunc(
return absl::OkStatus();
}

Status SentenceFragmenter::HasCloseParen(const FragmentBoundaryMatch &match,
bool *result) const {
absl::Status SentenceFragmenter::HasCloseParen(
const FragmentBoundaryMatch &match, bool *result) const {
*result = false;
// Get close punctuation span.
int i1 = match.first_close_punc_index();
Expand Down
24 changes: 12 additions & 12 deletions tensorflow_text/core/kernels/sentence_fragmenter.h
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ class SentenceFragmenter {

// Finds sentence fragments in the [start_, limit_) range of the associated
// document.
::tensorflow::Status FindFragments(std::vector<SentenceFragment> *result);
absl::Status FindFragments(std::vector<SentenceFragment> *result);

private:
// State for matching a fragment-boundary regexp against a token sequence.
Expand All @@ -180,33 +180,33 @@ class SentenceFragmenter {
// Matches a fragment-boundary regexp against the tokens starting at
// 'i_start'. Returns the longest match found; will be non-empty as long as
// 'i_start' was not already at the end of the associated token range.
::tensorflow::Status FindNextFragmentBoundary(
int i_start, FragmentBoundaryMatch *result) const;
absl::Status FindNextFragmentBoundary(int i_start,
FragmentBoundaryMatch *result) const;

// Updates 'latest_open_paren_is_sentential_' for the tokens in the given
// fragment.
::tensorflow::Status UpdateLatestOpenParenForFragment(int i_start, int i_end);
absl::Status UpdateLatestOpenParenForFragment(int i_start, int i_end);

// Populates a sentence fragment with the tokens from 'i_start' to the end
// of the given FragmentBoundaryMatch.
::tensorflow::Status FillInFragmentFields(int i_start,
const FragmentBoundaryMatch &match,
SentenceFragment *fragment) const;
absl::Status FillInFragmentFields(int i_start,
const FragmentBoundaryMatch &match,
SentenceFragment *fragment) const;

// Returns the adjusted first terminal punctuation index in a
// FragmentBoundaryMatch.
::tensorflow::Status GetAdjustedFirstTerminalPuncIndex(
absl::Status GetAdjustedFirstTerminalPuncIndex(
const FragmentBoundaryMatch &match, int *result) const;

// Returns true iff a FragmentBoundaryMatch has an "unattachable" terminal
// punctuation mark.
::tensorflow::Status HasUnattachableTerminalPunc(
const FragmentBoundaryMatch &match, bool *result) const;
absl::Status HasUnattachableTerminalPunc(const FragmentBoundaryMatch &match,
bool *result) const;

// Returns true iff a FragmentBoundaryMatch has a close paren in its closing
// punctuation.
::tensorflow::Status HasCloseParen(const FragmentBoundaryMatch &match,
bool *result) const;
absl::Status HasCloseParen(const FragmentBoundaryMatch &match,
bool *result) const;

// Whether the latest open paren seen so far appears to be sentence-initial.
// See UpdateLatestOpenParenForFragment() in the .cc file for details.
Expand Down

0 comments on commit bc16e6e

Please sign in to comment.