Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Consider simplification or replacement of BlockArray #179

Closed
Michael-T-McCann opened this issue Jan 18, 2022 · 3 comments · Fixed by #259
Closed

Consider simplification or replacement of BlockArray #179

Michael-T-McCann opened this issue Jan 18, 2022 · 3 comments · Fixed by #259
Labels
discussion required Some discussion necessary to decide how to address this issue improvement Improvement of existing code, including addressing of omissions or inconsistencies

Comments

@Michael-T-McCann
Copy link
Contributor

Michael-T-McCann commented Jan 18, 2022

BlockArrays are/have been at the center of several issues (#173 #159 #149 #147 #146), are difficult to maintain (because they touch jax internals), and necessitate a large amount of code bloat in SCICO (scico.numpy, scico.scipy).

Let's use this issue to discuss possible replacements BlockArray. What functionality would this replacement need to provide? How could we go about it (implementation-wise)?

@bwohlberg bwohlberg changed the title Consider replacement of BlockArray Consider simplification or replacement of BlockArray Jan 18, 2022
@bwohlberg bwohlberg added enhancement New feature or request discussion required Some discussion necessary to decide how to address this issue labels Jan 18, 2022
@bwohlberg
Copy link
Collaborator

bwohlberg commented Jan 21, 2022

Current usage examples include:

  1. BlockArray is integral part of the operator.BiConlvolve interface
  2. BlockArray allows linop.FiniteDifference to be simultaneously applied to multiple axes when the output shape is not the same for all axes
  3. We do elementwise arithmetic on BlockArray, e.g.,
    r = b - Ax
  4. We compute norms on BlockArray, e.g.,
    num = r.ravel().conj().T @ z.ravel()
  5. We make BlockArrays filled with zeros, e.g., snp.zeros(((1, 2), (3, 4)))

@Michael-T-McCann
Copy link
Contributor Author

Michael-T-McCann commented Jan 25, 2022

Voting to close this. @crstngc

After consideration (and trying to remove scico.numpy in a branch) I am convinced (1) whatever BlockArray is, we will want helper functions/wrappers to make use of them. Prewrapping numpy in scico.numpy is good solution to this. (2) If we want to support x + y for BlockArrays, then they've got to be a class. (1) and (2) together make the current implementation look like a good solution.

The only modification that might make sense is to change the internal representation to a Tuple[DeviceArray] (i.e. derive from tuple), which (I think) would allow arrays on different devices and avoid some copying. It may also allow simplifications in the implementation (e.g., indexing) that would reduce coupling to jax internals.

Notes on a few options that inform the above reasoning:

  1. BlockArray goes away, all functionality comes from jax's pytrees, e.g., Tuple[DeviceArray]. Pros: no magic, no multi-page docs, no confusion about how broadcasting should work. If you know tuples, lists, dicts, you know how these work. These should automatically work with all of jax. Cons: all math becomes cumbersome (probably would want helper functions):
    - to add two together: jax.tree_map(operator.add, x, y).
    - call unary function: jax.tree_map(jnp.sqrt, x) or a wrapped function scico.sqrt(x)
    - reductions: jax.tree_util.tree_reduce or a wrapped function
    - creating one full of zeros: ugly but possible using tree_map's is_leaf argument.
  2. BlockArray is a class, but scico.numpy goes away. Pros: convenient syntax for some operators, users only have jnp and np to get confused about. Cons: some math remains cumbersome unless we add helper functions.
    - to add two together: x + y
    - other things, same as (1)
  3. (current status) BlockArray is a class, all/most of numpy is prewrapped to work in scico.numpy. Pros: already implemented, convenient from a user perspective. Cons: mentioned in the issue. Also, current implementation could possibly be improved (from contiguous array to Tuple[DeviceArray] underneath?).

@crstngc
Copy link
Contributor

crstngc commented Jan 25, 2022

This is a good summary of where we are and I agree that we can close the issue. I will explore the change to an internal representation based on Tuple[DeviceArray] and we can go from there.

@bwohlberg bwohlberg added improvement Improvement of existing code, including addressing of omissions or inconsistencies and removed enhancement New feature or request labels Jan 27, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
discussion required Some discussion necessary to decide how to address this issue improvement Improvement of existing code, including addressing of omissions or inconsistencies
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants