1
+ #ifndef THC_GENERIC_FILE
2
+ #define THC_GENERIC_FILE " generic/GatedLinearUnit.cu"
3
+ #else
4
+
5
+ void THNN_ (GatedLinear_updateOutput)(
6
+ THCState *state,
7
+ THCTensor *input,
8
+ THCTensor *output,
9
+ int dim)
10
+ {
11
+ THCUNN_assertSameGPU (state, 2 , input, output);
12
+
13
+ // size output to half of input
14
+ dim = dim - 1 ;
15
+ const long nIn = THCTensor_ (size)(state, input, dim);
16
+ THArgCheck (nIn % 2 == 0 , 2 , " Halving dimension must be even. Dim %d is size %ld" , dim+1 , nIn);
17
+ const long inputSize = THCTensor_ (size)(state, input, dim) / 2 ;
18
+ THLongStorage *newSizes = THCTensor_ (newSizeOf)(state, input);
19
+ THLongStorage_set (newSizes, dim, inputSize);
20
+ THCTensor_ (resize)(state, output, newSizes, NULL );
21
+
22
+ // halve tensor
23
+ THCTensor *firstHalf = THCTensor_ (newNarrow)(state, input, dim, 0 , inputSize);
24
+ THCTensor *secondHalf = THCTensor_ (newNarrow)(state, input, dim, inputSize, inputSize);
25
+
26
+ // x = x1:cmul( sigmoid(x2) )
27
+ THC_pointwiseApply3 (state, output, secondHalf, firstHalf, gatedLinearCSigMul_functor<real, accreal>());
28
+
29
+ THLongStorage_free (newSizes);
30
+ THCTensor_ (free )(state, firstHalf);
31
+ THCTensor_ (free )(state, secondHalf);
32
+ }
33
+
34
+ void THNN_ (GatedLinear_updateGradInput)(
35
+ THCState *state,
36
+ THCTensor *input,
37
+ THCTensor *gradOutput,
38
+ THCTensor *gradInput,
39
+ int dim)
40
+ {
41
+ THCUNN_assertSameGPU (state, 2 , gradOutput, gradInput);
42
+ dim = dim - 1 ;
43
+ const long nIn = THCTensor_ (size)(state, input, dim);
44
+ THArgCheck (nIn % 2 == 0 , 2 , " Halving dimension must be even. Dim %d is size %ld" , dim+1 , nIn);
45
+
46
+ THCTensor_ (resizeAs)(state, gradInput, input);
47
+ const long inputSize = THCTensor_ (size)(state, input, dim) / 2 ;
48
+ THCTensor *firstHalf = THCTensor_ (newNarrow)(state, input, dim, 0 , inputSize);
49
+ THCTensor *secondHalf = THCTensor_ (newNarrow)(state, input, dim, inputSize, inputSize);
50
+ THCTensor *gradInputfirstHalf = THCTensor_ (newNarrow)(state, gradInput, dim, 0 , inputSize);
51
+ THCTensor *gradInputsecondHalf = THCTensor_ (newNarrow)(state, gradInput, dim, inputSize, inputSize);
52
+ // first half of derivative
53
+ THC_pointwiseApply3 (state, gradInputfirstHalf, secondHalf, gradOutput, gatedLinearCSigMul_functor<real, accreal>());
54
+ // second half of derivative
55
+ THCTensor_ (copy)(state, gradInputsecondHalf, firstHalf);
56
+ THC_pointwiseApply3 (state, gradInputsecondHalf, secondHalf, gradOutput, gatedLinearDerivativeSecondHalf_functor<real, accreal>());
57
+
58
+ THCTensor_ (free )(state, firstHalf);
59
+ THCTensor_ (free )(state, secondHalf);
60
+ THCTensor_ (free )(state, gradInputfirstHalf);
61
+ THCTensor_ (free )(state, gradInputsecondHalf);
62
+ }
63
+
64
+ #endif
0 commit comments