-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_sub.rs
51 lines (38 loc) · 1.21 KB
/
test_sub.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
use custos::{Buffer, CPU};
use sliced::BinaryOpsMayGrad;
#[cfg(feature = "cpu")]
#[test]
fn test_sub() {
let device = CPU::<custos::Autograd<custos::Base>>::new();
let lhs = Buffer::from((&device, [1, 2, 3, 4, 5]));
let rhs = Buffer::from((&device, [6, 7, 8, 9, 10]));
let out = device.sub(&lhs, &rhs);
assert_eq!(out.read(), [-5, -5, -5, -5, -5]);
#[cfg(feature = "autograd")]
{
out.backward();
let grad = lhs.grad();
assert_eq!(grad.read(), [1, 1, 1, 1, 1]);
let grad = rhs.grad();
assert_eq!(grad.read(), [-1, -1, -1, -1, -1]);
}
}
#[cfg(feature = "opencl")]
#[test]
fn test_sub_cl() -> custos::Result<()> {
use custos::OpenCL;
let device = OpenCL::<custos::Autograd<custos::Base>>::new(0)?;
let lhs = Buffer::from((&device, [1, 2, 3, 4, 5]));
let rhs = Buffer::from((&device, [6, 7, 8, 9, 10]));
let out = device.sub(&lhs, &rhs);
assert_eq!(out.read(), [-5, -5, -5, -5, -5]);
#[cfg(feature = "autograd")]
{
out.backward();
let grad = lhs.grad();
assert_eq!(grad.read(), [1, 1, 1, 1, 1]);
let grad = rhs.grad();
assert_eq!(grad.read(), [-1, -1, -1, -1, -1]);
}
Ok(())
}