Skip to content
This repository has been archived by the owner on Oct 15, 2019. It is now read-only.

Data dependent branches not working #140

Open
chenyang-tao opened this issue Jan 31, 2017 · 4 comments
Open

Data dependent branches not working #140

chenyang-tao opened this issue Jan 31, 2017 · 4 comments

Comments

@chenyang-tao
Copy link

chenyang-tao commented Jan 31, 2017

Hi guys, following are some issues with data dependent branching I encountered lately.

  1. This exact replica of the data dependent branching code does not work as expected.
import minpy.numpy as np

num = 100
X = np.linspace(-1.0,1.0,num=num)
Y = np.zeros([num,])

if X<Y:
    Z = X + Y
else:
    Z = Y ** 2

# Only Z = X + Y is executed
  1. Things get more hairy if I want to take the gradients with branching.
import minpy
from minpy.core import grad

def foo(X):

    # This will raise an AttributeError
    if X<=0:
        Y = X**2
    else:
        Y = 0

    # And this will work
    if -X>=0:
        Y = X**2
    else:
        Y = 0
    
    return Y

foo_grad = grad(foo)

print foo_grad(-1.0) 
@jermainewang
Copy link
Member

Hi,

The first example is ambiguous. If you replace the namespace with numpy, you will got:

Traceback (most recent call last):
  File "t1.py", line 8, in <module>
    if X < Y:
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

, because both X and Y are arrays, the X < Y will return a bool array of true and false. So what do you mean by condition X < Y? It is np.all(X < Y) or np.any(X < Y)? The reason only X + Y is executed is because in minpy it will use the first element as the condition value (which is true). So I think the behavior is correct, except that we should also give the same error as numpy.

The second one is same as #139 , where mxnet's operators have poor support for scalar arguments. We will fix this in MXNet.

Thanks for the report!

@chenyang-tao
Copy link
Author

Thanks for the reply. I am now even more confused with the first example. The following code is given as an example demonstrating features of minpy in README.md, highlighting ''... you freely use the if statement anyway you like.''

import minpy.numpy as np

x = ... # create x array
y = ... # create y array

if x < y:
    z = x + y
else:
    z = y ** 2

So if only the first element of the returned bool array will be used, then this example is misleading.

Looking forward to your fix with the scalar arguments issue, it causes a lot of troubles.

@jermainewang
Copy link
Member

I see. The example is only for scalar, but not meaningful for arrays. To make it more appropriate for arrays, you should have:

import minpy.numpy as np

x = ...
y = ...

if x[0] < y[0]:
  z = x + y
else:
  z = y ** 2

@lryta
Copy link
Member

lryta commented Feb 1, 2017

@jermainewang Can you update the image in web-data? Thanks!

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants