Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH Adds Array API support to LinearDiscriminantAnalysis #22554

Open
wants to merge 48 commits into
base: main
Choose a base branch
from

Conversation

thomasjpfan
Copy link
Member

@thomasjpfan thomasjpfan commented Feb 20, 2022

Reference Issues/PRs

Towards #22352

What does this implement/fix? Explain your changes.

This PR adds Array API support to LinearDiscriminantAnalysis. There is around a 14x runtime improvement when using Array API with CuPy on GPU.

The overall design principle is to use the Array API Specification as much as possible. In the short term, there will be an awkward transition as we need to support both NumPy and ArrayAPI. In the far term, the most maintainable position for the code base is to only use the Array API specification.

I extended the Array API spec in _ArrayAPIWrapper where these a feature we must have. In _NumPyApiWrapper, I added functions to the NumPy namespace adopt the functions in the Array API spec.

Any other comments?

There is still the question of how to communicated the feature. For this PR, I only implemented it for solver="svd".

@thomasjpfan thomasjpfan marked this pull request as draft Feb 23, 2022
@thomasjpfan thomasjpfan marked this pull request as ready for review Feb 28, 2022
Copy link
Member

@jjerphan jjerphan left a comment

Thank you, @thomasjpfan.

I can get a similar ×2 speed-up ratio on a machine with one NVIDIA Quadro RTX 6000 using the provided notebook.

This PR, with the Array API dispatch:

CPU times: user 12 s, sys: 858 ms, total: 12.8 s
Wall time: 14 s

This PR, without the Array API dispatch:

CPU times: user 1min 20s, sys: 1min 6s, total: 2min 27s
Wall time: 23.3 s

To me, this PR is clear and does not introduce too much complexity.

Do you think we could (if it's worth it) come up with adaptors for the few API mismatches (e.g add.at)?

doc/whats_new/v1.1.rst Outdated Show resolved Hide resolved
sklearn/utils/validation.py Outdated Show resolved Hide resolved
self.intercept_ = -0.5 * np.sum(coef**2, axis=1) + np.log(self.priors_)
self.coef_ = np.dot(coef, self.scalings_.T)
self.intercept_ -= np.dot(self.xbar_, self.coef_.T)
rank = xp.sum(xp.astype(S > self.tol * S[0], xp.int32))
Copy link
Member

@jjerphan jjerphan Mar 7, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Side-note: What is the intend of specifying xp.int32? Does it makes the sum faster?

Copy link
Member Author

@thomasjpfan thomasjpfan Mar 11, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ArrayAPI is very strict when it comes to bools. S > self.tol returns a boolean array which can not be summed.

I suspect it is because there is no type promotion rules for bools.

Copy link
Contributor

@rgommers rgommers Mar 13, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that sounds right. There are also no int-float mixed type casting rules for arrays, because erroring out is a valid design choice and something at least TensorFlow does (PyTorch also limits what it wants to allow without being explicit).

There could perhaps be a rule for Python bool to Python int, but there's probably little appetite for array dtype cross-kind casting rules.


X_np = numpy.asarray([[1, 2, 3]])

# Dispatching on
Copy link
Member

@jjerphan jjerphan Mar 7, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was this comment written entirely?

Suggested change
# Dispatching on
# This must dispatch on Numpy,
# regardless of the value of `array_api_dispatch`

@thomasjpfan
Copy link
Member Author

@thomasjpfan thomasjpfan commented Mar 7, 2022

Are your timings reversed? It looks like Array API makes it slower.

@jjerphan
Copy link
Member

@jjerphan jjerphan commented Mar 7, 2022

You are right, I just have corrected it.

Copy link
Member

@ogrisel ogrisel left a comment

Still familiarizing my-self with the Array API standard and the NumPy implementation but here is a first pass of comments for this PR.

sklearn/_config.py Show resolved Hide resolved
sklearn/utils/_array_api.py Show resolved Hide resolved
sklearn/utils/tests/test_array_api.py Show resolved Hide resolved
pytest.importorskip("numpy", minversion="1.22", reason="Requires Array API")

X_np = numpy.asarray([[1, 2, 3]])
xp = pytest.importorskip("numpy.array_api")
Copy link
Member

@ogrisel ogrisel Mar 8, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can simplify the 2 importorskip into a single one, right? NumPy 1.21 did not expose a numpy.array_api submodule (I checked).

Suggested change
pytest.importorskip("numpy", minversion="1.22", reason="Requires Array API")
X_np = numpy.asarray([[1, 2, 3]])
xp = pytest.importorskip("numpy.array_api")
# This test requires NumPy 1.22 or later for its implementation of the
# Array API specification:
with warnings.catch_warnings():
warnings.simplefilter("ignore") # ignore experimental warning
xp = pytest.importorskip("numpy.array_api")
X_np = numpy.asarray([[1, 2, 3]])

sklearn/utils/tests/test_array_api.py Show resolved Hide resolved
sklearn/utils/_array_api.py Show resolved Hide resolved
sklearn/utils/_array_api.py Show resolved Hide resolved
sklearn/utils/_array_api.py Show resolved Hide resolved
return getattr(self._namespace, name)

def astype(self, x, dtype, *, copy=True, casting="unsafe"):
# support casting for NumPy
Copy link
Member

@ogrisel ogrisel Mar 8, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# support casting for NumPy
# Extend Array API to support `casting` for NumPy containers

Copy link
Member

@ogrisel ogrisel Mar 8, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any issue to track the support of custom casting in the spec?

Copy link
Member Author

@thomasjpfan thomasjpfan Mar 12, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can not find a discussion on astype & casting. Maybe @rgommers has a link?

Copy link
Member Author

@thomasjpfan thomasjpfan Mar 12, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suspect it's because other libraries do not really implement casting for astype. For example, cupy.astype does not support casting.

Copy link
Member Author

@thomasjpfan thomasjpfan Mar 12, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For this specific PR, we are using casting in check_array (which was added in #14872):

array = array.astype(dtype, casting="unsafe", copy=False)

I think we do not need to set casting here since the default is unsafe. For reference, the casting behavior of nans and infs are not specified in the ArrayAPI spec

For example:

import numpy as np

np_float_arr = np.asarray([1, 2, np.nan], dtype=np.float32)
print(np_int_float.astype(np.int32))
# On a x86 machine:
# [          1           2 -2147483648]
# But on a M1 mac:
# [1, 2, 0]

# Cupy cast to zeros.
import cupy

cp_float_arr = cupy.asarray([1, 2, cupy.nan], dtype=cupy.float32)
print(cp_float_arr.astype(cupy.int32))
# [1, 2, 0]

Copy link
Contributor

@rgommers rgommers Mar 13, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For reference, the casting behavior of nans and infs are not specified in the ArrayAPI spec

That seems like something to specify - even if just to say it's undefined behavior. Which it is I think, as evidenced by the inconsistent NumPy results here.

I suspect it's because other libraries do not really implement casting for astype.

That is typically the reason. The PR that added astype lists all supported keywords across libraries, and only NumPy and Dask have casting.

A question is if the concept of casting modes is useful enough to include. I'm not sure to be honest (but I didn't think about it very hard yet). The default in numpy.ndarray.astype is unsafe anyway, which is the only reasonable choice probably - because code like astype(x_f64, float32) shouldn't raise as it is very explicit.

Copy link
Member

@ogrisel ogrisel Mar 14, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Out of curiosity I did a a quick survey of our current use of the casting kwarg.

git grep "casting="
sklearn/preprocessing/_polynomial.py:                        casting="no",
sklearn/tree/_tree.pyx:        return n_classes.astype(expected_dtype, casting="same_kind")
sklearn/tree/_tree.pyx:        return value_ndarray.astype(expected_dtype, casting='equiv')
sklearn/tree/_tree.pyx:    return node_ndarray.astype(expected_dtype, casting="same_kind")
sklearn/tree/tests/test_tree.py:    return node_ndarray.astype(new_dtype, casting="same_kind")
sklearn/tree/tests/test_tree.py:    return node_ndarray.astype(new_dtype, casting="same_kind")
sklearn/tree/tests/test_tree.py:    new_n_classes = n_classes.astype(new_dtype, casting="same_kind")
sklearn/utils/validation.py:                    # inf (numpy#14412). We cannot use casting='safe' because
sklearn/utils/validation.py:                    array = array.astype(dtype, casting="unsafe", copy=False)

We can ignore the tree files because they are written in Cython and will never benefit from Array API compat.

So we just have sklearn/preprocessing/_polynomial.py with casting="no" and sklearn/utils/validation.py with casting="unsafe".

So it's probably indeed not worth exposing the casting argument in our xp wrapper and always use unsafe.

sklearn/utils/_array_api.py Outdated Show resolved Hide resolved
@ogrisel
Copy link
Member

@ogrisel ogrisel commented Mar 8, 2022

I can get a similar speed-up ratio on a machine with one NVIDIA Quadro RTX 6000 using the provided notebook.

You report a ~2x speed-up instead of a 14x speed-up in @thomasjpfan notebook though. I am not sure if this is expected or not.

@thomasjpfan
Copy link
Member Author

@thomasjpfan thomasjpfan commented Mar 8, 2022

It could be because of different hardware. I ran my benchmarks using a Nvidia 3090 and a 5950x (16 core 32 thread) single CPU. It was also in a workstation environment where I can supply the GPU with 400 Watts of power.

@jjerphan
Copy link
Member

@jjerphan jjerphan commented Mar 10, 2022

You report a ~2x speed-up instead of a 14x speed-up in @thomasjpfan notebook though. I am not sure if this is expected or not.

Exact, I did a mistake comparing based on "total". Updated.

Copy link
Member

@adrinjalali adrinjalali left a comment

This doesn't look too complicated, better than what I imagined.

I do think we need to test against something which is not numpy, and figure out the performance implications.

array_api_dispatch : bool, default=None
Use Array API dispatching when inputs follow the Array API standard.
Default is False.
Copy link
Member

@adrinjalali adrinjalali Mar 10, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should add a note and maybe a link to a place where we have a list of estimators/parameter sets which actually support the array api.

Copy link
Member

@ogrisel ogrisel Mar 14, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 for a new dedicated user guide chapter on this. I think it could be a top level section.

Copy link
Member

@ogrisel ogrisel Mar 29, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

sklearn/utils/_array_api.py Show resolved Hide resolved
X = xp.exp(X)
else:
np.exp(X, out=X)
Copy link
Member

@adrinjalali adrinjalali Mar 10, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

one of the things that worries me is the performance implications of not having in-place operations.

Copy link
Member Author

@thomasjpfan thomasjpfan Mar 12, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea, there is not a great way around this without extending our own version of ArrayAPI further and special case NumPy.

The reasons for not having out= is in https://github.com/scikit-learn/scikit-learn/pull/22554/files#r825086171

sklearn/discriminant_analysis.py Outdated Show resolved Hide resolved
X = np.sqrt(fac) * (Xc / std)
X = math.sqrt(fac) * (Xc / std)
Copy link
Member

@adrinjalali adrinjalali Mar 10, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are they strictly equivalent? Shouldn't this be xp.sqrt instead?

Copy link
Member Author

@thomasjpfan thomasjpfan Mar 12, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

xp.sqrt does not work on Python scalars such as frac. We would need to call xp.asarray(fac) before calling xp.sqrt.

Copy link
Member

@adrinjalali adrinjalali Mar 15, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I see. It makes sense, but it also means a developer would need to know when to use xp. and when math., which I guess can be confusing. Should array api actually handle this?

Copy link
Member Author

@thomasjpfan thomasjpfan Mar 19, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but it also means a developer would need to know when to use xp. and when math., which I guess can be confusing. Should array api actually handle this?

From my understanding, this was by design for ArrayAPI. Only NumPy has the concept of "array scalars", while all other array libraries use a 0d array. (np.sqrt(python_scalar) returns an NumPy scalar, while math.sqrt(python_scalar) is a Python scalar)

In our case, the ArrayAPI forces us to think about using math for Python scalars. From a developer point of view, I agree it is one more think to think about, but I think it's better to be more strict about these types.

A side benefit is that the Python scalar version is faster:

%%timeit
_ = np.log(431.456)
# 315 ns ± 5.46 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

%%timeit
_ = math.log(431.456)
# 68.5 ns ± 0.519 ns per loop (mean ± std. dev. of 7 runs, 10000000 loops each)

which can be a difference for code that run in loops.

Copy link
Member Author

@thomasjpfan thomasjpfan Mar 21, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking at this again, I think it's better to do xp.asarray() on the scalar, so we can use xp.sqrt on it.

Copy link
Member

@adrinjalali adrinjalali Mar 22, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lolol, why? I was convinced with you last comment.

Copy link
Member Author

@thomasjpfan thomasjpfan Mar 22, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For me, the pros and cons of math vs xp.asarray on Python scalars is balanced. The argument for using xp.asarray is that it forces us to be in "array land" and not need to think about any Python scalar + Array interactions. Although, the ArrayAPI spec does state that that python_scalar * array is the same as xp.asarray(python_scalar, dtype=dtype_of_array) * array.

REF: https://discuss.scientific-python.org/t/poll-future-numpy-behavior-when-mixing-arrays-numpy-scalars-and-python-scalars/202

if self.solver == "svd":
X_new = np.dot(X - self.xbar_, self.scalings_)
X_new = (X - self.xbar_) @ self.scalings_
Copy link
Member

@adrinjalali adrinjalali Mar 10, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what happens if we train using one array type (e.g. cupy), and predict using another type, e.g. numpy ndarray?

The internal state of the estimator now seems to depend on the type of the input, do we want that or should we be consistent regarding the estimator state?

Copy link
Member Author

@thomasjpfan thomasjpfan Mar 12, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what happens if we train using one array type (e.g. cupy), and predict using another type, e.g. numpy ndarray?

Currently, this fails. Yes, the device the arrays are stored in depends on the input. One would need to transfer the arrays from the GPU to the CPU to be able to run inference on the CPU.

I do not think we should do this automatically, but it feels like a use case that should be supported.

Copy link
Member

@ogrisel ogrisel Mar 14, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We will probably need some explicit API to convert an estimators fitted with a given Array API namespace to be able to predict with another:

from sklearn.utils import convert_to_namespace

est_cupy = Estimator().fit(X_train_cupy, y_train_cupy)

est_numpy = convert_to_namespace(est_cupy, "numpy.array_api")
est_numpy.predict(X_test_numpy)

but we can handle that in a later PR.

Copy link
Member

@adrinjalali adrinjalali Mar 15, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then we need to properly document this. This is one of my main worries with this proposal.

Copy link
Member

@ogrisel ogrisel Mar 29, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's now documented as part of this PR itself (with a private/experimental function for now).

Copy link
Member

@ogrisel ogrisel Mar 30, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

def test_lda_array_api(X, y):
"""Check that the array_api Array gives the same results as ndarrays."""
pytest.importorskip("numpy", minversion="1.22", reason="Requires Array API")
xp = pytest.importorskip("numpy.array_api")
Copy link
Member

@adrinjalali adrinjalali Mar 10, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we kinda still have numpy array api specific code. I think we need something here which is not numpy's array api implementation to test.

Copy link
Member

@ogrisel ogrisel Mar 11, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We will be able to do this with pytorch in CPU mode once pytorch's compliance as improved.

Progress is tracked here:

pytorch/pytorch#58743

The other mature enough candidate is CuPy but this one requires maintaining a GPU CI runner. I would rather start with numpy only on our CI in the short term.

But we could improve this test to make it work with CuPy with a non-default parametrization:

@pytest.mark.parametrize("array_api_namespace", ["numpy.array_api", "cupy.array_api"])
@pytest.mark.parametrize("X, y", [(X, y), (X, y3)])
def test_lda_array_api(X, y, array_api_namespace):
    """Check that the array_api Array gives the same results as ndarrays."""
    xp = pytest.importorskip(array_api_namespace)
    ...

and this way it would be easy to run those compliance tests manually on a cuda enabled host.

Copy link
Member

@adrinjalali adrinjalali Mar 15, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would make sense for the array api effort to include a reference implementation. It can use numpy under the hoods, but it should be a minimal implementation, and it'll be different from numpy's own implementation since numpy has certain considerations which make its implementation not minimal. Then all libraries could test against that implementation instead of some random other library's implementation.

Copy link
Member Author

@thomasjpfan thomasjpfan Mar 19, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These is an Array API Compliance test suite here: https://github.com/data-apis/array-api-tests which test that an Array API implementation follows the spec. For a subset of operators, the test suite also test for correctness.

I see the numpy.array_api as the minimal implementation backed by NumPy. The idea around testing on another with another library's implementation is that the numerical operations can return different results depending on hardware. For us to trust that our algorithms are correct using CuPy's or PyTorch's Array API implementation, we would still need to test it ourselves.

sum_prob = np.sum(X, axis=1).reshape((-1, 1))

if is_array_api:
# array_api does not have `out=`
Copy link
Member

@ogrisel ogrisel Mar 11, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any plan or discussion for allowing this? Maybe as an optional API extension?

Copy link
Contributor

@rgommers rgommers Mar 11, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No there isn't. The reason is twofold:

  1. out= doesn't make sense for all libraries - for example, JAX and TensorFlow have immutable data structures.
  2. Even for libraries that do have mutable arrays, out= is not a very nice API pattern. It lets users do manual optimizations that a compiler may be able to do better. There was also another argument and an alternative design presented in data-apis/consortium-feedback#5.

And maybe (3), IIRC NumPy and PyTorch semantics for out= aren't identical.

Copy link
Member

@ogrisel ogrisel Mar 14, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the feedback. I guess we will have to keep on using those if is_array_api conditions to protect our use of numpy's out= arguments for now.

I don't necessarily see that as a block for the adoption of Array API in scikit-learn, but it does make the code looks uglier... I don't really see a potential longterm fix for this.

Copy link
Member

@adrinjalali adrinjalali Mar 15, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It lets users do manual optimizations that a compiler may be able to do better.

@rgommers I'm quite confused. Does python actually do such optimizations?

Copy link
Member

@adrinjalali adrinjalali Mar 15, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here's what I get on my machine:

In [19]: def f2():
    ...:     a = np.random.rand(10000, 100000)
    ...:     a = np.exp(a)
    ...:     return a
    ...: 

In [20]: def f1():
    ...:     a = np.random.rand(10000, 100000)
    ...:     np.exp(a, out=a)
    ...:     return a
    ...: 
In [23]: %timeit f1()
15.6 s ± 2.53 s per loop (mean ± std. dev. of 7 runs, 1 loop each)

In [24]: %timeit f2()
[1]    210906 killed     ipython

so the difference is quite significant (one of them gets killed 😁 )

Copy link
Contributor

@rgommers rgommers Mar 15, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does python actually do such optimizations?

No it doesn't - but Python doesn't have a compiler? I meant a JIT or AOT compiler like JAX's JIT or PyTorch's Torchscript. It is not strange to say that in principle X = xp.exp(X) can be rewritten to an inplace update (i.e., exp(X, out=X)) by a compiler transformation if and only if the memory backing X isn't used elsewhere, right?

This code looks fishy by the way for copy=False; the docs say nothing about inplace updating of the input array, which I'd consider a bug in library code if it were public. And to avoid this footgun, it defaults to copy=True which is always slow?

(one of them gets killed grin )

Ouch, that doesn't look like the expected result.

Copy link
Member

@adrinjalali adrinjalali Mar 15, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No it doesn't - but Python doesn't have a compiler? I meant a JIT or AOT compiler like JAX's JIT or PyTorch's Torchscript. It is not strange to say that in principle X = xp.exp(X) can be rewritten to an inplace update (i.e., exp(X, out=X)) by a compiler transformation if and only if the memory backing X isn't used elsewhere, right?

Ok now that makes sense. But here's the issue. One benefit of using array api is that developers can learn that instead of numpy's API and develop their estimators. It would also make it really easy to support array api even if they don't explicitly do so. But that raises the issue that in order to write efficient numpy code, one needs to do numpy, and for all other backends do array api, like here. I think in an ideal world, we wouldn't want to have these separate branches for the two APIs, do we?

Copy link
Contributor

@rgommers rgommers Mar 15, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think in an ideal world, we wouldn't want to have these separate branches for the two APIs, do we?

I think we indeed want to have separate branches in as few places as possible. A lot of the places that are being identified are functions that can be added to either the array API standard (e.g., take, moveaxis) or NumPy (e.g., unique_*). There's a few things that will remain though, because there's an inherent tension between portability and performance. out= and order= are the two that come to mind here. And some forms of (advanced) indexing.

We reviewed the use of out= and order=, and they're used less commonly then one would expect based on the amount of discussion around them. The amount of branching that will remain in the end is quite limited, and seems like a reasonable price to pay.

Copy link
Member

@adrinjalali adrinjalali Mar 15, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool, I'm happy then.

Copy link
Member

@ogrisel ogrisel Mar 29, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR now has a private helper (_asarray_with_order) to explicitly deal with the case where we want to enforce a specific order for numpy (e.g. when used in conjunction with Cython code) vs array API passthrough.

Copy link
Member

@ogrisel ogrisel left a comment

I think it's good to not make the order handling part of the wrapper by default to avoid implying that a future version of the Array API spec will support explicit memory layout constraints. But on the other hand we still factorize the redundant code in the input validation code.

sklearn/utils/validation.py Outdated Show resolved Hide resolved
sklearn/utils/validation.py Outdated Show resolved Hide resolved
Copy link
Member

@ogrisel ogrisel left a comment

For information, I just ran on a machine with an Nvidia GPU:

mamba create -n cuda cudatoolkit cupy numpy scipy cython joblib threadpoolctl pytest
conda activate cuda
pip install -e . --no-build-isolation

and I confirm that the LDA test passes successfully against cupy:

$  pytest -k test_lda_array_api -v -x sklearn/tests/test_discriminant_analysis.py
================================================================================================================================================ test session starts ================================================================================================================================================
platform linux -- Python 3.10.4, pytest-7.1.1, pluggy-1.0.0 -- /storage/store/work/ogrisel/mambaforge/envs/cuda/bin/python
cachedir: .pytest_cache
rootdir: /home/parietal/ogrisel/code/scikit-learn, configfile: setup.cfg
collected 44 items / 42 deselected / 2 selected                                                                                                                                                                                                                                                                     

sklearn/tests/test_discriminant_analysis.py::test_lda_array_api[numpy.array_api] PASSED                                                                                                                                                                                                                       [ 50%]
sklearn/tests/test_discriminant_analysis.py::test_lda_array_api[cupy.array_api] PASSED                                                                                                                                                                                                                        [100%]

========================================================================================================================================= 2 passed, 42 deselected in 2.24s ==========================================================================================================================================

@ogrisel
Copy link
Member

@ogrisel ogrisel commented Mar 28, 2022

The linter complains about:

sklearn/utils/tests/test_array_api.py:8:1: F401 'sklearn.utils._array_api._asarray_with_order' imported but unused

but instead of deleting this import statement, I think it might be worth to write a unit test for this helper.

@ogrisel
Copy link
Member

@ogrisel ogrisel commented Mar 29, 2022

@thomasjpfan I pushed some minimal doc for the experimental Array API support introduced in this PR. Feel free to edit and improve.

In the future we might want to automate the list of supported estimators with an estimator tag and a script but I think for the time being it's fine to have a static list.

@ogrisel
Copy link
Member

@ogrisel ogrisel commented Mar 29, 2022

Good idea to rename to "dispatching" to get more future proof URLs :)

@thomasjpfan
Copy link
Member Author

@thomasjpfan thomasjpfan commented Mar 29, 2022

Looks like we were working on the user guide at the same time. I added a "dispatching" section to the user guide where Array API is one of the methods to dispatch on. I'm thinking that the computational routines will also belong in "dispatching".

I also added a section on how to transfer from a cupy.array_api to numpy.ndarray. It's kind of hacky, so I do not know if we should include it. Although, I think it is a common use case. (Train on GPU, predict on CPU)

>>> # transform on a CPU
>>> X_trans = lda.transform(X_np)
>>> type(X_trans)
<class 'numpy.ndarray'>
Copy link
Member

@ogrisel ogrisel Mar 29, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We might want to provide an experimental helper, e.g. _convert_to_xp(model, target_namespace) to achieve this instead of putting an ugly code snippet in the doc. Here target_namespace can be an xp module or a namespace path (str).

Copy link
Member

@ogrisel ogrisel Mar 29, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually here we would rather need _convert_to_numpy or something similar.

Copy link
Member Author

@thomasjpfan thomasjpfan Mar 29, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Array API specification pushes for dlpack to do data exchange between different Array API implementation. I think it mostly works if the array is on the same device. For arrays on different devices, there is no public API that "just works".

With that in mind, I added a _convert_estimator_to_ndarray that converts from cupy.array_api or numpy.array_api into a numpy.ndarray.

Copy link
Contributor

@rgommers rgommers Mar 29, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A bit of context: there was a common understanding that automatic/implicit transfer of data between different devices is not a good idea and leads to hard to detect performance issues. It should be an explicit action by the user. That is why there is nothing that "just works".

Copy link
Member Author

@thomasjpfan thomasjpfan Mar 29, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this case, we want to transfer. The use case is to train on a GPU and deploy/predict on a CPU. Is there an API for a user to explicitly convert?

For from_dlpack to work, we need to use cupy.ndarray API to do the device transfer first, and then NumPy can convert:

import cupy.array_api as cu_xp
import numpy.array_api as np_xp

x_cu = cu_xp.asarray([1, 2, 3.4])

x_np = np_xp.from_dlpack(x_cu._array.get())

I can envision extending from_dlpack to do the transfer. For example:

X_np = np_xp.from_dlpack(x_cu, device="cpu")

where device maps to one of the DLDevice from dlpack.h.

Copy link
Contributor

@rgommers rgommers Mar 29, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, there isn't such an API, and no plan for it. Certainly shouldn't be mixed with DLPack I think, which is explicitly a zero-copy protocol (just like the buffer protocol, __array_interface__, __cuda_array_interface__, etc.). There are only library-specific methods for device transfers. xref pytorch/pytorch#36560 also for a relevant discussion in PyTorch, where a force= keyword was requested (makes some sense, but not implemented).

Before saying more, I should probably look at the code to understand why you want this, rather than leaving it to the user.

Copy link
Member

@ogrisel ogrisel Mar 29, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The usecase is:

  • fit a scikit-learn estimator on a GPU machine: the internal state are CuPy arrays backed by a GPU device managed by a CUDA runtime;
  • convert the estimator object to use numpy arrays instead;
  • pickle the trained model to a file, model store, docker container image, whatever;
  • load the pickle in a Python process on a machine without CUDA to compute predictions on new data using only the CPU and no extra dependencies beyond numpy and scipy.

Copy link
Contributor

@rgommers rgommers Mar 29, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes sense. In general, it's not supportable at least in the short term to have forcing (regular copy, device transfer, detach from autograd graph, etc.) conversions between any two array types from different libraries. Some libraries currently don't even offer this for "own kind to numpy", and a numpy-specific idiom won't be very attractive for maintainers of other libraries.

At the moment, the right thing to use is cupy.asnumpy() or cupy.ndarray.get() (see https://docs.cupy.dev/en/stable/reference/ndarray.html#conversion-to-from-numpy-arrays). You know you want a "to numpy" call, which is an easier ask than a "any to any other" generic API. There may be appetite to align the names of cupy.asnumpy()/torch.Tensor.numpy()/tf.Tensor.numpy() names, that seems like a logical next step. For now I'd say you want to special-case the libraries in some helper routine.

Copy link
Member

@ogrisel ogrisel left a comment

LGTM. Since it is private / experimental API with explicit activation, I think this is fine to merge this PR without a SLEP but I would be in favor of letting the time to other @scikit-learn/core-devs to express their feelings about this PR before merging (assuming we get a +2 first).

sklearn/tests/test_discriminant_analysis.py Outdated Show resolved Hide resolved
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants