Integer arithmetic performance

asked 2021-08-28 23:35:53 +0200

8d1h gravatar image

updated 2021-08-29 13:24:29 +0200

Hi!

I think Python/Sage is really slow with integer arithmetic. I have an example below that can illustrate this.

My question is why is this the case, and is there some way to improve it?

(My guess is this is probably due to memory allocation, i.e., for large integers the operations are not performed in place. But I tried the mutable integer type xmpz from the library gmpy2, and I don't see any improvement in performance... Also I'm curious how are the integers in Sage handled by default: does it use the GMP or FLINT library?)


Here is the example.

Given a list of integers, I want to compute the sum of the products of all its k-combinations. The naive solution is sum(prod(c) for c in Combinations(vec, k)). But the performance of this is not very good. Here is a version to do the enumeration of combinations manually.

def sum_c(k, w):
    def dfs(k, n):
        if k < 1:
            ans[0] += pp[0]
        else:
            for m in range(k, n+1):
                pp[k-1] = pp[k] * w[m-1]
                dfs(k-1, m-1)
    ans, pp = [0], [0]*k + [1]
    dfs(k, len(w))
    return ans[0]

from time import time
t = time()
sum_c(8, list(range(1, 31)))
print("sum_c:\t", time() - t)

from sage.all import *
t = time()
sum(prod(c) for c in Combinations(range(1, 31), 8))
print("naive:\t", time() - t)

The speed doubled with sum_c.

sum_c:   2.283874988555908
naive:   5.798710346221924

But with C this can be computed in less than 0.1s...


Edit. OK I tried Cython and gmpy2, and I'm able to get a satisfactory result for my problem.

Still I wonder if there were any way to improve the performance without doing something so "manual"...

And here is my "translation" for reference.

# distutils: libraries = gmp
from gmpy2 cimport *
import numpy as np

import_gmpy2()

cdef extern from "gmp.h":
    void mpz_init(mpz_t)
    void mpz_init_set_si(mpz_t, long)
    void mpz_set_si(mpz_t, long)
    void mpz_add(mpz_t, mpz_t, mpz_t)
    void mpz_mul_si(mpz_t, mpz_t, long)

cdef dfs(long k, long n, list w, mpz ans, mpz[:] pp):
    cdef long m
    if k < 1:
        mpz_add(ans.z, ans.z, pp[0].z)
    else:
        for m in range(k, n+1):
            mpz_mul_si(pp[k-1].z, pp[k].z, w[m-1])
            dfs(k-1, m-1, w, ans, pp)

def sum_c(long k, list w):
    cdef long i
    cdef mpz ans = GMPy_MPZ_New(NULL)
    cdef mpz[:] pp = np.array([GMPy_MPZ_New(NULL) for i in range(k+1)])
    mpz_init(ans.z)
    for i in range(k):
        mpz_init(pp[i].z)
    mpz_init_set_si(pp[k].z, 1)
    dfs(k, len(w), w, ans, pp)
    return int(ans)
edit retag flag offensive close merge delete

Comments

I wonder what Cython would get you...

Emmanuel Charpentier gravatar imageEmmanuel Charpentier ( 2021-08-29 08:41:03 +0200 )edit

Thanks! I just tried out Cython and indeed I get the performance comparable to C :) However, integers can get overflow in C, so I still need to use some libraries for big numbers, which seems complicated...

Also, is there anyway to systematically improve the integer arithmetic in general, or one has to manually switch to Cython case by case?

8d1h gravatar image8d1h ( 2021-08-29 11:39:22 +0200 )edit

If you want to find out where the time is being spent in the code, try %prun, as in for example %prun sum(prod(c) for c in Combinations(range(1, 31), 8)). See https://doc.sagemath.org/html/en/tuto....

John Palmieri gravatar imageJohn Palmieri ( 2021-08-29 21:04:17 +0200 )edit

Thanks! I agree that profiling is very useful. Although in this case %prun sum(prod(c) for c in Combinations(range(1, 31), 8)) tells me that most of the time is spend iterating Combinations, and %prun sum_c(8, list(range(1, 31))) shows that it's the enumeration function dfs. Neither gives much information...

8d1h gravatar image8d1h ( 2021-08-29 21:55:51 +0200 )edit