Skip to content

Commit aa2e0b4

Browse files
Add files via upload
add objectbox inference code
1 parent 0df5fe0 commit aa2e0b4

File tree

1 file changed

+335
-0
lines changed

1 file changed

+335
-0
lines changed

src/application/objectbox.cpp

+335
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,335 @@
1+
#include "objectbox.hpp"
2+
#include <atomic>
3+
#include <mutex>
4+
#include <queue>
5+
#include <condition_variable>
6+
#include <infer/trt_infer.hpp>
7+
#include <common/ilogger.hpp>
8+
#include <common/infer_controller.hpp>
9+
#include <common/preprocess_kernel.cuh>
10+
#include <common/monopoly_allocator.hpp>
11+
#include <common/cuda_tools.hpp>
12+
13+
14+
namespace Objdetectbox {
15+
using namespace cv;
16+
using namespace std;
17+
18+
struct AffineMatrix {
19+
float i2d[6]; // image to dst(network), 2x3 matrix
20+
float d2i[6]; // dst to image, 2x3 matrix
21+
22+
void compute(const cv::Size& from, const cv::Size& to) {
23+
float scale_x = to.width / (float)from.width;
24+
float scale_y = to.height / (float)from.height;
25+
26+
float scale = std::min(scale_x, scale_y);
27+
28+
i2d[0] = scale; i2d[1] = 0; i2d[2] = -scale * from.width * 0.5 + to.width * 0.5 + scale * 0.5 - 0.5;
29+
i2d[3] = 0; i2d[4] = scale; i2d[5] = -scale * from.height * 0.5 + to.height * 0.5 + scale * 0.5 - 0.5;
30+
31+
cv::Mat m2x3_i2d(2, 3, CV_32F, i2d);
32+
cv::Mat m2x3_d2i(2, 3, CV_32F, d2i);
33+
cv::invertAffineTransform(m2x3_i2d, m2x3_d2i);
34+
}
35+
36+
cv::Mat i2d_mat() {
37+
return cv::Mat(2, 3, CV_32F, i2d);
38+
}
39+
};
40+
41+
42+
using ControllerImpl = InferController
43+
<
44+
Mat, // input
45+
BoxArray, // output
46+
tuple<string, int>, // start param
47+
AffineMatrix // additional
48+
>;
49+
class InferImpl : public Infer, public ControllerImpl {
50+
public:
51+
/** 要求在InferImpl里面执行stop,而不是在基类执行stop **/
52+
virtual ~InferImpl() {
53+
stop();
54+
}
55+
56+
pair<int, float> getBestClass(float* data, int start_ind, int length) {
57+
int max_ind = -1;
58+
float max_val = 0.0;
59+
for (int i = start_ind; i < length; i++) {
60+
if (data[i] > max_val) {
61+
max_val = data[i];
62+
max_ind = i;
63+
}
64+
}
65+
return pair<int, float>(max_ind - start_ind, max_val);
66+
}
67+
68+
static tuple<float, float> affine_project(float x, float y, float* pmatrix) {
69+
70+
float newx = x * pmatrix[0] + y * pmatrix[1] + pmatrix[2];
71+
float newy = x * pmatrix[3] + y * pmatrix[4] + pmatrix[5];
72+
return make_tuple(newx, newy);
73+
}
74+
75+
static float iou(const Box& a, const Box& b) {
76+
float cleft = max(a.left, b.left);
77+
float ctop = max(a.top, b.top);
78+
float cright = min(a.right, b.right);
79+
float cbottom = min(a.bottom, b.bottom);
80+
81+
float c_area = max(cright - cleft, 0.0f) * max(cbottom - ctop, 0.0f);
82+
if (c_area == 0.0f)
83+
return 0.0f;
84+
85+
float a_area = max(0.0f, a.right - a.left) * max(0.0f, a.bottom - a.top);
86+
float b_area = max(0.0f, b.right - b.left) * max(0.0f, b.bottom - b.top);
87+
return c_area / (a_area + b_area - c_area);
88+
}
89+
90+
static BoxArray cpu_nms(BoxArray& boxes, float threshold) {
91+
92+
std::sort(boxes.begin(), boxes.end(), [](BoxArray::const_reference a, BoxArray::const_reference b) {
93+
return a.confidence > b.confidence;
94+
});
95+
96+
BoxArray output;
97+
output.reserve(boxes.size());
98+
99+
vector<bool> remove_flags(boxes.size());
100+
for (int i = 0; i < boxes.size(); ++i) {
101+
102+
if (remove_flags[i]) continue;
103+
104+
auto& a = boxes[i];
105+
output.emplace_back(a);
106+
107+
for (int j = i + 1; j < boxes.size(); ++j) {
108+
if (remove_flags[j]) continue;
109+
110+
auto& b = boxes[j];
111+
if (b.class_label == a.class_label) {
112+
if (iou(a, b) >= threshold)
113+
remove_flags[j] = true;
114+
}
115+
}
116+
}
117+
return output;
118+
}
119+
120+
virtual bool startup(const string& file, int gpuid, float confidence_threshold, float nms_threshold) {
121+
normalize_ = CUDAKernel::Norm::alpha_beta(1 / 255.0f, 0.0f, CUDAKernel::ChannelType::Invert);
122+
confidence_threshold_ = confidence_threshold;
123+
nms_threshold_ = nms_threshold;
124+
return ControllerImpl::startup(make_tuple(file, gpuid));
125+
}
126+
127+
virtual void worker(promise<bool>& result) override {
128+
129+
string file = get<0>(start_param_);
130+
int gpuid = get<1>(start_param_);
131+
132+
TRT::set_device(gpuid);
133+
auto engine = TRT::load_infer(file);
134+
if (engine == nullptr) {
135+
INFOE("Engine %s load failed", file.c_str());
136+
result.set_value(false);
137+
return;
138+
}
139+
140+
engine->print();
141+
142+
const int MAX_IMAGE_BBOX = 1024;
143+
const int NUM_BOX_ELEMENT = 7; // left, top, right, bottom, confidence, class, keepflag
144+
TRT::Tensor affin_matrix_device(TRT::DataType::Float);
145+
int max_batch_size = engine->get_max_batch_size();
146+
auto input = engine->tensor("input.1");
147+
auto output = engine->tensor("1428");
148+
int num_classes = output->size(2) - 5;
149+
150+
151+
input_width_ = input->size(3);
152+
input_height_ = input->size(2);
153+
tensor_allocator_ = make_shared<MonopolyAllocator<TRT::Tensor>>(max_batch_size * 2);
154+
stream_ = engine->get_stream();
155+
gpu_ = gpuid;
156+
result.set_value(true);
157+
158+
input->resize_single_dim(0, max_batch_size).to_gpu();
159+
affin_matrix_device.set_stream(stream_);
160+
161+
// the nubmer 8 here means 8 * sizeof(float) % 32 == 0
162+
affin_matrix_device.resize(max_batch_size, 8).to_gpu();
163+
164+
165+
vector<Job> fetch_jobs;
166+
while (get_jobs_and_wait(fetch_jobs, max_batch_size)) {
167+
168+
int infer_batch_size = fetch_jobs.size();
169+
input->resize_single_dim(0, infer_batch_size);
170+
171+
for (int ibatch = 0; ibatch < infer_batch_size; ++ibatch) {
172+
auto& job = fetch_jobs[ibatch];
173+
auto& mono = job.mono_tensor->data();
174+
175+
if (mono->get_stream() != stream_) {
176+
checkCudaRuntime(cudaStreamSynchronize(mono->get_stream()));
177+
}
178+
179+
affin_matrix_device.copy_from_gpu(affin_matrix_device.offset(ibatch), mono->get_workspace()->gpu(), 6);
180+
input->copy_from_gpu(input->offset(ibatch), mono->gpu(), mono->count());
181+
job.mono_tensor->release();
182+
}
183+
engine->forward(false);
184+
185+
for (int ibatch = 0; ibatch < infer_batch_size; ++ibatch) {
186+
auto& job = fetch_jobs[ibatch];
187+
float* image_based_output = output->cpu<float>(ibatch);
188+
auto& image_based_boxes = job.output;
189+
auto& affine_matrix = job.additional;
190+
191+
for (int i = 0; i < output->size(1); ++i) {
192+
float* boxinfo = output->cpu<float>(ibatch, i);
193+
if (boxinfo[4] <= confidence_threshold_)
194+
continue;
195+
196+
for (int j = 5; j < output->size(2); j++)
197+
boxinfo[j] *= boxinfo[4];
198+
199+
auto out_result = getBestClass(boxinfo, 5, num_classes);
200+
if (out_result.second < confidence_threshold_)
201+
continue;
202+
203+
float box_x = boxinfo[0] - boxinfo[2] / 2.;
204+
float box_y = boxinfo[1] - boxinfo[3] / 2.;
205+
float box_r = boxinfo[0] + boxinfo[2] / 2.;
206+
float box_b = boxinfo[1] + boxinfo[3] / 2.;
207+
208+
Point box_lt,box_rb;
209+
tie(box_lt.x, box_lt.y) = affine_project(box_x, box_y, job.additional.d2i);
210+
tie(box_rb.x, box_rb.y) = affine_project(box_r, box_b, job.additional.d2i);
211+
image_based_boxes.emplace_back(box_lt.x, box_lt.y, box_rb.x, box_rb.y, boxinfo[4], out_result.first);
212+
}
213+
image_based_boxes = cpu_nms(image_based_boxes, nms_threshold_);
214+
job.pro->set_value(job.output);
215+
}
216+
fetch_jobs.clear();
217+
}
218+
stream_ = nullptr;
219+
tensor_allocator_.reset();
220+
INFO("Engine destroy.");
221+
}
222+
223+
virtual bool preprocess(Job& job, const Mat& image) override {
224+
225+
if (tensor_allocator_ == nullptr) {
226+
INFOE("tensor_allocator_ is nullptr");
227+
return false;
228+
}
229+
230+
job.mono_tensor = tensor_allocator_->query();
231+
if (job.mono_tensor == nullptr) {
232+
INFOE("Tensor allocator query failed.");
233+
return false;
234+
}
235+
236+
CUDATools::AutoDevice auto_device(gpu_);
237+
auto& tensor = job.mono_tensor->data();
238+
if (tensor == nullptr) {
239+
// not init
240+
tensor = make_shared<TRT::Tensor>();
241+
tensor->set_workspace(make_shared<TRT::MixMemory>());
242+
}
243+
244+
Size input_size(input_width_, input_height_);
245+
job.additional.compute(image.size(), input_size);
246+
247+
tensor->set_stream(stream_);
248+
tensor->resize(1, 3, input_height_, input_width_);
249+
250+
size_t size_image = image.cols * image.rows * 3;
251+
size_t size_matrix = iLogger::upbound(sizeof(job.additional.d2i), 32);
252+
auto workspace = tensor->get_workspace();
253+
uint8_t* gpu_workspace = (uint8_t*)workspace->gpu(size_matrix + size_image);
254+
float* affine_matrix_device = (float*)gpu_workspace;
255+
uint8_t* image_device = size_matrix + gpu_workspace;
256+
257+
uint8_t* cpu_workspace = (uint8_t*)workspace->cpu(size_matrix + size_image);
258+
float* affine_matrix_host = (float*)cpu_workspace;
259+
uint8_t* image_host = size_matrix + cpu_workspace;
260+
261+
memcpy(image_host, image.data, size_image);
262+
memcpy(affine_matrix_host, job.additional.d2i, sizeof(job.additional.d2i));
263+
checkCudaRuntime(cudaMemcpyAsync(image_device, image_host, size_image, cudaMemcpyHostToDevice, stream_));
264+
checkCudaRuntime(cudaMemcpyAsync(affine_matrix_device, affine_matrix_host, sizeof(job.additional.d2i), cudaMemcpyHostToDevice, stream_));
265+
266+
CUDAKernel::warp_affine_bilinear_and_normalize_plane(
267+
image_device, image.cols * 3, image.cols, image.rows,
268+
tensor->gpu<float>(), input_width_, input_height_,
269+
affine_matrix_device, 114, // note
270+
normalize_, stream_
271+
);
272+
return true;
273+
}
274+
275+
virtual vector<shared_future<BoxArray>> commits(const vector<Mat>& images) override {
276+
return ControllerImpl::commits(images);
277+
}
278+
279+
virtual std::shared_future<BoxArray> commit(const Mat& image) override {
280+
return ControllerImpl::commit(image);
281+
}
282+
283+
private:
284+
int input_width_ = 0;
285+
int input_height_ = 0;
286+
int gpu_ = 0;
287+
float confidence_threshold_ = 0;
288+
float nms_threshold_ = 0;
289+
TRT::CUStream stream_ = nullptr;
290+
CUDAKernel::Norm normalize_;
291+
};
292+
293+
shared_ptr<Infer> create_infer(const string& engine_file, int gpuid, float confidence_threshold, float nms_threshold) {
294+
shared_ptr<InferImpl> instance(new InferImpl());
295+
if (!instance->startup(engine_file, gpuid, confidence_threshold, nms_threshold)) {
296+
instance.reset();
297+
}
298+
return instance;
299+
}
300+
301+
void image_to_tensor(const cv::Mat& image, shared_ptr<TRT::Tensor>& tensor, int ibatch) {
302+
303+
auto normalize = CUDAKernel::Norm::alpha_beta(1 / 255.0f, 0.0f, CUDAKernel::ChannelType::Invert);
304+
Size input_size(tensor->size(3), tensor->size(2));
305+
AffineMatrix affine;
306+
affine.compute(image.size(), input_size);
307+
308+
size_t size_image = image.cols * image.rows * 3;
309+
size_t size_matrix = iLogger::upbound(sizeof(affine.d2i), 32);
310+
auto workspace = tensor->get_workspace();
311+
uint8_t* gpu_workspace = (uint8_t*)workspace->gpu(size_matrix + size_image);
312+
float* affine_matrix_device = (float*)gpu_workspace;
313+
uint8_t* image_device = size_matrix + gpu_workspace;
314+
315+
uint8_t* cpu_workspace = (uint8_t*)workspace->cpu(size_matrix + size_image);
316+
float* affine_matrix_host = (float*)cpu_workspace;
317+
uint8_t* image_host = size_matrix + cpu_workspace;
318+
auto stream = tensor->get_stream();
319+
320+
memcpy(image_host, image.data, size_image);
321+
memcpy(affine_matrix_host, affine.d2i, sizeof(affine.d2i));
322+
checkCudaRuntime(cudaMemcpyAsync(image_device, image_host, size_image, cudaMemcpyHostToDevice, stream));
323+
checkCudaRuntime(cudaMemcpyAsync(affine_matrix_device, affine_matrix_host, sizeof(affine.d2i), cudaMemcpyHostToDevice, stream));
324+
325+
CUDAKernel::warp_affine_bilinear_and_normalize_plane(
326+
image_device, image.cols * 3, image.cols, image.rows,
327+
tensor->gpu<float>(ibatch), input_size.width, input_size.height,
328+
affine_matrix_device, 114,
329+
normalize, stream
330+
);
331+
tensor->synchronize();
332+
}
333+
};
334+
335+

0 commit comments

Comments
 (0)