Highlights
- Arctic Code Vault Contributor
Create your own GitHub profile
Sign up for your own profile on GitHub, the best place to host code, manage projects, and build software alongside 50 million developers.
Sign upPinned
1,516 contributions in the last year
Contribution activity
December 2020
Created 7 commits in 2 repositories
Created a pull request in google/jax that received 2 comments
Ensure lax.convert_element_type preserves concrete inputs
There are several places in the JAX library where it is implicitly assumed that lax.convert_element_type
will return concrete results for concrete …
Opened 7 other pull requests in 1 repository
google/jax
6
merged
1
closed
- Cleanup: remove rng_factory boilerplate in lax_autodiff_test
- [jax2tf] Fix bug in conversion of bfloat16 to tensor
- Cleanup: remove unnecessary rng_factory boilerplate from linalg_test.py
- Cleanup: remove unnecessary rng_factory boilerplace in lax_test.py
- Cleanup: remove unnecessary rng_factory boilerplate in lax_numpy_test
- lax.convert_element_type: always return DeviceArray
- Better error for jsp.special.multigammaln
Reviewed 9 pull requests in 1 repository
google/jax 9 pull requests
- WIP: propagate weak types through unary and binary ops
- Upgrade trigonometric functions to primitives.
- Implement linear_ramp pad mode in jax.numpy.pad
- add jax.device_put_replicated
- device_put_sharded tweaks and docstring detail
- make convert_element_type_p not require old_dtype
- added prepend and append to diff
- Implement statistic pad mode in jax.numpy.pad
- TPUs support sorting complex operands
Created an issue in google/jax that received 1 comment
Infinite recursion in jax2tf with bfloat16 input.
Short repro: import jax.numpy as jnp from jax.experimental import jax2tf jax2tf.convert(lambda x: x)(jnp.bfloat16(1.0)) Traceback (most recent call…