Compare strings in numba-compiled function
For newer numba versions (0.41.0 and later)
Numba (since version 0.41.0) support str
in nopython mode and the code as written in the question will "just work". However for your example comparing the strings is much slower than your operation, so if you want to use strings in numba functions make sure the overhead is worth it.
import numba as nb
@nb.njit
def foo_string(a, t):
if t == 'awesome':
return(a**2)
elif t == 'default':
return(a**3)
else:
return a
@nb.njit
def foo_int(a, t):
if t == 1:
return(a**2)
elif t == 0:
return(a**3)
else:
return a
assert foo_string(100, 'default') == foo_int(100, 0)
%timeit foo_string(100, 'default')
# 2.82 µs ± 45.9 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
%timeit foo_int(100, 0)
# 213 ns ± 10.2 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
In your case the code is more than 10 times slower using strings.
Since your function doesn't do much it could be better and faster to do the string comparison in Python instead of numba:
def foo_string2(a, t):
if t == 'awesome':
sec = 1
elif t == 'default':
sec = 0
else:
sec = -1
return foo_int(a, sec)
assert foo_string2(100, 'default') == foo_string(100, 'default')
%timeit foo_string2(100, 'default')
# 323 ns ± 10.6 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
This is still a bit slower than the pure integer version but it's almost 10 times faster than using the string in the numba function.
But if you do a lot of numerical work in the numba function the string comparison overhead won't matter. But simply putting numba.njit
on a function, especially if it doesn't do many array operations or number crunching, won't make it automatically faster!
For older numba versions (before 0.41.0):
Numba doesn't support strings in nopython
mode.
From the documentation:
2.6.2. Built-in types
2.6.2.1. int, bool [...]
2.6.2.2. float, complex [...]
2.6.2.3. tuple [...]
2.6.2.4. list [...]
2.6.2.5. set [...]
2.6.2.7. bytes, bytearray, memoryview
The
bytearray
type and, on Python 3, thebytes
type support indexing, iteration and retrieving thelen()
.[...]
So strings aren't supported at all and bytes don't support equality checks.
However you can pass in bytes
and iterate over them. That makes it possible to write your own comparison function:
import numba as nb
@nb.njit
def bytes_equal(a, b):
if len(a) != len(b):
return False
for char1, char2 in zip(a, b):
if char1 != char2:
return False
return True
Unfortunately the next problem is that numba cannot "lower" bytes, so you cannot hardcode the bytes in the function directly. But bytes are basically just integers, and the bytes_equal
function works for all types that numba supports, that have a length and can be iterated over. So you could simply store them as lists:
import numba as nb
@nb.njit
def foo(a, t):
if bytes_equal(t, [97, 119, 101, 115, 111, 109, 101]):
return a**2
elif bytes_equal(t, [100, 101, 102, 97, 117, 108, 116]):
return a**3
else:
return a
or as global arrays (thanks @chrisb - see comments):
import numba as nb
import numpy as np
AWESOME = np.frombuffer(b'awesome', dtype='uint8')
DEFAULT = np.frombuffer(b'default', dtype='uint8')
@nb.njit
def foo(a, t):
if bytes_equal(t, AWESOME):
return a**2
elif bytes_equal(t, DEFAULT):
return a**3
else:
return a
Both will work correctly:
>>> foo(10, b'default')
1000
>>> foo(10, b'awesome')
100
>>> foo(10, b'awe')
10
However, you cannot specify a bytes array as default, so you need to explicitly provide the t
variable. Also it feels hacky to do it that way.
My opinion: Just do the if t == ...
checks in a normal function and call specialized numba functions inside the if
s. String comparisons are really fast in Python, just wrap the math/array-intensive stuff in a numba function:
import numba as nb
@nb.njit
def awesome_func(a):
return a**2
@nb.njit
def default_func(a):
return a**3
@nb.njit
def other_func(a):
return a
def foo(a, t='default'):
if t == 'awesome':
return awesome_func(a)
elif t == 'default':
return default_func(a)
else:
return other_func(a)
But make sure you actually need numba for the functions. Sometimes normal Python/NumPy will be fast enough. Just profile the numba solution and a Python/NumPy solution and see if numba makes it significantly faster. :)
I'd suggest accepting @MSeifert's answer, but as a another option for these types of problems, consider using an enum
.
In python, strings are often used as a sort of enum, and you numba
has builtin support for enums so they can be used directly.
import enum
class FooOptions(enum.Enum):
AWESOME = 1
DEFAULT = 2
import numba
@numba.njit
def foo(a, t=FooOptions.DEFAULT):
if t == FooOptions.AWESOME:
return a**2
elif t == FooOptions.DEFAULT:
return a**2
else:
return a
foo(10, FooOptions.AWESOME)
Out[5]: 100