Skip to content

Commit f24e492

Browse files
author
dujw
committed
add warp_perspective kernel
1 parent d577fba commit f24e492

File tree

3 files changed

+142
-1
lines changed

3 files changed

+142
-1
lines changed

.vscode/settings.json

+3-1
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,9 @@
7474
"cinttypes": "cpp",
7575
"typeindex": "cpp",
7676
"valarray": "cpp",
77-
"bit": "cpp"
77+
"bit": "cpp",
78+
"__functional_base": "cpp",
79+
"locale": "cpp"
7880
},
7981
"workbench.tree.expandMode": "doubleClick"
8082
}

src/tensorRT/common/preprocess_kernel.cu

+109
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,99 @@ namespace CUDAKernel{
105105
*pdst_c2 = c2;
106106
}
107107

108+
__global__ void warp_perspective_kernel(uint8_t* src, int src_line_size, int src_width, int src_height, float* dst, int dst_width, int dst_height,
109+
uint8_t const_value_st, float* warp_affine_matrix_3_3, Norm norm, int edge){
110+
111+
int position = blockDim.x * blockIdx.x + threadIdx.x;
112+
if (position >= edge) return;
113+
114+
float m_x1 = warp_affine_matrix_3_3[0];
115+
float m_y1 = warp_affine_matrix_3_3[1];
116+
float m_z1 = warp_affine_matrix_3_3[2];
117+
118+
float m_x2 = warp_affine_matrix_3_3[3];
119+
float m_y2 = warp_affine_matrix_3_3[4];
120+
float m_z2 = warp_affine_matrix_3_3[5];
121+
122+
float m_x3 = warp_affine_matrix_3_3[6];
123+
float m_y3 = warp_affine_matrix_3_3[7];
124+
float m_z3 = warp_affine_matrix_3_3[8];
125+
126+
int dx = position % dst_width;
127+
int dy = position / dst_width;
128+
129+
// 原图位置
130+
float src_x = (m_x1 * dx + m_y1 * dy + m_z1)/(m_x3 * dx + m_y3 * dy + m_z3);
131+
float src_y = (m_x2 * dx + m_y2 * dy + m_z2)/(m_x3 * dx + m_y3 * dy + m_z3);
132+
float c0, c1, c2;
133+
134+
if(src_x <= -1 || src_x >= src_width || src_y <= -1 || src_y >= src_height){
135+
// out of range
136+
c0 = const_value_st;
137+
c1 = const_value_st;
138+
c2 = const_value_st;
139+
}else{
140+
int y_low = floorf(src_y);
141+
int x_low = floorf(src_x);
142+
int y_high = y_low + 1;
143+
int x_high = x_low + 1;
144+
145+
uint8_t const_value[] = {const_value_st, const_value_st, const_value_st};
146+
float ly = src_y - y_low;
147+
float lx = src_x - x_low;
148+
float hy = 1 - ly;
149+
float hx = 1 - lx;
150+
float w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;
151+
uint8_t* v1 = const_value;
152+
uint8_t* v2 = const_value;
153+
uint8_t* v3 = const_value;
154+
uint8_t* v4 = const_value;
155+
if(y_low >= 0){
156+
if (x_low >= 0)
157+
v1 = src + y_low * src_line_size + x_low * 3;
158+
159+
if (x_high < src_width)
160+
v2 = src + y_low * src_line_size + x_high * 3;
161+
}
162+
163+
if(y_high < src_height){
164+
if (x_low >= 0)
165+
v3 = src + y_high * src_line_size + x_low * 3;
166+
167+
if (x_high < src_width)
168+
v4 = src + y_high * src_line_size + x_high * 3;
169+
}
170+
171+
// same to opencv
172+
c0 = floorf(w1 * v1[0] + w2 * v2[0] + w3 * v3[0] + w4 * v4[0] + 0.5f);
173+
c1 = floorf(w1 * v1[1] + w2 * v2[1] + w3 * v3[1] + w4 * v4[1] + 0.5f);
174+
c2 = floorf(w1 * v1[2] + w2 * v2[2] + w3 * v3[2] + w4 * v4[2] + 0.5f);
175+
}
176+
177+
if(norm.channel_type == ChannelType::Invert){
178+
float t = c2;
179+
c2 = c0; c0 = t;
180+
}
181+
182+
if(norm.type == NormType::MeanStd){
183+
c0 = (c0 * norm.alpha - norm.mean[0]) / norm.std[0];
184+
c1 = (c1 * norm.alpha - norm.mean[1]) / norm.std[1];
185+
c2 = (c2 * norm.alpha - norm.mean[2]) / norm.std[2];
186+
}else if(norm.type == NormType::AlphaBeta){
187+
c0 = c0 * norm.alpha + norm.beta;
188+
c1 = c1 * norm.alpha + norm.beta;
189+
c2 = c2 * norm.alpha + norm.beta;
190+
}
191+
192+
int area = dst_width * dst_height;
193+
float* pdst_c0 = dst + dy * dst_width + dx;
194+
float* pdst_c1 = pdst_c0 + area;
195+
float* pdst_c2 = pdst_c1 + area;
196+
*pdst_c0 = c0;
197+
*pdst_c1 = c1;
198+
*pdst_c2 = c2;
199+
}
200+
108201
__global__ void warp_affine_bilinear_and_normalize_plane_kernel(uint8_t* src, int src_line_size, int src_width, int src_height, float* dst, int dst_width, int dst_height,
109202
uint8_t const_value_st, float* warp_affine_matrix_2_3, Norm norm, int edge){
110203

@@ -394,6 +487,22 @@ namespace CUDAKernel{
394487
));
395488
}
396489

490+
void warp_perspective(
491+
uint8_t* src, int src_line_size, int src_width, int src_height, float* dst, int dst_width, int dst_height,
492+
float* matrix_3_3, uint8_t const_value, const Norm& norm, cudaStream_t stream
493+
)
494+
{
495+
int jobs = dst_width * dst_height;
496+
auto grid = CUDATools::grid_dims(jobs);
497+
auto block = CUDATools::block_dims(jobs);
498+
499+
checkCudaKernel(warp_perspective_kernel << <grid, block, 0, stream >> > (
500+
src, src_line_size,
501+
src_width, src_height, dst,
502+
dst_width, dst_height, const_value, matrix_3_3, norm, jobs
503+
));
504+
}
505+
397506
void resize_bilinear_and_normalize(
398507
uint8_t* src, int src_line_size, int src_width, int src_height, float* dst, int dst_width, int dst_height,
399508
const Norm& norm,

src/tensorRT/common/preprocess_kernel.cuh

+30
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,36 @@ namespace CUDAKernel{
5050
float* matrix_2_3, uint8_t const_value, const Norm& norm,
5151
cudaStream_t stream);
5252

53+
// 可以用来图像校正、图像旋转等等 (测试比cpu快10倍以上)
54+
// 使用示范:
55+
// float* matrix_3_3 = nullptr;
56+
// size_t matrix_bytes = 3 * 3 * sizeof(f32);
57+
// checkCudaRuntime(cudaMalloc(&matrix_3_3, matrix_bytes));
58+
// checkCudaRuntime(cudaMemset(matrix_3_3, 0, matrix_bytes));
59+
//
60+
// #左上、右上、右下、左下 原图像四个点的坐标
61+
// cv::Point2f src_points[] = {
62+
// vctvctPoints[nImageIdx][0],
63+
// vctvctPoints[nImageIdx][1],
64+
// vctvctPoints[nImageIdx][2],
65+
// vctvctPoints[nImageIdx][3]};
66+
//
67+
// #左上、右上、左下、右下(Z 字形排列) 目标图像四个点的坐标
68+
// cv::Point2f dst_points[] = {
69+
// cv::Point2f(0, 0),
70+
// cv::Point2f(nw-1, 0),
71+
// cv::Point2f(0, nh-1),
72+
// cv::Point2f(nw-1, nh-1) };
73+
// 利用opencv 得到变换矩阵 dst -> src 的 矩阵
74+
// cv::Mat Perspect_Matrix = cv::getPerspectiveTransform(dst_points, src_points);
75+
// Perspect_Matrix.convertTo(Perspect_Matrix, CV_32FC1);
76+
// 拷贝到 gpu
77+
// checkCudaRuntime(cudaMemcpy(matrix_3_3, Perspect_Matrix.data, matrix_bytes, cudaMemcpyHostToDevice));
78+
void warp_perspective(
79+
uint8_t* src, int src_line_size, int src_width, int src_height, float* dst, int dst_width, int dst_height,
80+
float* matrix_3_3, uint8_t const_value, const Norm& norm, cudaStream_t stream
81+
);
82+
5383
void norm_feature(
5484
float* feature_array, int num_feature, int feature_length,
5585
cudaStream_t stream

0 commit comments

Comments
 (0)