Ask Your Question

Revision history [back]

click to hide/show revision 1
initial version

What about

sage: def softmax_probs(L):
....:     S = sum(map(exp, L))
....:     return map(lambda u:exp(u)/S, L)
....: 
sage: L=var("a, b, c")
sage: list(softmax_probs(L))
[e^a/(e^a + e^b + e^c), e^b/(e^a + e^b + e^c), e^c/(e^a + e^b + e^c)]

?

What about

sage: def softmax_probs(L):
....:  S = sum(map(exp, R=list(map(exp, L))
....:     S=sum(R)
    return map(lambda u:exp(u)/S, L)
....: 
sage: L=var("a, u:u/S, R)
L=list(var("a, b, c")
sage: c"))
list(softmax_probs(L))
[e^a/(e^a + e^b + e^c), e^b/(e^a + e^b + e^c), e^c/(e^a + e^b + e^c)]

?

What about

def softmax_probs(L):
    R=list(map(exp, L))
    S=sum(R)
    return map(lambda u:u/S, R)
L=list(var("a, b, c"))
list(softmax_probs(L))
[e^a/(e^a + e^b + e^c), e^b/(e^a + e^b + e^c), e^c/(e^a + e^b + e^c)]

?

The point of the local variables is to avoid recomputing the same quantities more than once :

sage: %timeit list(map(lambda u:exp(u)/sum(map(exp, L)), L))
70.4 µs ± 942 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
sage: %timeit list(softmax_probs(L))
31.2 µs ± 699 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)