Skip to content

Implement predict_proba for RandomForestClassifier #288

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 15 commits into
base: development
Choose a base branch
from
Prev Previous commit
Next Next commit
fix test
  • Loading branch information
Mec-iS committed Jan 20, 2025
commit bb356e6a289209ad586eaf7af20217c12125a2ae
12 changes: 7 additions & 5 deletions src/ensemble/random_forest_classifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -833,8 +833,9 @@ mod tests {
)]
#[test]
fn test_random_forest_predict_proba() {
use num_traits::FromPrimitive;
// Iris-like dataset (subset)
let x = DenseMatrix::from_2d_array(&[
let x: DenseMatrix<f64> = DenseMatrix::from_2d_array(&[
&[5.1, 3.5, 1.4, 0.2],
&[4.9, 3.0, 1.4, 0.2],
&[4.7, 3.2, 1.3, 0.2],
Expand Down Expand Up @@ -881,21 +882,22 @@ mod tests {
// These values are approximate and based on typical random forest behavior
for i in 0..5 {
assert!(
*probas.get((i, 0)) > 0.6,
*probas.get((i, 0)) > f64::from_f32(0.6).unwrap(),
"Class 0 samples should have high probability for class 0"
);
assert!(
*probas.get((i, 1)) < 0.4,
*probas.get((i, 1)) < f64::from_f32(0.4).unwrap(),
"Class 0 samples should have low probability for class 1"
);
}

for i in 5..10 {
assert!(
*probas.get((i, 1)) > 0.6,
*probas.get((i, 1)) > f64::from_f32(0.6).unwrap(),
"Class 1 samples should have high probability for class 1"
);
assert!(
*probas.get((i, 0)) < 0.4,
*probas.get((i, 0)) < f64::from_f32(0.4).unwrap(),
"Class 1 samples should have low probability for class 0"
);
}
Expand Down
Loading