Ask Your Question

Revision history [back]

click to hide/show revision 1
initial version

Decided to just write this myself. Found the expr2tree() function here.

x, x0, dt = var('x, x0, dt')

def expr2tree(expr): 
    if expr.operator() is None: 
        return expr 
    else: 
        return [expr.operator()]+map(expr2tree, expr.operands())

def tree2expr(tree):
    args = []
    for ea in tree[1:]:
        if type(ea) is list:
            args.append(tree2expr(ea))
        else:
            args.append(ea)
    return tree[0](*args)

def remove(tree, var, exp):
    if type(tree) != list:
        return False

    if not hasattr(tree[0], '__name__'):
        return False

    if tree[0].__name__ == 'pow':
        if tree[1] == var and tree[2] >= exp:
            return True

    for ea in tree[1:]:
        if remove(ea, var, exp):
            return True

    return False

def truncate_terms(expr, term):
    orig = expr
    expr = expr.expand()
    tree = expr2tree(expr)

    # check is addition
    if tree[0].__name__ != 'add_vararg':
        return expr

    # expand term (eg. dt^5)
    term_tree = expr2tree(term)
    if term_tree[0].__name__ != 'pow':
        return expr

    var, exp = term_tree[1:3]

    args = []
    for ea in tree[1:]:
        if not remove(ea, var, exp):
            args.append(ea)

    return tree2expr([tree[0]] + args)

Test:

e = x*dt + x*x0*9*dt^2 + x*100*dt^3
truncate_terms(e, dt^3)

Output:

9*dt^2*x*x0 + dt*x