Dilawar's Blog

watch -n 1 /dev/null

12 Nov 2020

argmax in pure Python

Python does not have an inbuilt argmax function. numpy does.

Here is one implementation.

def argmax(ls : list) -> int:
    _m, _mi = -math.inf, 0   # requires `import math` at top
    for i, v in enumerate(ls):
        if v > _m:
            _m = v
            _mi = i
    return _mi

There is also a one-liner which I often use: max(zip(ls, range(len(ls))))[1] where ls is the input list. To my surprise, the one liner is slower than argmax defined above (Python3.9 on openSUSE-Tumbleweed/Intel).

In [1]: a = [1,2,3,9,5,6,3,2]                                                                                                                                     
In [2]: max(zip(a, range(len(a))))[1]                                                                                                                                                         
Out[2]: 3

In [4]: %timeit max(zip(a, range(len(a))))                                                                                                                        
1.34 µs ± 28.7 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)                                                                                                                                 

In [7]: %timeit argmax(a)                                                                                                                                         
837 ns ± 8.65 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

On large list,

In [11]: a = [random.randint(0, 111111) for x in range(1000)]                                                                                                     

In [12]: %timeit argmax(a)                                                                                                                                        
67.7 µs ± 446 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

In [13]: %timeit max(zip(a, range(len(a))))                                                                                                                       
86.1 µs ± 739 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

Categories

Tags