16
16
17
17
import numpy as np
18
18
import numpy .typing as npt
19
+ import pymc as pm
20
+ import pytensor .tensor as pt
19
21
from numba import njit
20
22
from pymc .initial_point import PointType
21
23
from pymc .model import Model , modelcontext
@@ -120,15 +122,15 @@ class PGBART(ArrayStepShared):
120
122
"tune" : (bool , []),
121
123
}
122
124
123
- def __init__ ( # noqa: PLR0915
125
+ def __init__ ( # noqa: PLR0912, PLR0915
124
126
self ,
125
- vars = None , # pylint: disable=redefined-builtin
127
+ vars : list [ pm . Distribution ] | None = None ,
126
128
num_particles : int = 10 ,
127
129
batch : tuple [float , float ] = (0.1 , 0.1 ),
128
130
model : Optional [Model ] = None ,
129
131
initial_point : PointType | None = None ,
130
- compile_kwargs : dict | None = None , # pylint: disable=unused-argument
131
- ):
132
+ compile_kwargs : dict | None = None ,
133
+ ) -> None :
132
134
model = modelcontext (model )
133
135
if initial_point is None :
134
136
initial_point = model .initial_point ()
@@ -137,6 +139,10 @@ def __init__( # noqa: PLR0915
137
139
else :
138
140
vars = [model .rvs_to_values .get (var , var ) for var in vars ]
139
141
vars = inputvars (vars )
142
+
143
+ if vars is None :
144
+ raise ValueError ("Unable to find variables to sample" )
145
+
140
146
value_bart = vars [0 ]
141
147
self .bart = model .values_to_rvs [value_bart ].owner .op
142
148
@@ -325,7 +331,7 @@ def normalize(self, particles: list[ParticleTree]) -> float:
325
331
return wei / wei .sum ()
326
332
327
333
def resample (
328
- self , particles : list [ParticleTree ], normalized_weights : npt .NDArray [ np . float64 ]
334
+ self , particles : list [ParticleTree ], normalized_weights : npt .NDArray
329
335
) -> list [ParticleTree ]:
330
336
"""
331
337
Use systematic resample for all but the first particle
@@ -347,7 +353,7 @@ def resample(
347
353
return particles
348
354
349
355
def get_particle_tree (
350
- self , particles : list [ParticleTree ], normalized_weights : npt .NDArray [ np . float64 ]
356
+ self , particles : list [ParticleTree ], normalized_weights : npt .NDArray
351
357
) -> tuple [ParticleTree , Tree ]:
352
358
"""
353
359
Sample a new particle and associated tree
@@ -359,7 +365,7 @@ def get_particle_tree(
359
365
360
366
return new_particle , new_particle .tree
361
367
362
- def systematic (self , normalized_weights : npt .NDArray [ np . float64 ] ) -> npt .NDArray [np .int_ ]:
368
+ def systematic (self , normalized_weights : npt .NDArray ) -> npt .NDArray [np .int_ ]:
363
369
"""
364
370
Systematic resampling.
365
371
@@ -395,7 +401,7 @@ def update_weight(self, particle: ParticleTree, odim: int) -> None:
395
401
particle .log_weight = new_likelihood
396
402
397
403
@staticmethod
398
- def competence (var , has_grad ) :
404
+ def competence (var : pm . Distribution , has_grad : bool ) -> Competence :
399
405
"""PGBART is only suitable for BART distributions."""
400
406
dist = getattr (var .owner , "op" , None )
401
407
if isinstance (dist , BARTRV ):
@@ -406,12 +412,12 @@ def competence(var, has_grad):
406
412
class RunningSd :
407
413
"""Welford's online algorithm for computing the variance/standard deviation"""
408
414
409
- def __init__ (self , shape : tuple ) -> None :
415
+ def __init__ (self , shape : tuple [ int , ...] ) -> None :
410
416
self .count = 0 # number of data points
411
417
self .mean = np .zeros (shape ) # running mean
412
418
self .m_2 = np .zeros (shape ) # running second moment
413
419
414
- def update (self , new_value : npt .NDArray [ np . float64 ] ) -> Union [float , npt .NDArray [ np . float64 ] ]:
420
+ def update (self , new_value : npt .NDArray ) -> Union [float , npt .NDArray ]:
415
421
self .count = self .count + 1
416
422
self .mean , self .m_2 , std = _update (self .count , self .mean , self .m_2 , new_value )
417
423
return fast_mean (std )
@@ -420,10 +426,10 @@ def update(self, new_value: npt.NDArray[np.float64]) -> Union[float, npt.NDArray
420
426
@njit
421
427
def _update (
422
428
count : int ,
423
- mean : npt .NDArray [ np . float64 ] ,
424
- m_2 : npt .NDArray [ np . float64 ] ,
425
- new_value : npt .NDArray [ np . float64 ] ,
426
- ) -> tuple [npt .NDArray [ np . float64 ] , npt .NDArray [ np . float64 ] , Union [float , npt .NDArray [ np . float64 ] ]]:
429
+ mean : npt .NDArray ,
430
+ m_2 : npt .NDArray ,
431
+ new_value : npt .NDArray ,
432
+ ) -> tuple [npt .NDArray , npt .NDArray , Union [float , npt .NDArray ]]:
427
433
delta = new_value - mean
428
434
mean += delta / count
429
435
delta2 = new_value - mean
@@ -434,7 +440,7 @@ def _update(
434
440
435
441
436
442
class SampleSplittingVariable :
437
- def __init__ (self , alpha_vec : npt .NDArray [ np . float64 ] ) -> None :
443
+ def __init__ (self , alpha_vec : npt .NDArray ) -> None :
438
444
"""
439
445
Sample splitting variables proportional to `alpha_vec`.
440
446
@@ -547,16 +553,16 @@ def filter_missing_values(available_splitting_values, idx_data_points, missing_d
547
553
548
554
549
555
def draw_leaf_value (
550
- y_mu_pred : npt .NDArray [ np . float64 ] ,
551
- x_mu : npt .NDArray [ np . float64 ] ,
556
+ y_mu_pred : npt .NDArray ,
557
+ x_mu : npt .NDArray ,
552
558
m : int ,
553
- norm : npt .NDArray [ np . float64 ] ,
559
+ norm : npt .NDArray ,
554
560
shape : int ,
555
561
response : str ,
556
- ) -> tuple [npt .NDArray [ np . float64 ] , Optional [npt .NDArray [ np . float64 ] ]]:
562
+ ) -> tuple [npt .NDArray , Optional [npt .NDArray ]]:
557
563
"""Draw Gaussian distributed leaf values."""
558
564
linear_params = None
559
- mu_mean = np . empty ( shape )
565
+ mu_mean : npt . NDArray
560
566
if y_mu_pred .size == 0 :
561
567
return np .zeros (shape ), linear_params
562
568
@@ -571,7 +577,7 @@ def draw_leaf_value(
571
577
572
578
573
579
@njit
574
- def fast_mean (ari : npt .NDArray [ np . float64 ] ) -> Union [float , npt .NDArray [ np . float64 ] ]:
580
+ def fast_mean (ari : npt .NDArray ) -> Union [float , npt .NDArray ]:
575
581
"""Use Numba to speed up the computation of the mean."""
576
582
if ari .ndim == 1 :
577
583
count = ari .shape [0 ]
@@ -590,11 +596,11 @@ def fast_mean(ari: npt.NDArray[np.float64]) -> Union[float, npt.NDArray[np.float
590
596
591
597
@njit
592
598
def fast_linear_fit (
593
- x : npt .NDArray [ np . float64 ] ,
594
- y : npt .NDArray [ np . float64 ] ,
599
+ x : npt .NDArray ,
600
+ y : npt .NDArray ,
595
601
m : int ,
596
- norm : npt .NDArray [ np . float64 ] ,
597
- ) -> tuple [npt .NDArray [ np . float64 ] , list [npt .NDArray [ np . float64 ] ]]:
602
+ norm : npt .NDArray ,
603
+ ) -> tuple [npt .NDArray , list [npt .NDArray ]]:
598
604
n = len (x )
599
605
y = y / m + np .expand_dims (norm , axis = 1 )
600
606
@@ -678,17 +684,17 @@ def update(self):
678
684
679
685
@njit
680
686
def inverse_cdf (
681
- single_uniform : npt .NDArray [ np . float64 ] , normalized_weights : npt .NDArray [ np . float64 ]
687
+ single_uniform : npt .NDArray , normalized_weights : npt .NDArray
682
688
) -> npt .NDArray [np .int_ ]:
683
689
"""
684
690
Inverse CDF algorithm for a finite distribution.
685
691
686
692
Parameters
687
693
----------
688
- single_uniform: npt.NDArray[np.float64]
694
+ single_uniform: npt.NDArray
689
695
Ordered points in [0,1]
690
696
691
- normalized_weights: npt.NDArray[np.float64] )
697
+ normalized_weights: npt.NDArray)
692
698
Normalized weights
693
699
694
700
Returns
@@ -711,7 +717,7 @@ def inverse_cdf(
711
717
712
718
713
719
@njit
714
- def jitter_duplicated (array : npt .NDArray [ np . float64 ] , std : float ) -> npt .NDArray [ np . float64 ] :
720
+ def jitter_duplicated (array : npt .NDArray , std : float ) -> npt .NDArray :
715
721
"""
716
722
Jitter duplicated values.
717
723
"""
@@ -727,12 +733,17 @@ def jitter_duplicated(array: npt.NDArray[np.float64], std: float) -> npt.NDArray
727
733
728
734
729
735
@njit
730
- def are_whole_number (array : npt .NDArray [ np . float64 ] ) -> np .bool_ :
736
+ def are_whole_number (array : npt .NDArray ) -> np .bool_ :
731
737
"""Check if all values in array are whole numbers"""
732
738
return np .all (np .mod (array [~ np .isnan (array )], 1 ) == 0 )
733
739
734
740
735
- def logp (point , out_vars , vars , shared ): # pylint: disable=redefined-builtin
741
+ def logp (
742
+ point ,
743
+ out_vars : list [pm .Distribution ],
744
+ vars : list [pm .Distribution ],
745
+ shared : list [pt .TensorVariable ],
746
+ ):
736
747
"""Compile PyTensor function of the model and the input and output variables.
737
748
738
749
Parameters
0 commit comments