comments | difficulty | edit_url | tags | ||||
---|---|---|---|---|---|---|---|
true |
Hard |
|
You are given a string num
. A string of digits is called balanced if the sum of the digits at even indices is equal to the sum of the digits at odd indices.
Return the number of distinct permutations of num
that are balanced.
Since the answer may be very large, return it modulo 109 + 7
.
A permutation is a rearrangement of all the characters of a string.
Example 1:
Input: num = "123"
Output: 2
Explanation:
- The distinct permutations of
num
are"123"
,"132"
,"213"
,"231"
,"312"
and"321"
. - Among them,
"132"
and"231"
are balanced. Thus, the answer is 2.
Example 2:
Input: num = "112"
Output: 1
Explanation:
- The distinct permutations of
num
are"112"
,"121"
, and"211"
. - Only
"121"
is balanced. Thus, the answer is 1.
Example 3:
Input: num = "12345"
Output: 0
Explanation:
- None of the permutations of
num
are balanced, so the answer is 0.
Constraints:
2 <= num.length <= 80
num
consists of digits'0'
to'9'
only.
First, we count the occurrences of each digit in the string
If
Next, we define a memoization search function
In the function
Next, we check if the remaining number of digits to be filled in odd positions
Otherwise, we can enumerate the number of current digits assigned to odd positions
The time complexity is
class Solution:
def countBalancedPermutations(self, num: str) -> int:
@cache
def dfs(i: int, j: int, a: int, b: int) -> int:
if i > 9:
return (j | a | b) == 0
if a == 0 and j:
return 0
ans = 0
for l in range(min(cnt[i], a) + 1):
r = cnt[i] - l
if 0 <= r <= b and l * i <= j:
t = comb(a, l) * comb(b, r) * dfs(i + 1, j - l * i, a - l, b - r)
ans = (ans + t) % mod
return ans
nums = list(map(int, num))
s = sum(nums)
if s % 2:
return 0
n = len(nums)
mod = 10**9 + 7
cnt = Counter(nums)
return dfs(0, s // 2, n // 2, (n + 1) // 2)
class Solution {
private final int[] cnt = new int[10];
private final int mod = (int) 1e9 + 7;
private Integer[][][][] f;
private long[][] c;
public int countBalancedPermutations(String num) {
int s = 0;
for (char c : num.toCharArray()) {
cnt[c - '0']++;
s += c - '0';
}
if (s % 2 == 1) {
return 0;
}
int n = num.length();
int m = n / 2 + 1;
f = new Integer[10][s / 2 + 1][m][m + 1];
c = new long[m + 1][m + 1];
c[0][0] = 1;
for (int i = 1; i <= m; i++) {
c[i][0] = 1;
for (int j = 1; j <= i; j++) {
c[i][j] = (c[i - 1][j] + c[i - 1][j - 1]) % mod;
}
}
return dfs(0, s / 2, n / 2, (n + 1) / 2);
}
private int dfs(int i, int j, int a, int b) {
if (i > 9) {
return ((j | a | b) == 0) ? 1 : 0;
}
if (a == 0 && j != 0) {
return 0;
}
if (f[i][j][a][b] != null) {
return f[i][j][a][b];
}
int ans = 0;
for (int l = 0; l <= Math.min(cnt[i], a); ++l) {
int r = cnt[i] - l;
if (r >= 0 && r <= b && l * i <= j) {
int t = (int) (c[a][l] * c[b][r] % mod * dfs(i + 1, j - l * i, a - l, b - r) % mod);
ans = (ans + t) % mod;
}
}
return f[i][j][a][b] = ans;
}
}
using ll = long long;
const int MX = 80;
const int MOD = 1e9 + 7;
ll c[MX][MX];
auto init = [] {
c[0][0] = 1;
for (int i = 1; i < MX; ++i) {
c[i][0] = 1;
for (int j = 1; j <= i; ++j) {
c[i][j] = (c[i - 1][j] + c[i - 1][j - 1]) % MOD;
}
}
return 0;
}();
class Solution {
public:
int countBalancedPermutations(string num) {
int cnt[10]{};
int s = 0;
for (char& c : num) {
++cnt[c - '0'];
s += c - '0';
}
if (s % 2) {
return 0;
}
int n = num.size();
int m = n / 2 + 1;
int f[10][s / 2 + 1][m][m + 1];
memset(f, -1, sizeof(f));
auto dfs = [&](auto&& dfs, int i, int j, int a, int b) -> int {
if (i > 9) {
return ((j | a | b) == 0 ? 1 : 0);
}
if (a == 0 && j) {
return 0;
}
if (f[i][j][a][b] != -1) {
return f[i][j][a][b];
}
int ans = 0;
for (int l = 0; l <= min(cnt[i], a); ++l) {
int r = cnt[i] - l;
if (r >= 0 && r <= b && l * i <= j) {
int t = c[a][l] * c[b][r] % MOD * dfs(dfs, i + 1, j - l * i, a - l, b - r) % MOD;
ans = (ans + t) % MOD;
}
}
return f[i][j][a][b] = ans;
};
return dfs(dfs, 0, s / 2, n / 2, (n + 1) / 2);
}
};
const (
MX = 80
MOD = 1_000_000_007
)
var c [MX][MX]int
func init() {
c[0][0] = 1
for i := 1; i < MX; i++ {
c[i][0] = 1
for j := 1; j <= i; j++ {
c[i][j] = (c[i-1][j] + c[i-1][j-1]) % MOD
}
}
}
func countBalancedPermutations(num string) int {
var cnt [10]int
s := 0
for _, ch := range num {
cnt[ch-'0']++
s += int(ch - '0')
}
if s%2 != 0 {
return 0
}
n := len(num)
m := n/2 + 1
f := make([][][][]int, 10)
for i := range f {
f[i] = make([][][]int, s/2+1)
for j := range f[i] {
f[i][j] = make([][]int, m)
for k := range f[i][j] {
f[i][j][k] = make([]int, m+1)
for l := range f[i][j][k] {
f[i][j][k][l] = -1
}
}
}
}
var dfs func(i, j, a, b int) int
dfs = func(i, j, a, b int) int {
if i > 9 {
if j == 0 && a == 0 && b == 0 {
return 1
}
return 0
}
if a == 0 && j > 0 {
return 0
}
if f[i][j][a][b] != -1 {
return f[i][j][a][b]
}
ans := 0
for l := 0; l <= min(cnt[i], a); l++ {
r := cnt[i] - l
if r >= 0 && r <= b && l*i <= j {
t := c[a][l] * c[b][r] % MOD * dfs(i+1, j-l*i, a-l, b-r) % MOD
ans = (ans + t) % MOD
}
}
f[i][j][a][b] = ans
return ans
}
return dfs(0, s/2, n/2, (n+1)/2)
}
const MX = 80;
const MOD = 10 ** 9 + 7;
const c: number[][] = Array.from({ length: MX }, () => Array(MX).fill(0));
(function init() {
c[0][0] = 1;
for (let i = 1; i < MX; i++) {
c[i][0] = 1;
for (let j = 1; j <= i; j++) {
c[i][j] = (c[i - 1][j] + c[i - 1][j - 1]) % MOD;
}
}
})();
function countBalancedPermutations(num: string): number {
const cnt = Array(10).fill(0);
let s = 0;
for (const ch of num) {
cnt[+ch]++;
s += +ch;
}
if (s % 2 !== 0) {
return 0;
}
const n = num.length;
const m = Math.floor(n / 2) + 1;
const f: Record<string, number> = {};
const dfs = (i: number, j: number, a: number, b: number): number => {
if (i > 9) {
return (j | a | b) === 0 ? 1 : 0;
}
if (a === 0 && j > 0) {
return 0;
}
const key = `${i},${j},${a},${b}`;
if (key in f) {
return f[key];
}
let ans = 0;
for (let l = 0; l <= Math.min(cnt[i], a); l++) {
const r = cnt[i] - l;
if (r >= 0 && r <= b && l * i <= j) {
const t = Number(
(((BigInt(c[a][l]) * BigInt(c[b][r])) % BigInt(MOD)) *
BigInt(dfs(i + 1, j - l * i, a - l, b - r))) %
BigInt(MOD),
);
ans = (ans + t) % MOD;
}
}
f[key] = ans;
return ans;
};
return dfs(0, s / 2, Math.floor(n / 2), Math.floor((n + 1) / 2));
}