diff --git a/doc/source/nfloat.rst b/doc/source/nfloat.rst index d708a16c6a..ffdb6638e4 100644 --- a/doc/source/nfloat.rst +++ b/doc/source/nfloat.rst @@ -412,4 +412,8 @@ real pairs. int _nfloat_complex_vec_set(nfloat_complex_ptr res, nfloat_complex_srcptr x, slong len, gr_ctx_t ctx) int _nfloat_complex_vec_add(nfloat_complex_ptr res, nfloat_complex_srcptr x, nfloat_complex_srcptr y, slong len, gr_ctx_t ctx) int _nfloat_complex_vec_sub(nfloat_complex_ptr res, nfloat_complex_srcptr x, nfloat_complex_srcptr y, slong len, gr_ctx_t ctx) - + int nfloat_complex_mat_mul_fixed_classical(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, gr_ctx_t ctx) + int nfloat_complex_mat_mul_waksman(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, gr_ctx_t ctx) + int nfloat_complex_mat_mul_block(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, slong min_block_size, gr_ctx_t ctx) + int nfloat_complex_mat_mul_reorder(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, gr_ctx_t ctx) + int nfloat_complex_mat_mul(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, gr_ctx_t ctx) diff --git a/src/gr_mat/test_approx_mul.c b/src/gr_mat/test_approx_mul.c index e3c3c69424..9d510edaeb 100644 --- a/src/gr_mat/test_approx_mul.c +++ b/src/gr_mat/test_approx_mul.c @@ -57,9 +57,6 @@ void gr_mat_test_approx_mul_max_norm(gr_method_mat_binary_op mul_impl, gr_srcptr status |= gr_mat_randtest(A, state, ctx); status |= gr_mat_randtest(B, state, ctx); - status |= gr_mat_entrywise_unary_op(A, (gr_method_unary_op) gr_abs, A, ctx); - status |= gr_mat_entrywise_unary_op(B, (gr_method_unary_op) gr_abs, B, ctx); - status |= gr_mat_randtest(C, state, ctx); status |= gr_mat_randtest(D, state, ctx); diff --git a/src/nfloat.h b/src/nfloat.h index 78b764a213..f0fdcf2985 100644 --- a/src/nfloat.h +++ b/src/nfloat.h @@ -563,6 +563,12 @@ int _nfloat_complex_vec_sub(nfloat_complex_ptr res, nfloat_complex_srcptr x, nfl int _nfloat_complex_vec_dot(nfloat_complex_ptr res, nfloat_complex_srcptr initial, int subtract, nfloat_complex_srcptr x, nfloat_complex_srcptr y, slong len, gr_ctx_t ctx); int _nfloat_complex_vec_dot_rev(nfloat_complex_ptr res, nfloat_complex_srcptr initial, int subtract, nfloat_complex_srcptr x, nfloat_complex_srcptr y, slong len, gr_ctx_t ctx); +int nfloat_complex_mat_mul_fixed_classical(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, gr_ctx_t ctx); +int nfloat_complex_mat_mul_waksman(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, gr_ctx_t ctx); +int nfloat_complex_mat_mul_block(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, slong min_block_size, gr_ctx_t ctx); +int nfloat_complex_mat_mul_reorder(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, gr_ctx_t ctx); +int nfloat_complex_mat_mul(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, gr_ctx_t ctx); + #ifdef __cplusplus } #endif diff --git a/src/nfloat/complex.c b/src/nfloat/complex.c index 347210ebc2..ca80a50e6f 100644 --- a/src/nfloat/complex.c +++ b/src/nfloat/complex.c @@ -1792,8 +1792,8 @@ gr_method_tab_input _nfloat_complex_methods_input[] = /* {GR_METHOD_POLY_MULLOW, (gr_funcptr) nfloat_complex_poly_mullow}, {GR_METHOD_POLY_ROOTS_OTHER,(gr_funcptr) nfloat_complex_poly_roots_other}, - {GR_METHOD_MAT_MUL, (gr_funcptr) nfloat_complex_mat_mul}, */ + {GR_METHOD_MAT_MUL, (gr_funcptr) nfloat_complex_mat_mul}, {GR_METHOD_MAT_DET, (gr_funcptr) gr_mat_det_generic_field}, {GR_METHOD_MAT_FIND_NONZERO_PIVOT, (gr_funcptr) gr_mat_find_nonzero_pivot_large_abs}, diff --git a/src/nfloat/mat_mul.c b/src/nfloat/mat_mul.c index 169722a3ec..dd496c9bc6 100644 --- a/src/nfloat/mat_mul.c +++ b/src/nfloat/mat_mul.c @@ -86,6 +86,24 @@ void nfixed_sub(nn_ptr res, nn_srcptr a, nn_srcptr b, slong nlimbs) } } +FLINT_FORCE_INLINE +void _nfixed_vec_add(nn_ptr res, nn_srcptr a, nn_srcptr b, slong len, slong nlimbs) +{ + slong i; + + for (i = 0; i < len; i++) + nfixed_add(res + i * (nlimbs + 1), a + i * (nlimbs + 1), b + i * (nlimbs + 1), nlimbs); +} + +FLINT_FORCE_INLINE +void _nfixed_vec_sub(nn_ptr res, nn_srcptr a, nn_srcptr b, slong len, slong nlimbs) +{ + slong i; + + for (i = 0; i < len; i++) + nfixed_sub(res + i * (nlimbs + 1), a + i * (nlimbs + 1), b + i * (nlimbs + 1), nlimbs); +} + FLINT_FORCE_INLINE void nfixed_mul(nn_ptr res, nn_srcptr a, nn_srcptr b, slong nlimbs) { @@ -615,7 +633,203 @@ _nfloat_nbits(nfloat_srcptr x, slong nlimbs) return bits; } +int +nfloat_complex_mat_addmul_block_fallback(gr_mat_t C, + const gr_mat_t A, const gr_mat_t B, + slong block_start, + slong block_end, + gr_ctx_t ctx) +{ + slong M, P, n; + slong i, j, sz; + nn_ptr tmpB; + slong ndlimbs = NFLOAT_COMPLEX_CTX_DATA_NLIMBS(ctx); + sz = ctx->sizeof_elem; + int status = GR_SUCCESS; + + M = A->r; + P = B->c; + + n = block_end - block_start; + + tmpB = flint_malloc(sizeof(ulong) * ndlimbs * (P * n)); + +#define AA(ii, jj) GR_MAT_ENTRY(A, ii, block_start + (jj), sz) + + for (i = 0; i < P; i++) + for (j = 0; j < n; j++) + flint_mpn_copyi(tmpB + (i * n + j) * ndlimbs, GR_MAT_ENTRY(B, block_start + j, i, sz), ndlimbs); + + for (i = 0; i < M; i++) + { + for (j = 0; j < P; j++) + { + status |= _nfloat_complex_vec_dot(GR_MAT_ENTRY(C, i, j, sz), + (block_start == 0) ? NULL : GR_MAT_ENTRY(C, i, j, sz), 0, + GR_MAT_ENTRY(A, i, block_start, sz), + tmpB + j * n * ndlimbs, n, ctx); + } + } + + flint_free(tmpB); + + return status; +} + +static void +_nfloat_complex_2exp_get_fmpz_fmpz(fmpz_t res1, fmpz_t res2, nfloat_complex_srcptr x, slong fixexp, gr_ctx_t ctx) +{ + _nfloat_2exp_get_fmpz(res1, NFLOAT_COMPLEX_RE(x, ctx), fixexp, ctx); + _nfloat_2exp_get_fmpz(res2, NFLOAT_COMPLEX_IM(x, ctx), fixexp, ctx); +} + +int +nfloat_complex_mat_addmul_block_prescaled(gr_mat_t C, + const gr_mat_t A, const gr_mat_t B, + slong block_start, + slong block_end, + const slong * A_min, /* A per-row bottom exponent */ + const slong * B_min, /* B per-row bottom exponent */ + gr_ctx_t ctx) +{ + slong M, P, n; + slong i, j; + slong M0, M1, P0, P1, Mstep, Pstep; + int status = GR_SUCCESS; + slong sz = ctx->sizeof_elem; + ulong t[NFLOAT_MAX_ALLOC]; + slong e; + + M = A->r; + P = B->c; + + n = block_end - block_start; + + /* Create sub-blocks to keep matrices nearly square. Necessary? */ +#if 1 + Mstep = (M < 2 * n) ? M : n; + Pstep = (P < 2 * n) ? P : n; +#else + Mstep = M; + Pstep = P; +#endif + + for (M0 = 0; M0 < M; M0 += Mstep) + { + for (P0 = 0; P0 < P; P0 += Pstep) + { + fmpz_mat_t Aa, Ab, Ba, Bb, AaBa, AbBb, Cb; + + M1 = FLINT_MIN(M0 + Mstep, M); + P1 = FLINT_MIN(P0 + Pstep, P); + + fmpz_mat_init(Aa, M1 - M0, n); + fmpz_mat_init(Ab, M1 - M0, n); + + fmpz_mat_init(Ba, n, P1 - P0); + fmpz_mat_init(Bb, n, P1 - P0); + + fmpz_mat_init(AaBa, M1 - M0, P1 - P0); + fmpz_mat_init(AbBb, M1 - M0, P1 - P0); + fmpz_mat_init(Cb, M1 - M0, P1 - P0); + + /* Convert to fixed-point matrices. */ + for (i = M0; i < M1; i++) + { + if (A_min[i] == WORD_MIN) /* only zeros in this row */ + continue; + + for (j = 0; j < n; j++) + _nfloat_complex_2exp_get_fmpz_fmpz(fmpz_mat_entry(Aa, i - M0, j), + fmpz_mat_entry(Ab, i - M0, j), GR_MAT_ENTRY(A, i, block_start + j, sz), A_min[i], ctx); + } + + for (i = P0; i < P1; i++) + { + if (B_min[i] == WORD_MIN) /* only zeros in this column */ + continue; + + for (j = 0; j < n; j++) + _nfloat_complex_2exp_get_fmpz_fmpz(fmpz_mat_entry(Ba, j, i - P0), + fmpz_mat_entry(Bb, j, i - P0), GR_MAT_ENTRY(B, block_start + j, i, sz), B_min[i], ctx); + } + + fmpz_mat_mul(AaBa, Aa, Ba); + fmpz_mat_mul(AbBb, Ab, Bb); + fmpz_mat_add(Aa, Aa, Ab); + fmpz_mat_add(Ba, Ba, Bb); + fmpz_mat_mul(Cb, Aa, Ba); + fmpz_mat_sub(Cb, Cb, AaBa); + fmpz_mat_sub(Cb, Cb, AbBb); + + /* Add to the result matrix */ + for (i = M0; i < M1; i++) + { + for (j = P0; j < P1; j++) + { + nn_ptr cc = NFLOAT_COMPLEX_IM(GR_MAT_ENTRY(C, i, j, sz), ctx); + + e = A_min[i] + B_min[j]; + + /* The first time we write this Cij */ + if (block_start == 0) + { + status |= nfloat_set_fmpz(cc, fmpz_mat_entry(Cb, i - M0, j - P0), ctx); + status |= nfloat_mul_2exp_si(cc, cc, e, ctx); + } + else + { + status |= nfloat_set_fmpz(t, fmpz_mat_entry(Cb, i - M0, j - P0), ctx); + status |= nfloat_mul_2exp_si(t, t, e, ctx); + status |= nfloat_add(cc, cc, t, ctx); + } + } + } + + fmpz_mat_sub(Cb, AaBa, AbBb); + + /* Add to the result matrix */ + for (i = M0; i < M1; i++) + { + for (j = P0; j < P1; j++) + { + nn_ptr cc = NFLOAT_COMPLEX_RE(GR_MAT_ENTRY(C, i, j, sz), ctx); + + e = A_min[i] + B_min[j]; + + /* The first time we write this Cij */ + if (block_start == 0) + { + status |= nfloat_set_fmpz(cc, fmpz_mat_entry(Cb, i - M0, j - P0), ctx); + status |= nfloat_mul_2exp_si(cc, cc, e, ctx); + } + else + { + status |= nfloat_set_fmpz(t, fmpz_mat_entry(Cb, i - M0, j - P0), ctx); + status |= nfloat_mul_2exp_si(t, t, e, ctx); + status |= nfloat_add(cc, cc, t, ctx); + } + } + } + + fmpz_mat_clear(Aa); + fmpz_mat_clear(Ab); + + fmpz_mat_clear(Ba); + fmpz_mat_clear(Bb); + + fmpz_mat_clear(AaBa); + fmpz_mat_clear(AbBb); + fmpz_mat_clear(Cb); + } + } + + return status; +} + + /* todo: squaring optimizations */ +/* note: also supports complex contexts */ int nfloat_mat_mul_block(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, slong min_block_size, gr_ctx_t ctx) { @@ -631,10 +845,11 @@ nfloat_mat_mul_block(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, slong min_b slong nlimbs = NFLOAT_CTX_NLIMBS(ctx); slong prec = NFLOAT_CTX_PREC(ctx); int status = GR_SUCCESS; + int complex_ctx = (ctx->which_ring == GR_CTX_NFLOAT_COMPLEX); M = A->r; N = A->c; - P = B->r; + P = B->c; if (N != B->r || M != C->r || P != C->c) return GR_DOMAIN; @@ -676,44 +891,139 @@ nfloat_mat_mul_block(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, slong min_b /* Build table of bottom exponents (WORD_MIN signifies a zero), and also collect some statistics. */ - for (i = 0; i < M; i++) + if (complex_ctx) { - for (j = 0; j < N; j++) + for (i = 0; i < M; i++) { - t = GR_MAT_ENTRY(A, i, j, sz); - if (NFLOAT_IS_ZERO(t)) + for (j = 0; j < N; j++) { - A_bot[i * N + j] = WORD_MIN; - A_bits[i * N + j] = 0; + nfloat_srcptr re, im; + slong h, b1, b2; + + t = GR_MAT_ENTRY(A, i, j, sz); + re = NFLOAT_COMPLEX_RE(t, ctx); + im = NFLOAT_COMPLEX_IM(t, ctx); + + if (NFLOAT_IS_ZERO(re) && NFLOAT_IS_ZERO(im)) + { + A_bot[i * N + j] = WORD_MIN; + A_bits[i * N + j] = 0; + } + else if (NFLOAT_IS_ZERO(im)) + { + b = _nfloat_nbits(re, nlimbs); + A_bot[i * N + j] = NFLOAT_EXP(re) - b; + A_bits[i * N + j] = b; + A_max_bits = FLINT_MAX(A_max_bits, b); + A_density++; + } + else if (NFLOAT_IS_ZERO(re)) + { + b = _nfloat_nbits(im, nlimbs); + A_bot[i * N + j] = NFLOAT_EXP(im) - b; + A_bits[i * N + j] = b; + A_max_bits = FLINT_MAX(A_max_bits, b); + A_density++; + } + else + { + b1 = _nfloat_nbits(re, nlimbs); + b2 = _nfloat_nbits(im, nlimbs); + + A_bot[i * N + j] = b = FLINT_MIN(NFLOAT_EXP(re) - b1, NFLOAT_EXP(im) - b2); + A_bits[i * N + j] = h = FLINT_MAX(NFLOAT_EXP(re) - b, NFLOAT_EXP(im) - b); + A_max_bits = FLINT_MAX(A_max_bits, h); + A_density++; + } } - else + } + + for (i = 0; i < N; i++) + { + for (j = 0; j < P; j++) { - b = _nfloat_nbits(t, nlimbs); - A_bot[i * N + j] = NFLOAT_EXP(t) - b; - A_bits[i * N + j] = b; - A_max_bits = FLINT_MAX(A_max_bits, b); - A_density++; + nfloat_srcptr re, im; + slong h, b1, b2; + + t = GR_MAT_ENTRY(B, i, j, sz); + re = NFLOAT_COMPLEX_RE(t, ctx); + im = NFLOAT_COMPLEX_IM(t, ctx); + + if (NFLOAT_IS_ZERO(re) && NFLOAT_IS_ZERO(im)) + { + B_bot[i * P + j] = WORD_MIN; + B_bits[i * P + j] = 0; + } + else if (NFLOAT_IS_ZERO(im)) + { + b = _nfloat_nbits(re, nlimbs); + B_bot[i * P + j] = NFLOAT_EXP(re) - b; + B_bits[i * P + j] = b; + B_max_bits = FLINT_MAX(B_max_bits, b); + B_density++; + } + else if (NFLOAT_IS_ZERO(re)) + { + b = _nfloat_nbits(im, nlimbs); + B_bot[i * P + j] = NFLOAT_EXP(im) - b; + B_bits[i * P + j] = b; + B_max_bits = FLINT_MAX(B_max_bits, b); + B_density++; + } + else + { + b1 = _nfloat_nbits(re, nlimbs); + b2 = _nfloat_nbits(im, nlimbs); + + B_bot[i * P + j] = b = FLINT_MIN(NFLOAT_EXP(re) - b1, NFLOAT_EXP(im) - b2); + B_bits[i * P + j] = h = FLINT_MAX(NFLOAT_EXP(re) - b, NFLOAT_EXP(im) - b); + B_max_bits = FLINT_MAX(B_max_bits, h); + B_density++; + } } } } - - for (i = 0; i < N; i++) + else { - for (j = 0; j < P; j++) + for (i = 0; i < M; i++) { - t = GR_MAT_ENTRY(B, i, j, sz); - if (NFLOAT_IS_ZERO(t)) + for (j = 0; j < N; j++) { - B_bot[i * P + j] = WORD_MIN; - B_bits[i * P + j] = 0; + t = GR_MAT_ENTRY(A, i, j, sz); + if (NFLOAT_IS_ZERO(t)) + { + A_bot[i * N + j] = WORD_MIN; + A_bits[i * N + j] = 0; + } + else + { + b = _nfloat_nbits(t, nlimbs); + A_bot[i * N + j] = NFLOAT_EXP(t) - b; + A_bits[i * N + j] = b; + A_max_bits = FLINT_MAX(A_max_bits, b); + A_density++; + } } - else + } + + for (i = 0; i < N; i++) + { + for (j = 0; j < P; j++) { - b = _nfloat_nbits(t, nlimbs); - B_bot[i * P + j] = NFLOAT_EXP(t) - b; - B_bits[i * P + j] = b; - B_max_bits = FLINT_MAX(B_max_bits, b); - B_density++; + t = GR_MAT_ENTRY(B, i, j, sz); + if (NFLOAT_IS_ZERO(t)) + { + B_bot[i * P + j] = WORD_MIN; + B_bits[i * P + j] = 0; + } + else + { + b = _nfloat_nbits(t, nlimbs); + B_bot[i * P + j] = NFLOAT_EXP(t) - b; + B_bits[i * P + j] = b; + B_max_bits = FLINT_MAX(B_max_bits, b); + B_density++; + } } } } @@ -830,11 +1140,18 @@ nfloat_mat_mul_block(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, slong min_b if (block_end - block_start < min_block_size) { block_end = FLINT_MIN(N, block_start + min_block_size); - status |= nfloat_mat_addmul_block_fallback(C, A, B, block_start, block_end, ctx); + + if (complex_ctx) + status |= nfloat_complex_mat_addmul_block_fallback(C, A, B, block_start, block_end, ctx); + else + status |= nfloat_mat_addmul_block_fallback(C, A, B, block_start, block_end, ctx); } else { - status |= nfloat_mat_addmul_block_prescaled(C, A, B, block_start, block_end, A_min, B_min, ctx); + if (complex_ctx) + status |= nfloat_complex_mat_addmul_block_prescaled(C, A, B, block_start, block_end, A_min, B_min, ctx); + else + status |= nfloat_mat_addmul_block_prescaled(C, A, B, block_start, block_end, A_min, B_min, ctx); } block_start = block_end; @@ -929,3 +1246,526 @@ nfloat_mat_mul(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, gr_ctx_t ctx) return nfloat_mat_mul_block(C, A, B, 70, ctx); } } + +static void +_nfloat_complex_mat_exp_range(slong * _Amin, slong * _Amax, const gr_mat_t A, gr_ctx_t ctx) +{ + slong Amax, Amin; + slong m = A->r; + slong n = A->c; + slong exp, i, j; + slong sz = ctx->sizeof_elem; + nfloat_srcptr a; + + Amax = WORD_MIN; + Amin = WORD_MAX; + + for (i = 0; i < m; i++) + { + for (j = 0; j < n; j++) + { + a = NFLOAT_COMPLEX_RE(GR_MAT_ENTRY(A, i, j, sz), ctx); + exp = NFLOAT_EXP(a); + Amax = FLINT_MAX(Amax, exp); + Amin = FLINT_MIN(Amin, exp); + + a = NFLOAT_COMPLEX_IM(GR_MAT_ENTRY(A, i, j, sz), ctx); + exp = NFLOAT_EXP(a); + Amax = FLINT_MAX(Amax, exp); + Amin = FLINT_MIN(Amin, exp); + } + } + + _Amin[0] = Amin; + _Amax[0] = Amax; +} + +int +_nfloat_complex_mat_mul_fixed_given_exp(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, slong Aexp, slong Bexp, slong fnlimbs, int waksman, gr_ctx_t ctx) +{ + nn_ptr T, Aa, Ab, Ba, Bb, AaBa, AbBb, Cb; + slong i, j; + slong sz = ctx->sizeof_elem; + slong fdnlimbs; + + slong m = A->r; + slong n = A->c; + slong p = B->c; + + /* limbs including sign limb */ + fdnlimbs = fnlimbs + 1; + + T = flint_calloc(fdnlimbs * (2 * m * n + 2 * n * p + 3 * m * p), sizeof(ulong)); + + Aa = T; + Ab = Aa + fdnlimbs * (m * n); + Ba = Ab + fdnlimbs * (m * n); + Bb = Ba + fdnlimbs * (n * p); + Cb = Bb + fdnlimbs * (n * p); + AaBa = Cb + fdnlimbs * (m * p); + AbBb = AaBa + fdnlimbs * (m * p); + + for (i = 0; i < m; i++) + { + for (j = 0; j < n; j++) + { + _nfloat_get_nfixed(Aa + i * fdnlimbs * n + j * fdnlimbs, NFLOAT_COMPLEX_RE(GR_MAT_ENTRY(A, i, j, sz), ctx), Aexp, fnlimbs, ctx); + _nfloat_get_nfixed(Ab + i * fdnlimbs * n + j * fdnlimbs, NFLOAT_COMPLEX_IM(GR_MAT_ENTRY(A, i, j, sz), ctx), Aexp, fnlimbs, ctx); + } + } + + for (i = 0; i < n; i++) + { + for (j = 0; j < p; j++) + { + _nfloat_get_nfixed(Ba + i * fdnlimbs * p + j * fdnlimbs, NFLOAT_COMPLEX_RE(GR_MAT_ENTRY(B, i, j, sz), ctx), Bexp, fnlimbs, ctx); + _nfloat_get_nfixed(Bb + i * fdnlimbs * p + j * fdnlimbs, NFLOAT_COMPLEX_IM(GR_MAT_ENTRY(B, i, j, sz), ctx), Bexp, fnlimbs, ctx); + } + } + + /* (Aa+Ab i)(Ba+Bb i) = (Aa Ba - Ab Bb) + (Aa Bb + Ab Ba i) */ + + /* (Aa Ba - Ab Bb) + ((Aa + Ab)(Ba + Bb) - Aa Ba - Ab Bb) i */ + + if (waksman) + { + _nfixed_mat_mul_waksman(AaBa, Aa, Ba, m, n, p, fnlimbs); + _nfixed_mat_mul_waksman(AbBb, Ab, Bb, m, n, p, fnlimbs); + _nfixed_vec_add(Aa, Aa, Ab, m * n, fnlimbs); + _nfixed_vec_add(Ba, Ba, Bb, n * p, fnlimbs); + _nfixed_mat_mul_waksman(Cb, Aa, Ba, m, n, p, fnlimbs); + _nfixed_vec_sub(Cb, Cb, AaBa, m * p, fnlimbs); + _nfixed_vec_sub(Cb, Cb, AbBb, m * p, fnlimbs); + + for (i = 0; i < m; i++) + for (j = 0; j < p; j++) + _nfloat_set_nfixed(NFLOAT_COMPLEX_IM(GR_MAT_ENTRY(C, i, j, sz), ctx), Cb + i * fdnlimbs * p + j * fdnlimbs, Aexp + Bexp, fnlimbs, ctx); + + _nfixed_vec_sub(Cb, AaBa, AbBb, m * p, fnlimbs); + + for (i = 0; i < m; i++) + for (j = 0; j < p; j++) + _nfloat_set_nfixed(NFLOAT_COMPLEX_RE(GR_MAT_ENTRY(C, i, j, sz), ctx), Cb + i * fdnlimbs * p + j * fdnlimbs, Aexp + Bexp, fnlimbs, ctx); + } + else + { + _nfixed_mat_mul_classical(AaBa, Aa, Ba, m, n, p, fnlimbs); + _nfixed_mat_mul_classical(AbBb, Ab, Bb, m, n, p, fnlimbs); + _nfixed_vec_add(Aa, Aa, Ab, m * n, fnlimbs); + _nfixed_vec_add(Ba, Ba, Bb, n * p, fnlimbs); + _nfixed_mat_mul_classical(Cb, Aa, Ba, m, n, p, fnlimbs); + _nfixed_vec_sub(Cb, Cb, AaBa, m * p, fnlimbs); + _nfixed_vec_sub(Cb, Cb, AbBb, m * p, fnlimbs); + + for (i = 0; i < m; i++) + for (j = 0; j < p; j++) + _nfloat_set_nfixed(NFLOAT_COMPLEX_IM(GR_MAT_ENTRY(C, i, j, sz), ctx), Cb + i * fdnlimbs * p + j * fdnlimbs, Aexp + Bexp, fnlimbs, ctx); + + _nfixed_vec_sub(Cb, AaBa, AbBb, m * p, fnlimbs); + + for (i = 0; i < m; i++) + for (j = 0; j < p; j++) + _nfloat_set_nfixed(NFLOAT_COMPLEX_RE(GR_MAT_ENTRY(C, i, j, sz), ctx), Cb + i * fdnlimbs * p + j * fdnlimbs, Aexp + Bexp, fnlimbs, ctx); + } + + flint_free(T); + + return GR_SUCCESS; +} + +int +_nfloat_complex_mat_mul_fixed(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, int waksman, slong max_extra_bits, gr_ctx_t ctx) +{ + slong Amax, Amin, Bmax, Bmin, Adelta, Bdelta, Aexp, Bexp; + slong prec; + slong pad_top, pad_bot, extra_bits, fbits, fnlimbs; + slong n = A->c; + + if (NFLOAT_CTX_HAS_INF_NAN(ctx)) + return GR_UNABLE; + + prec = NFLOAT_CTX_PREC(ctx); + + _nfloat_complex_mat_exp_range(&Amin, &Amax, A, ctx); + _nfloat_complex_mat_exp_range(&Bmin, &Bmax, B, ctx); + + if (Amax < NFLOAT_MIN_EXP || Bmax < NFLOAT_MIN_EXP) + return gr_mat_zero(C, ctx); + + /* Currently, we don't handle zeros. (They pose no problem, but zero entries in + the output may not be exact. To be done.) */ + if (Amin < NFLOAT_MIN_EXP || Bmin < NFLOAT_MIN_EXP) + return gr_mat_mul_classical(C, A, B, ctx); + + Adelta = Amax - Amin; + Bdelta = Bmax - Bmin; + + /* sanity check */ + if (Adelta > 10 * prec || Bdelta > 10 * prec) + return gr_mat_mul_classical(C, A, B, ctx); + + /* + To double check: for Waksman, + * The intermediate entries are bounded by 8n max(|A|,|B|)^2. + * The error, including error from converting + the input matrices, is bounded by 8n ulps. + */ + + pad_top = 3 + FLINT_BIT_COUNT(n); + pad_bot = 3 + FLINT_BIT_COUNT(n); + + extra_bits = Adelta + Bdelta + pad_top + pad_bot; + + if (extra_bits >= max_extra_bits) + return gr_mat_mul_classical(C, A, B, ctx); + + Aexp = Amax + pad_top; + Bexp = Bmax + pad_top; + fbits = prec + extra_bits; + fnlimbs = (fbits + FLINT_BITS - 1) / FLINT_BITS; + + return _nfloat_complex_mat_mul_fixed_given_exp(C, A, B, Aexp, Bexp, fnlimbs, waksman, ctx); +} + +int +nfloat_complex_mat_mul_fixed_classical(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, gr_ctx_t ctx) +{ + return _nfloat_complex_mat_mul_fixed(C, A, B, 0, 100000, ctx); +} + +int +nfloat_complex_mat_mul_waksman(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, gr_ctx_t ctx) +{ + return _nfloat_complex_mat_mul_fixed(C, A, B, 1, 100000, ctx); +} + +FLINT_FORCE_INLINE slong +_nfloat_complex_nbits(nfloat_srcptr x, slong nlimbs) +{ + nn_srcptr ad; + slong bits; + + ad = NFLOAT_D(x); + bits = FLINT_BITS * nlimbs; + + while (ad[0] == 0) + { + bits -= FLINT_BITS; + ad++; + } + + bits -= flint_ctz(ad[0]); + + return bits; +} + +int +nfloat_complex_mat_mul_block(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, slong min_block_size, gr_ctx_t ctx) +{ + return nfloat_mat_mul_block(C, A, B, min_block_size, ctx); +} + +static int +_nfloat_complex_mat_is_real(const gr_mat_t A, gr_ctx_t ctx) +{ + slong i, j; + slong m = A->r; + slong n = A->c; + slong sz = ctx->sizeof_elem; + + for (i = 0; i < m; i++) + for (j = 0; j < n; j++) + if (!NFLOAT_IS_ZERO(NFLOAT_COMPLEX_IM(GR_MAT_ENTRY(A, i, j, sz), ctx))) + return 0; + + return 1; +} + +/* check if for all entries with nonzero real and imaginary parts, the + components don't differ too much in magnitude */ +static int +_nfloat_complex_mat_parts_are_well_scaled(const gr_mat_t A, gr_ctx_t ctx) +{ + slong i, j; + slong m = A->r; + slong n = A->c; + slong sz = ctx->sizeof_elem; + nn_srcptr a, b; + slong aexp, bexp, max; + + max = NFLOAT_CTX_PREC(ctx) / 8 + 64; + + for (i = 0; i < m; i++) + { + for (j = 0; j < n; j++) + { + a = NFLOAT_COMPLEX_RE(GR_MAT_ENTRY(A, i, j, sz), ctx); + if (NFLOAT_IS_ZERO(a)) + continue; + + b = NFLOAT_COMPLEX_IM(GR_MAT_ENTRY(A, i, j, sz), ctx); + if (NFLOAT_IS_ZERO(b)) + continue; + + aexp = NFLOAT_EXP(a); + bexp = NFLOAT_EXP(b); + + if (FLINT_ABS(aexp - bexp) > max) + return 0; + } + } + + return 1; +} + +static void +_real_part(gr_mat_t res, const gr_mat_t A, gr_ctx_t ctx) +{ + slong i, j; + slong m = A->r; + slong n = A->c; + slong ndlimbs = NFLOAT_CTX_DATA_NLIMBS(ctx); + + for (i = 0; i < m; i++) + for (j = 0; j < n; j++) + flint_mpn_copyi(((nn_ptr) res->rows[i]) + j * ndlimbs, ((nn_srcptr) A->rows[i]) + j * 2 * ndlimbs, ndlimbs); +} + +static void +_imag_part(gr_mat_t res, const gr_mat_t A, gr_ctx_t ctx) +{ + slong i, j; + slong m = A->r; + slong n = A->c; + slong ndlimbs = NFLOAT_CTX_DATA_NLIMBS(ctx); + + for (i = 0; i < m; i++) + for (j = 0; j < n; j++) + flint_mpn_copyi(((nn_ptr) res->rows[i]) + j * ndlimbs, ((nn_srcptr) A->rows[i]) + (j * 2 + 1) * ndlimbs, ndlimbs); +} + +static void +_set_real_part(gr_mat_t res, const gr_mat_t A, gr_ctx_t ctx) +{ + slong i, j; + slong m = A->r; + slong n = A->c; + slong ndlimbs = NFLOAT_CTX_DATA_NLIMBS(ctx); + + for (i = 0; i < m; i++) + for (j = 0; j < n; j++) + flint_mpn_copyi(((nn_ptr) res->rows[i]) + j * 2 * ndlimbs, ((nn_srcptr) A->rows[i]) + j * ndlimbs, ndlimbs); +} + +static void +_set_imag_part(gr_mat_t res, const gr_mat_t A, gr_ctx_t ctx) +{ + slong i, j; + slong m = A->r; + slong n = A->c; + slong ndlimbs = NFLOAT_CTX_DATA_NLIMBS(ctx); + + for (i = 0; i < m; i++) + for (j = 0; j < n; j++) + flint_mpn_copyi(((nn_ptr) res->rows[i]) + (j * 2 + 1) * ndlimbs, ((nn_srcptr) A->rows[i]) + j * ndlimbs, ndlimbs); +} + +int +_nfloat_complex_mat_mul_reorder(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, int Areal, int Breal, gr_ctx_t ctx) +{ + gr_mat_t X, Y, Z, W; + gr_ctx_t ctx2; + slong i, j, ndlimbs; + int status = GR_SUCCESS; + + slong M = A->r; + slong N = A->c; + slong P = B->c; + + nfloat_ctx_init(ctx2, NFLOAT_CTX_PREC(ctx), NFLOAT_CTX_FLAGS(ctx)); + ndlimbs = NFLOAT_CTX_DATA_NLIMBS(ctx2); + + gr_mat_init(X, M, N, ctx2); + gr_mat_init(Y, N, P, ctx2); + gr_mat_init(Z, M, P, ctx2); + + if (Areal && Breal) + { + _real_part(X, A, ctx2); + _real_part(Y, B, ctx2); + status |= gr_mat_mul(Z, X, Y, ctx2); + _set_real_part(C, Z, ctx2); + + for (i = 0; i < M; i++) + for (j = 0; j < P; j++) + status |= nfloat_zero(((nn_ptr) C->rows[i]) + (j * 2 + 1) * ndlimbs, ctx2); + } + else if (Areal) + { + _real_part(X, A, ctx2); + _real_part(Y, B, ctx2); + status |= gr_mat_mul(Z, X, Y, ctx2); + _set_real_part(C, Z, ctx2); + _imag_part(Y, B, ctx2); + status |= gr_mat_mul(Z, X, Y, ctx2); + _set_imag_part(C, Z, ctx2); + } + else if (Breal) + { + _real_part(X, A, ctx2); + _real_part(Y, B, ctx2); + status |= gr_mat_mul(Z, X, Y, ctx2); + _set_real_part(C, Z, ctx2); + _imag_part(X, A, ctx2); + status |= gr_mat_mul(Z, X, Y, ctx2); + _set_imag_part(C, Z, ctx2); + } + else + { + gr_mat_init(W, M, P, ctx2); + + /* Z = re(A) * re(B) */ + _real_part(X, A, ctx2); + _real_part(Y, B, ctx2); + status |= gr_mat_mul(Z, X, Y, ctx2); + + /* W = im(A) * im(B) */ + _imag_part(X, A, ctx2); + _imag_part(Y, B, ctx2); + status |= gr_mat_mul(W, X, Y, ctx2); + + if (A == C || B == C) + { + gr_mat_t T; + gr_mat_init(T, M, P, ctx2); + + /* T = Z - W */ + status |= gr_mat_sub(T, Z, W, ctx2); + + /* Z = re(A) * im(B) */ + /* W = im(A) * re(B) */ + _real_part(X, A, ctx2); + status |= gr_mat_mul(Z, X, Y, ctx2); + _imag_part(X, A, ctx2); + _real_part(Y, B, ctx2); + status |= gr_mat_mul(W, X, Y, ctx2); + + /* re(C) = T */ + _set_real_part(C, T, ctx2); + + gr_mat_clear(T, ctx2); + + /* im(C) = Z + W */ + /* todo: add directly to the output */ + status |= gr_mat_add(Z, Z, W, ctx2); + _set_imag_part(C, Z, ctx2); + } + else + { + /* re(C) = Z - W */ + status |= gr_mat_sub(Z, Z, W, ctx2); + _set_real_part(C, Z, ctx2); + + _real_part(X, A, ctx2); + status |= gr_mat_mul(Z, X, Y, ctx2); + _imag_part(X, A, ctx2); + _real_part(Y, B, ctx2); + status |= gr_mat_mul(W, X, Y, ctx2); + + /* im(C) = Z + W */ + /* todo: add directly to the output */ + status |= gr_mat_add(Z, Z, W, ctx2); + _set_imag_part(C, Z, ctx2); + } + + gr_mat_clear(W, ctx2); + } + + gr_mat_clear(X, ctx2); + gr_mat_clear(Y, ctx2); + gr_mat_clear(Z, ctx2); + + gr_ctx_clear(ctx2); + + return status; +} + +int +nfloat_complex_mat_mul_reorder(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, gr_ctx_t ctx) +{ + return _nfloat_complex_mat_mul_reorder(C, A, B, _nfloat_complex_mat_is_real(A, ctx), + _nfloat_complex_mat_is_real(B, ctx), ctx); +} + +int +nfloat_complex_mat_mul(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, gr_ctx_t ctx) +{ + slong dim; + slong block_cutoff; + int use_waksman = 0; + slong prec; + slong max_extra_prec; + int A_real = 0, B_real = 0; + + slong m = A->r; + slong n = A->c; + slong p = B->c; + + dim = FLINT_MIN(n, FLINT_MIN(m, p)); + + if (dim <= 2 || NFLOAT_CTX_HAS_INF_NAN(ctx)) + return gr_mat_mul_classical(C, A, B, ctx); + + if (dim >= 6) + { + A_real = _nfloat_complex_mat_is_real(A, ctx); + B_real = _nfloat_complex_mat_is_real(B, ctx); + + if (A_real || B_real) + return _nfloat_complex_mat_mul_reorder(C, A, B, A_real, B_real, ctx); + } + + prec = NFLOAT_CTX_PREC(ctx); + + if (prec <= 256) + block_cutoff = 80; + else if (prec <= 512) + block_cutoff = 160; + else if (prec <= 3072) + block_cutoff = 100; + else + block_cutoff = 80; + + if (dim < block_cutoff) + { + if (prec <= 128 || (prec <= 256 && n <= 4) || (prec == 576 && n <= 6)) + return gr_mat_mul_classical(C, A, B, ctx); + + if (prec <= 320) + use_waksman = 0; + else if (prec <= 448) + use_waksman = (dim >= 20); + else if (prec <= 640) + use_waksman = (dim >= 6); + else if (prec <= 1536) + use_waksman = (dim >= 4); + else + use_waksman = (dim >= 3); + + max_extra_prec = (prec < 768) ? 64 : prec / 4; + + return _nfloat_complex_mat_mul_fixed(C, A, B, use_waksman, max_extra_prec, ctx); + } + else + { + if (_nfloat_complex_mat_parts_are_well_scaled(A, ctx) && + _nfloat_complex_mat_parts_are_well_scaled(B, ctx)) + { + return nfloat_complex_mat_mul_block(C, A, B, block_cutoff - 10, ctx); + } + else + { + return _nfloat_complex_mat_mul_reorder(C, A, B, A_real, B_real, ctx); + } + } +} diff --git a/src/nfloat/test/main.c b/src/nfloat/test/main.c index 8520994f89..a4473c5693 100644 --- a/src/nfloat/test/main.c +++ b/src/nfloat/test/main.c @@ -13,6 +13,7 @@ #include "t-add_sub_n.c" #include "t-addmul_submul.c" +#include "t-complex_mat_mul.c" #include "t-mat_mul.c" #include "t-nfloat.c" #include "t-nfloat_complex.c" @@ -23,6 +24,7 @@ test_struct tests[] = { TEST_FUNCTION(add_sub_n), TEST_FUNCTION(addmul_submul), + TEST_FUNCTION(complex_mat_mul), TEST_FUNCTION(mat_mul), TEST_FUNCTION(nfloat), TEST_FUNCTION(nfloat_complex), diff --git a/src/nfloat/test/t-complex_mat_mul.c b/src/nfloat/test/t-complex_mat_mul.c new file mode 100644 index 0000000000..8d5e38a978 --- /dev/null +++ b/src/nfloat/test/t-complex_mat_mul.c @@ -0,0 +1,70 @@ +/* + Copyright (C) 2024 Fredrik Johansson + + This file is part of FLINT. + + FLINT is free software: you can redistribute it and/or modify it under + the terms of the GNU Lesser General Public License (LGPL) as published + by the Free Software Foundation; either version 3 of the License, or + (at your option) any later version. See . +*/ + +#include "test_helpers.h" +#include "gr.h" +#include "gr_mat.h" +#include "nfloat.h" + +int +nfloat_complex_mat_mul_block1(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, gr_ctx_t ctx) +{ + return nfloat_complex_mat_mul_block(C, A, B, 2, ctx); +} + +TEST_FUNCTION_START(complex_mat_mul, state) +{ + gr_ctx_t ctx; + slong prec; + slong iter; + gr_ptr tol; + + for (iter = 0; iter < 10 * flint_test_multiplier(); iter++) + { + if (n_randint(state, 5)) + prec = FLINT_BITS * (1 + n_randint(state, 4)); + else + prec = FLINT_BITS * (1 + n_randint(state, NFLOAT_MAX_LIMBS)); + + nfloat_complex_ctx_init(ctx, prec, 0); + + tol = gr_heap_init(ctx); + GR_MUST_SUCCEED(gr_one(tol, ctx)); + GR_MUST_SUCCEED(gr_mul_2exp_si(tol, tol, -prec + 2, ctx)); + + gr_mat_test_approx_mul_max_norm( + (gr_method_mat_binary_op) nfloat_complex_mat_mul_reorder, + tol, state, 10, 4, ctx); + + gr_mat_test_approx_mul_max_norm( + (gr_method_mat_binary_op) nfloat_complex_mat_mul_waksman, + tol, state, (prec <= 256) ? 100 : 1, 10, ctx); + + gr_mat_test_approx_mul_max_norm( + (gr_method_mat_binary_op) nfloat_complex_mat_mul_block1, + tol, state, (prec <= 256) ? 10 : 1, + (prec <= 256) ? 40 : 20, ctx); + + gr_mat_test_approx_mul_max_norm( + (gr_method_mat_binary_op) nfloat_complex_mat_mul_fixed_classical, + tol, state, (prec <= 256) ? 10 : 1, + (prec <= 256) ? 40 : 20, ctx); + + gr_mat_test_approx_mul_max_norm( + (gr_method_mat_binary_op) nfloat_complex_mat_mul, + tol, state, 1, 120, ctx); + + gr_heap_clear(tol, ctx); + gr_ctx_clear(ctx); + } + + TEST_FUNCTION_END(state); +}