Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Difference in the implementation of slicemap and mapcols #14

Open
raphaelchinchilla opened this issue Jul 8, 2022 · 4 comments
Open

Comments

@raphaelchinchilla
Copy link

The implementation of slicemap and mapcols is fundamentally different. Intuitively, this is relatively weird because

mapcols(f,M)=slicemap(f,M,dims=1)
and
slicemap(f,M,dims=1)=reshape(mapcols(f,reshape(M,size(M,1),:)),size(M))
(if dims is not equal to 1, then one could just use PermutedDimsArray)

After some (light) testing, I have the impression that using mapcols is about 25% faster than using slicemap. Is that a general result or specific to my application? Would there be any advantage on using either one or the other implementations?

@mcabbott
Copy link
Owner

mcabbott commented Jul 9, 2022

Yes, they have quite different paths. mapcols handles everything in-house, as this was the only way I could make things work for Tracker. (As does MapCols, in a different way.)

But slicemap calls JuliennedArrays to handle the slices. The gradient rules for this will only work with Zygote. I have not investigated very closely but there may be some overhead in this.

@raphaelchinchilla
Copy link
Author

Would there be any advantage on using either one or the other implementations?

After some more testing, I have concluded that in some situations one is better, in other situations the other is better. I am not sure what is the rule.

The gradient rules for this will only work with Zygote

It also works with Forward and ReverseDiff. Is that normal?

Also, a curious behavior that I have observed is that mapcols does not take the gradient of the parameters when one uses Zygote. It can be seen in this toy problem:

using SliceMap, ForwardDiff, ReverseDiff, Zygote


f(x,p)=[p*(x'*x)]

cost_slice(x,p)=sum(slicemap(x->f(x,p),x,dims=1))
cost_each(x,p)=sum(mapcols(x->f(x,p),x))

x=randn(10,100)
p=rand()

# Using slicemap

ForwardDiff.gradient(x) do x
    cost_slice(x,p)    
end

ForwardDiff.derivative(p) do p
    cost_slice(x,p)    
end

Zygote.gradient(x) do x
    cost_slice(x,p)    
end

Zygote.gradient(p) do p
    cost_slice(x,p)    
end


# Using mapcols

ForwardDiff.gradient(x) do x
    cost_each(x,p)    
end

ForwardDiff.derivative(p) do p
    cost_each(x,p)    
end


Zygote.gradient(x) do x
    cost_each(x,p)    
end

Zygote.gradient(p) do p
    cost_each(x,p)    
end
# This returns (nothing,)

@mcabbott
Copy link
Owner

It also works with Forward and ReverseDiff. Is that normal?

With these, this package is not involved in derivatives at all. I suspect that this means ReverseDiff is tracking each number, not whole arrays, and will be quite slow, but haven't tested.

Also, a curious behavior that I have observed is that mapcols does not take the gradient of the parameters

I was confused for a bit, but this is in fact expected. The help says:

  Any arguments after the matrix are passed to f as scalars, i.e. `mapcols(f, m, args...) =
  reduce(hcat, f(col, args...) for col in eeachcol(m))`. They do not get sliced/iterated (unlike
  map), nor are their gradients tracked.

  Note that if `f` itself contains parameters, their gradients are also not tracked.

This was enough for what I needed, I don't quite recall whether tracking or accumulating the gradient of f (which contains p) was blocked by something particular.

@raphaelchinchilla
Copy link
Author

I suspect that this means ReverseDiff is tracking each number, not whole arrays, and will be quite slow, but haven't tested.

With some light testing with the functions above, it seems that ReverseDiff is about 5 times slower than Zygote.

The gradient rules for this will only work with Zygote.

Is there any technical reason not to implement them? I would gladly look into it. Or should we just hope that stack will be merged soon enough and this would be a waste of time?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants