Skip to content

Commit 126e77d

Browse files
committed
Merge commit 'e9b05c71b4acf210fad719f4da8bb58a425dd00b'
2 parents 53eec78 + e9b05c7 commit 126e77d

File tree

3 files changed

+107
-0
lines changed

3 files changed

+107
-0
lines changed

torch/lib/THCUNN/GatedLinearUnit.cu

+30
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
#include "THCUNN.h"
2+
#include "THCHalf.h"
3+
#include "THCHalfAutoNumerics.cuh"
4+
#include <THC/THCApply.cuh>
5+
#include "common.h"
6+
7+
template <typename Dtype, typename Acctype>
8+
struct gatedLinearCSigMul_functor
9+
{
10+
__device__ void operator()(Dtype *target, const Dtype *sigTensor, const Dtype *mulTensor) const
11+
{
12+
const Acctype sigNum = Acctype(1)/(Acctype(1)+ exp(ScalarConvert<Dtype, Acctype>::to(-*sigTensor)));
13+
const Dtype mulNum = *mulTensor;
14+
*target = ScalarConvert<Acctype, Dtype>::to(sigNum * mulNum);
15+
}
16+
};
17+
18+
template <typename Dtype, typename Acctype>
19+
struct gatedLinearDerivativeSecondHalf_functor
20+
{
21+
__device__ void operator()(Dtype *target, const Dtype *sigTensor, const Dtype *mulTensor) const
22+
{
23+
const Acctype sigNum = Acctype(1)/(Acctype(1)+ exp(ScalarConvert<Dtype, Acctype>::to(-*sigTensor)));
24+
const Dtype mulNum = *mulTensor;
25+
*target *= ScalarConvert<Acctype, Dtype>::to((Acctype(1) - sigNum) * sigNum * mulNum);
26+
}
27+
};
28+
29+
#include "generic/GatedLinearUnit.cu"
30+
#include "THCGenerateFloatTypes.h"
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
#ifndef THC_GENERIC_FILE
2+
#define THC_GENERIC_FILE "generic/GatedLinearUnit.cu"
3+
#else
4+
5+
void THNN_(GatedLinear_updateOutput)(
6+
THCState *state,
7+
THCTensor *input,
8+
THCTensor *output,
9+
int dim)
10+
{
11+
THCUNN_assertSameGPU(state, 2, input, output);
12+
13+
// size output to half of input
14+
dim = dim - 1;
15+
const long nIn = THCTensor_(size)(state, input, dim);
16+
THArgCheck(nIn % 2 == 0, 2, "Halving dimension must be even. Dim %d is size %ld", dim+1, nIn);
17+
const long inputSize = THCTensor_(size)(state, input, dim) / 2;
18+
THLongStorage *newSizes = THCTensor_(newSizeOf)(state, input);
19+
THLongStorage_set(newSizes, dim, inputSize);
20+
THCTensor_(resize)(state, output, newSizes, NULL);
21+
22+
// halve tensor
23+
THCTensor *firstHalf = THCTensor_(newNarrow)(state, input, dim, 0, inputSize);
24+
THCTensor *secondHalf = THCTensor_(newNarrow)(state, input, dim, inputSize, inputSize);
25+
26+
// x = x1:cmul( sigmoid(x2) )
27+
THC_pointwiseApply3(state, output, secondHalf, firstHalf, gatedLinearCSigMul_functor<real, accreal>());
28+
29+
THLongStorage_free(newSizes);
30+
THCTensor_(free)(state, firstHalf);
31+
THCTensor_(free)(state, secondHalf);
32+
}
33+
34+
void THNN_(GatedLinear_updateGradInput)(
35+
THCState *state,
36+
THCTensor *input,
37+
THCTensor *gradOutput,
38+
THCTensor *gradInput,
39+
int dim)
40+
{
41+
THCUNN_assertSameGPU(state, 2, gradOutput, gradInput);
42+
dim = dim - 1;
43+
const long nIn = THCTensor_(size)(state, input, dim);
44+
THArgCheck(nIn % 2 == 0, 2, "Halving dimension must be even. Dim %d is size %ld", dim+1, nIn);
45+
46+
THCTensor_(resizeAs)(state, gradInput, input);
47+
const long inputSize = THCTensor_(size)(state, input, dim) / 2;
48+
THCTensor *firstHalf = THCTensor_(newNarrow)(state, input, dim, 0, inputSize);
49+
THCTensor *secondHalf = THCTensor_(newNarrow)(state, input, dim, inputSize, inputSize);
50+
THCTensor *gradInputfirstHalf = THCTensor_(newNarrow)(state, gradInput, dim, 0, inputSize);
51+
THCTensor *gradInputsecondHalf = THCTensor_(newNarrow)(state, gradInput, dim, inputSize, inputSize);
52+
// first half of derivative
53+
THC_pointwiseApply3(state, gradInputfirstHalf, secondHalf, gradOutput, gatedLinearCSigMul_functor<real, accreal>());
54+
// second half of derivative
55+
THCTensor_(copy)(state, gradInputsecondHalf, firstHalf);
56+
THC_pointwiseApply3(state, gradInputsecondHalf, secondHalf, gradOutput, gatedLinearDerivativeSecondHalf_functor<real, accreal>());
57+
58+
THCTensor_(free)(state, firstHalf);
59+
THCTensor_(free)(state, secondHalf);
60+
THCTensor_(free)(state, gradInputfirstHalf);
61+
THCTensor_(free)(state, gradInputsecondHalf);
62+
}
63+
64+
#endif

torch/lib/THCUNN/generic/THCUNN.h

+13
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,19 @@ TH_API void THNN_(HardTanh_updateGradInput)(
138138
real max_val,
139139
bool inplace);
140140

141+
TH_API void THNN_(GatedLinear_updateOutput)(
142+
THCState *state,
143+
THCTensor *input,
144+
THCTensor *output,
145+
int dim);
146+
147+
TH_API void THNN_(GatedLinear_updateGradInput)(
148+
THCState *state,
149+
THCTensor *input,
150+
THCTensor *gradOutput,
151+
THCTensor *gradInput,
152+
int dim);
153+
141154
TH_API void THNN_(LeakyReLU_updateOutput)(
142155
THCState *state,
143156
THCTensor *input,

0 commit comments

Comments
 (0)