How can I override comparisons of NumPy ndarray and my type?

In NumPy, you can use the __array_priority__ attribute to control binary operators acting on ndarray and a user-defined type. For instance:

class Foo(object): def __radd__(self, lhs): return 0 __array_priority__ = 100 a = np.random.random((100,100)) b = Foo() a + b # calls b.__radd__(a) -> 0 

The same, however, does not work for comparison operators. For example, if I add the following line to Foo , then it is never called from the expression a < b :

 def __rlt__(self, lhs): return 0 

I understand that __rlt__ is not really a Python special name, but I thought this might work. I tried all __lt__ , __le__ , __eq__ , __ne__ , __ge__ , __gt__ with and without the previous r , plus __cmp__ too, but I could never get NumPy to call any of them.

Is it possible to override these comparisons?

UPDATE

To avoid confusion, here is a more detailed description of the behavior of NumPy. First, this is what the Guide to NumPy book says:

 If the ufunc has 2 inputs and 1 output and the second input is an Object array then a special-case check is performed so that NotImplemented is returned if the second input is not an ndarray, has the array priority attribute, and has an r<op> special method. 

I think this is a rule that does + work. Here is an example:

 import numpy as np a = np.random.random((2,2)) class Bar0(object): def __add__(self, rhs): return 0 def __radd__(self, rhs): return 1 b = Bar0() print a + b # Calls __radd__ four times, returns an array # [[1 1] # [1 1]] class Bar1(object): def __add__(self, rhs): return 0 def __radd__(self, rhs): return 1 __array_priority__ = 100 b = Bar1() print a + b # Calls __radd__ once, returns 1 # 1 

As you can see, without __array_priority__ , NumPy interprets the user-defined object as a scalar type and applies the operation at each position in the array. This is not what I want. My type is like an array (but should not be derived from ndarray).

Here is a longer example showing how this happens when all the comparison methods are defined:

 class Foo(object): def __cmp__(self, rhs): return 0 def __lt__(self, rhs): return 1 def __le__(self, rhs): return 2 def __eq__(self, rhs): return 3 def __ne__(self, rhs): return 4 def __gt__(self, rhs): return 5 def __ge__(self, rhs): return 6 __array_priority__ = 100 b = Foo() print a < b # Calls __cmp__ four times, returns an array # [[False False] # [False False]] 
+4
source share
2 answers

Looks like I can answer it myself. np.set_numeric_ops can be used as follows:

 class Foo(object): def __lt__(self, rhs): return 0 def __le__(self, rhs): return 1 def __eq__(self, rhs): return 2 def __ne__(self, rhs): return 3 def __gt__(self, rhs): return 4 def __ge__(self, rhs): return 5 __array_priority__ = 100 def override(name): def ufunc(x,y): if isinstance(y,Foo): return NotImplemented return np.getattr(name)(x,y) return ufunc np.set_numeric_ops( ** { ufunc : override(ufunc) for ufunc in ( "less", "less_equal", "equal", "not_equal", "greater_equal" , "greater" ) } ) a = np.random.random((2,2)) b = Foo() print a < b # 4 
+1
source

I can not reproduce your problem. The correct approach is to use the special __cmp__ method. If i write

 import numpy as np class Foo(object): def __radd__(self, lhs): return 0 def __cmp__(self, this): return -1 __array_prioriy__ = 100 a = np.random.random((100,100)) b = Foo() print a<b 

and set a breakpoint in the debugger, execution stops at return -1 .

Btw: __array_prioriy__ doesn't matter here: you have a typo!

0
source

All Articles