Skip to content

Commit

Permalink
Merge pull request #1902 from flintlib/karatsuba
Browse files Browse the repository at this point in the history
Generic Karatsuba multiplication
  • Loading branch information
fredrik-johansson authored Apr 2, 2024
2 parents 8665025 + e8bd59f commit d9b387b
Show file tree
Hide file tree
Showing 8 changed files with 245 additions and 6 deletions.
10 changes: 10 additions & 0 deletions doc/source/gr_poly.rst
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,16 @@ Arithmetic

.. function:: int gr_poly_mul_scalar(gr_poly_t res, const gr_poly_t poly, gr_srcptr c, gr_ctx_t ctx)

.. function:: int _gr_poly_mul_karatsuba(gr_ptr res, gr_srcptr poly1, slong len1, gr_srcptr poly2, slong len2, gr_ctx_t ctx)
int gr_poly_mul_karatsuba(gr_poly_t res, const gr_poly_t poly1, const gr_poly_t poly2, gr_ctx_t ctx)

Karatsuba multiplication.
Not optimized for unbalanced operands, and not memory-optimized for recursive calls.
The underscore method requires positive lengths and does not support aliasing.
This function calls :func:`_gr_poly_mul` recursively rather than itself, so to get a recursive
algorithm with `O(n^{1.6})` complexity, the ring must overload :func:`_gr_poly_mul` to dispatch
to :func:`_gr_poly_mul_karatsuba` above some cutoff.

Powering
--------------------------------------------------------------------------------

Expand Down
6 changes: 1 addition & 5 deletions src/gr/init_random.c
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,8 @@ gr_ctx_init_random_ring_composite(gr_ctx_t ctx, flint_rand_t state)
*/
break;
case 3:
gr_ctx_init_gr_poly(ctx, base_ring);
/*
this currently breaks some tests
gr_ctx_init_gr_series_mod(ctx, base_ring, n_randint(state, 6));
gr_ctx_init_series_mod_gr_poly(ctx, base_ring, n_randint(state, 6));
break;
*/
case 4:
gr_ctx_init_vector_space_gr_vec(ctx, base_ring, n_randint(state, 4));
break;
Expand Down
2 changes: 1 addition & 1 deletion src/gr/series_mod.c
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ static void _gr_gr_series_mod_init(gr_poly_t res, gr_ctx_t ctx)

static void _gr_gr_series_mod_clear(gr_poly_t res, gr_ctx_t ctx)
{
gr_poly_init(res, SERIES_MOD_ELEM_CTX(ctx));
gr_poly_clear(res, SERIES_MOD_ELEM_CTX(ctx));
}

static void _gr_gr_series_mod_swap(gr_poly_t x, gr_poly_t y, gr_ctx_t ctx)
Expand Down
4 changes: 4 additions & 0 deletions src/gr_poly.h
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,10 @@ GR_POLY_INLINE WARN_UNUSED_RESULT int _gr_poly_mullow(gr_ptr res, gr_srcptr poly
WARN_UNUSED_RESULT int gr_poly_mullow(gr_poly_t res, const gr_poly_t poly1, const gr_poly_t poly2, slong n, gr_ctx_t ctx);
WARN_UNUSED_RESULT int gr_poly_mul_scalar(gr_poly_t res, const gr_poly_t poly, gr_srcptr c, gr_ctx_t ctx);

WARN_UNUSED_RESULT int _gr_poly_mul_karatsuba(gr_ptr res, gr_srcptr poly1, slong len1, gr_srcptr poly2, slong len2, gr_ctx_t ctx);
WARN_UNUSED_RESULT int gr_poly_mul_karatsuba(gr_poly_t res, const gr_poly_t poly1, const gr_poly_t poly2, gr_ctx_t ctx);


/* powering */

WARN_UNUSED_RESULT int _gr_poly_pow_series_ui_binexp(gr_ptr res, gr_srcptr f, slong flen, ulong exp, slong len, gr_ctx_t ctx);
Expand Down
122 changes: 122 additions & 0 deletions src/gr_poly/mul_karatsuba.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
/*
Copyright (C) 2024 Fredrik Johansson
This file is part of FLINT.
FLINT is free software: you can redistribute it and/or modify it under
the terms of the GNU Lesser General Public License (LGPL) as published
by the Free Software Foundation; either version 3 of the License, or
(at your option) any later version. See <https://www.gnu.org/licenses/>.
*/

#include "gr_poly.h"

/* TODO: a self-recursive variant that can reuse scratch space. */
int
_gr_poly_mul_karatsuba(gr_ptr res, gr_srcptr f, slong flen, gr_srcptr g, slong glen, gr_ctx_t ctx)
{
slong m, f1len, g1len, tlen, ulen, vlen, alloc;
slong sz = ctx->sizeof_elem;
gr_ptr t, u, v;
gr_srcptr f0, f1, g0, g1;
int squaring = (f == g) && (flen == glen);
int status = GR_SUCCESS;

FLINT_ASSERT(flen >= 0);
FLINT_ASSERT(glen >= 0);
FLINT_ASSERT(res != f);
FLINT_ASSERT(res != g);

/* TODO: should explicitly call basecase mul. */
if (flen == 1 || glen == 1)
return _gr_poly_mullow_generic(res, f, flen, g, glen, flen + glen - 1, ctx);

/* Split at X = x^m */
/* res = f0 g0 + (f0 g1 + f1 g0) X + f1 g1 X^2
= f0 g0 + ((f0 + f1) (g0 + g1) - f0 g0 - f1 g1) X + f1 g1 X^2 */
m = (FLINT_MIN(flen, glen) + 1) / 2;

f0 = f;
g0 = g;
f1 = GR_ENTRY(f, m, sz);
g1 = GR_ENTRY(g, m, sz);
f1len = flen - m;
g1len = glen - m;

/* Low part: res[0, ..., 2m-2] = f0 g0. */
status |= _gr_poly_mul(res, f, m, g, m, ctx);
/* res[2m-1] = 0 (could be avoided if we split the summation later, which is more annoying) */
status |= gr_zero(GR_ENTRY(res, 2 * m - 1, sz), ctx);

/* High part: res[2m, ..., flen+glen-2] = f1 g1. */
status |= _gr_poly_mul(GR_ENTRY(res, 2 * m, sz), f1, f1len, g1, g1len, ctx);

/* Temporary space for the middle part. */
tlen = FLINT_MAX(m, f1len);
ulen = FLINT_MAX(m, g1len);
vlen = tlen + ulen - 1;
alloc = tlen + (squaring ? 0 : ulen) + FLINT_MAX(FLINT_MAX(vlen, 2 * m - 1), flen + glen - 1);

GR_TMP_INIT_VEC(t, alloc, ctx);
u = GR_ENTRY(t, tlen, sz);
v = squaring ? u : GR_ENTRY(u, ulen, sz);

/* t = f0 + f1 */
status |= _gr_poly_add(t, f0, m, f1, f1len, ctx);

if (squaring)
{
status |= _gr_poly_mul(v, t, tlen, t, tlen, ctx);
}
else
{
/* u = g0 + g1 */
status |= _gr_poly_add(u, g0, m, g1, g1len, ctx);
/* v = (f0 + f1) (g0 + g1) */
status |= _gr_poly_mul(v, t, tlen, u, ulen, ctx);
}

/* v -= f0 g0 */
status |= _gr_poly_sub(v, v, vlen, res, 2 * m - 1, ctx);
vlen = FLINT_MAX(vlen, 2 * m - 1);
/* v -= f1 g1 */
status |= _gr_poly_sub(v, v, vlen, GR_ENTRY(res, 2 * m, sz), f1len + g1len - 1, ctx);
vlen = FLINT_MAX(vlen, f1len + g1len - 1);

/* Finally add the middle part. */
status |= _gr_poly_add(GR_ENTRY(res, m, sz), GR_ENTRY(res, m, sz), vlen, v, vlen, ctx);

GR_TMP_CLEAR_VEC(t, alloc, ctx);

return status;
}

int
gr_poly_mul_karatsuba(gr_poly_t res, const gr_poly_t poly1, const gr_poly_t poly2, gr_ctx_t ctx)
{
slong len_out;
int status;

if (poly1->length == 0 || poly2->length == 0)
return gr_poly_zero(res, ctx);

len_out = poly1->length + poly2->length - 1;

if (res == poly1 || res == poly2)
{
gr_poly_t t;
gr_poly_init2(t, len_out, ctx);
status = _gr_poly_mul_karatsuba(t->coeffs, poly1->coeffs, poly1->length, poly2->coeffs, poly2->length, ctx);
gr_poly_swap(res, t, ctx);
gr_poly_clear(t, ctx);
}
else
{
gr_poly_fit_length(res, len_out, ctx);
status = _gr_poly_mul_karatsuba(res->coeffs, poly1->coeffs, poly1->length, poly2->coeffs, poly2->length, ctx);
}

_gr_poly_set_length(res, len_out, ctx);
_gr_poly_normalise(res, ctx);
return status;
}
1 change: 1 addition & 0 deletions src/gr_poly/mullow.c
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ _gr_poly_mullow_generic(gr_ptr res,
if (n == 1)
return gr_mul(res, poly1, poly2, ctx);

/* todo: noncommutative rings */
if (len1 == 1)
return _gr_vec_mul_scalar(res, poly2, n, poly1, ctx);

Expand Down
2 changes: 2 additions & 0 deletions src/gr_poly/test/main.c
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
#include "t-inv_series.c"
#include "t-log_series.c"
#include "t-make_monic.c"
#include "t-mul_karatsuba.c"
#include "t-nth_derivative.c"
#include "t-pow_series_fmpq.c"
#include "t-pow_series_ui.c"
Expand Down Expand Up @@ -105,6 +106,7 @@ test_struct tests[] =
TEST_FUNCTION(gr_poly_inv_series),
TEST_FUNCTION(gr_poly_log_series),
TEST_FUNCTION(gr_poly_make_monic),
TEST_FUNCTION(gr_poly_mul_karatsuba),
TEST_FUNCTION(gr_poly_nth_derivative),
TEST_FUNCTION(gr_poly_pow_series_fmpq),
TEST_FUNCTION(gr_poly_pow_series_ui),
Expand Down
104 changes: 104 additions & 0 deletions src/gr_poly/test/t-mul_karatsuba.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
/*
Copyright (C) 2023 Fredrik Johansson
This file is part of FLINT.
FLINT is free software: you can redistribute it and/or modify it under
the terms of the GNU Lesser General Public License (LGPL) as published
by the Free Software Foundation; either version 3 of the License, or
(at your option) any later version. See <https://www.gnu.org/licenses/>.
*/

#include "test_helpers.h"
#include "ulong_extras.h"
#include "gr_poly.h"

FLINT_DLL extern gr_static_method_table _ca_methods;

int
test_mul(flint_rand_t state, int which)
{
gr_ctx_t ctx;
slong n;
gr_poly_t A, B, C, D;
int status = GR_SUCCESS;

gr_ctx_init_random(ctx, state);

gr_poly_init(A, ctx);
gr_poly_init(B, ctx);
gr_poly_init(C, ctx);
gr_poly_init(D, ctx);

if (ctx->methods == _ca_methods)
n = 4;
else
n = 10;

GR_MUST_SUCCEED(gr_poly_randtest(A, state, 1 + n_randint(state, n), ctx));
GR_MUST_SUCCEED(gr_poly_randtest(B, state, 1 + n_randint(state, n), ctx));
GR_MUST_SUCCEED(gr_poly_randtest(C, state, 1 + n_randint(state, n), ctx));

switch (which)
{
case 0:
status |= gr_poly_mul_karatsuba(C, A, B, ctx);
break;
case 1:
status |= gr_poly_set(C, A, ctx);
status |= gr_poly_mul_karatsuba(C, C, B, ctx);
break;
case 2:
status |= gr_poly_set(C, B, ctx);
status |= gr_poly_mul_karatsuba(C, A, C, ctx);
break;
case 3:
status |= gr_poly_set(B, A, ctx);
status |= gr_poly_mul_karatsuba(C, A, A, ctx);
break;
case 4:
status |= gr_poly_set(B, A, ctx);
status |= gr_poly_set(C, A, ctx);
status |= gr_poly_mul_karatsuba(C, C, C, ctx);
break;

default:
flint_abort();
}

/* todo: should explicitly call basecase mul */
status |= gr_poly_mullow(D, A, B, FLINT_MAX(0, A->length + B->length - 1), ctx);

if (status == GR_SUCCESS && gr_poly_equal(C, D, ctx) == T_FALSE)
{
flint_printf("FAIL\n\n");
flint_printf("which = %d, n = %wd\n\n", which, n);
gr_ctx_println(ctx);
flint_printf("A = "); gr_poly_print(A, ctx); flint_printf("\n\n");
flint_printf("B = "); gr_poly_print(B, ctx); flint_printf("\n\n");
flint_printf("C = "); gr_poly_print(C, ctx); flint_printf("\n\n");
flint_printf("D = "); gr_poly_print(D, ctx); flint_printf("\n\n");
flint_abort();
}

gr_poly_clear(A, ctx);
gr_poly_clear(B, ctx);
gr_poly_clear(C, ctx);
gr_poly_clear(D, ctx);

gr_ctx_clear(ctx);

return status;
}

TEST_FUNCTION_START(gr_poly_mul_karatsuba, state)
{
slong iter;

for (iter = 0; iter < 1000; iter++)
{
test_mul(state, n_randint(state, 5));
}

TEST_FUNCTION_END(state);
}

0 comments on commit d9b387b

Please sign in to comment.