Ask Your Question
0

Issue porting solve utility from sympy to sagemath

asked 2024-04-30 10:39:17 +0100

snowch gravatar image

I'm trying to port a utility from sympy to sagemath.

Here is my sympy code:

from sympy import symbols, Eq, solve, Matrix, pprint
x, y, z = symbols('x y z')

def sympy_solve(augmented_matrix):
    # Get the number of variables
    num_variables = augmented_matrix.shape[1] - 1

    # Generate symbols for variables
    variables = symbols('x:' + str(num_variables))

    # Extract coefficients and constants from the augmented matrix
    coefficients = augmented_matrix[:, :-1]
    constants = augmented_matrix[:, -1]

    # Create equations from the coefficients and constants
    equations = []
    for i in range(len(constants)):
        equation = Eq(sum(coefficients[i, j] * variables[j] for j in range(num_variables)), constants[i])
        equations.append(equation)

    # Solve the equations
    solution = solve(equations, variables, dict=True)
    return solution

When I run it, I get the result I expect:

A = Matrix([
    [1,1,1,-1,1],
    [0,1,-1,1,-1],
    [3,0,6,-6,6],
    [0,-1,1,-1,1]
])
print(sympy_solve(A))
# [{
#   x0: -2*x2 + 2*x3 + 2, 
#   x1: x2 - x3 - 1
# }]

My current solution in sagemath looks like this:

def sagemath_solve(augmented_matrix):

    A = augmented_matrix[:, :-1]
    Y = augmented_matrix[:, -1]

    m, n = A.dimensions()
    p, q = Y.dimensions()

    if m!=p:
        raise RuntimeError("The matrices have different numbers of rows")
    X = vector([var("x_{}".format(i)) for i in [1..n]])

    sols = []
    for j in range(q):
        system = [A[i]*X==Y[i,j] for i in range(m)]
        sols += solve(system, *X)
    return sols

However, it is returning rationals:

A = matrix([
    [1,1,1,-1,1],
    [0,1,-1,1,-1],
    [3,0,6,-6,6],
    [0,-1,1,-1,1]
])
print(sagemath_solve(A))

# [[
#   x_1 == 2*r1 - 2*r2 + 2, 
#   x_2 == -r1 + r2 - 1, 
#   x_3 == r2, 
#   x_4 == r1
# ]]

Any thoughts on how to fix this?

edit retag flag offensive close merge delete

1 Answer

Sort by ยป oldest newest most voted
0

answered 2024-04-30 10:57:57 +0100

snowch gravatar image

updated 2024-04-30 11:58:41 +0100

The solution for me was to remove the free variables from solve.

(I also fixed the vars to start from x_0).

def sagemath_solve(augmented_matrix):

    A = augmented_matrix[:, :-1]
    Y = augmented_matrix[:, -1]

    m, n = A.dimensions()
    p, q = Y.dimensions()

    if m!=p:
        raise RuntimeError("The matrices have different numbers of rows")
    X = vector([var("x_{}".format(i)) for i in [0..n-1]])

    # don't include the free variables in solve
    X_pivots = vector([var("x_{}".format(i)) for i in [0..n-1] if i in A.pivots()])

    sols = []
    for j in range(q):
        system = [A[i]*X==Y[i,j] for i in range(m)]
        sols += solve(system, *X_pivots)
    return sols


A = matrix([
    [1,1,1,-1,1],
    [0,1,-1,1,-1],
    [3,0,6,-6,6],
    [0,-1,1,-1,1]
])
print(sagemath_solve(A))

# [[
#   x_0 == -2*x_2 + 2*x_3 + 2, 
#   x_1 == x_2 - x_3 - 1
# ]]
edit flag offensive delete link more

Your Answer

Please start posting anonymously - your entry will be published after you log in or create a new account.

Add Answer

Question Tools

1 follower

Stats

Asked: 2024-04-30 10:39:17 +0100

Seen: 116 times

Last updated: Apr 30