Skip to content

Commit cd5dfbe

Browse files
authored
refactor rng_fn method (#212)
1 parent 064457e commit cd5dfbe

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

pymc_bart/bart.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -55,12 +55,12 @@ def rng_fn( # pylint: disable=W0237
5555
if not size:
5656
size = None
5757

58-
if isinstance(cls.Y, (TensorSharedVariable, TensorVariable)):
59-
Y = cls.Y.eval()
60-
else:
61-
Y = cls.Y
62-
6358
if not cls.all_trees:
59+
if isinstance(cls.Y, (TensorSharedVariable, TensorVariable)):
60+
Y = cls.Y.eval()
61+
else:
62+
Y = cls.Y
63+
6464
if size is not None:
6565
return np.full((size[0], Y.shape[0]), Y.mean())
6666
else:

0 commit comments

Comments
 (0)