Integer arithmetic performance
Hi!
I think Python/Sage is really slow with integer arithmetic. I have an example below that can illustrate this.
My question is why is this the case, and is there some way to improve it?
(My guess is this is probably due to memory allocation, i.e., for large integers the operations are not performed in place. But I tried the mutable integer type xmpz
from the library gmpy2
, and I don't see any improvement in performance... Also I'm curious how are the integers in Sage handled by default: does it use the GMP or FLINT library?)
Here is the example.
Given a list of integers, I want to compute the sum of the products of all its k-combinations. The naive solution is sum(prod(c) for c in Combinations(vec, k))
. But the performance of this is not very good. Here is a version to do the enumeration of combinations manually.
def sum_c(k, w):
def dfs(k, n):
if k < 1:
ans[0] += pp[0]
else:
for m in range(k, n+1):
pp[k-1] = pp[k] * w[m-1]
dfs(k-1, m-1)
ans, pp = [0], [0]*k + [1]
dfs(k, len(w))
return ans[0]
from time import time
t = time()
sum_c(8, list(range(1, 31)))
print("sum_c:\t", time() - t)
from sage.all import *
t = time()
sum(prod(c) for c in Combinations(range(1, 31), 8))
print("naive:\t", time() - t)
The speed doubled with sum_c
.
sum_c: 2.283874988555908
naive: 5.798710346221924
But with C this can be computed in less than 0.1s...
Edit. OK I tried Cython
and gmpy2
, and I'm able to get a satisfactory result for my problem.
Still I wonder if there were any way to improve the performance without doing something so "manual"...
And here is my "translation" for reference.
# distutils: libraries = gmp
from gmpy2 cimport *
import numpy as np
import_gmpy2()
cdef extern from "gmp.h":
void mpz_init(mpz_t)
void mpz_init_set_si(mpz_t, long)
void mpz_set_si(mpz_t, long)
void mpz_add(mpz_t, mpz_t, mpz_t)
void mpz_mul_si(mpz_t, mpz_t, long)
cdef dfs(long k, long n, list w, mpz ans, mpz[:] pp):
cdef long m
if k < 1:
mpz_add(ans.z, ans.z, pp[0].z)
else:
for m in range(k, n+1):
mpz_mul_si(pp[k-1].z, pp[k].z, w[m-1])
dfs(k-1, m-1, w, ans, pp)
def sum_c(long k, list w):
cdef long i
cdef mpz ans = GMPy_MPZ_New(NULL)
cdef mpz[:] pp = np.array([GMPy_MPZ_New(NULL) for i in range(k+1)])
mpz_init(ans.z)
for i in range(k):
mpz_init(pp[i].z)
mpz_init_set_si(pp[k].z, 1)
dfs(k, len(w), w, ans, pp)
return int(ans)
I wonder what
Cython
would get you...Thanks! I just tried out Cython and indeed I get the performance comparable to C :) However, integers can get overflow in C, so I still need to use some libraries for big numbers, which seems complicated...
Also, is there anyway to systematically improve the integer arithmetic in general, or one has to manually switch to Cython case by case?
If you want to find out where the time is being spent in the code, try
%prun
, as in for example%prun sum(prod(c) for c in Combinations(range(1, 31), 8))
. See https://doc.sagemath.org/html/en/tuto....Thanks! I agree that profiling is very useful. Although in this case
%prun sum(prod(c) for c in Combinations(range(1, 31), 8))
tells me that most of the time is spend iteratingCombinations
, and%prun sum_c(8, list(range(1, 31)))
shows that it's the enumeration functiondfs
. Neither gives much information...