Skip to content
Snippets Groups Projects
Commit bf281645 authored by Nicolas Pope's avatar Nicolas Pope
Browse files

Allow selection of CT shape

parent 4afea678
No related branches found
No related tags found
1 merge request!351Circle census transform or SGM
Pipeline #33077 failed
This commit is part of merge request !351. Comments created here will be created in the context of that merge request.
...@@ -28,6 +28,7 @@ limitations under the License. ...@@ -28,6 +28,7 @@ limitations under the License.
#include <stdint.h> #include <stdint.h>
#include "libsgm_config.h" #include "libsgm_config.h"
#include "libsgm_parameters.hpp"
#include <cuda_runtime.h> #include <cuda_runtime.h>
#if defined(LIBSGM_SHARED) #if defined(LIBSGM_SHARED)
...@@ -71,7 +72,8 @@ namespace sgm { ...@@ -71,7 +72,8 @@ namespace sgm {
int P2; int P2;
float uniqueness; float uniqueness;
bool subpixel; bool subpixel;
Parameters(int P1 = 10, int P2 = 120, float uniqueness = 0.95f, bool subpixel = false) : P1(P1), P2(P2), uniqueness(uniqueness), subpixel(subpixel) {} CensusShape ct_shape;
Parameters(int P1 = 10, int P2 = 120, float uniqueness = 0.95f, bool subpixel = false, CensusShape ct_shape = CensusShape::CIRCLE_4_2) : P1(P1), P2(P2), uniqueness(uniqueness), subpixel(subpixel), ct_shape(ct_shape) {}
}; };
/** /**
......
#pragma once
namespace sgm {
enum class CensusShape {
CT_5X5,
CS_CT_9X7,
CIRCLE_4_2
};
}
...@@ -21,14 +21,11 @@ namespace sgm { ...@@ -21,14 +21,11 @@ namespace sgm {
namespace { namespace {
static constexpr int WINDOW_WIDTH = 5;
static constexpr int WINDOW_HEIGHT = 5;
static constexpr int BLOCK_SIZE = 128; static constexpr int BLOCK_SIZE = 128;
static constexpr int LINES_PER_BLOCK = 16; static constexpr int LINES_PER_BLOCK = 16;
/* Centre symmetric census */ /* Centre symmetric census */
template <typename T> template <typename T, int WINDOW_WIDTH, int WINDOW_HEIGHT>
__global__ void cs_census_transform_kernel( __global__ void cs_census_transform_kernel(
feature_type *dest, feature_type *dest,
const T *src, const T *src,
...@@ -104,7 +101,7 @@ __global__ void cs_census_transform_kernel( ...@@ -104,7 +101,7 @@ __global__ void cs_census_transform_kernel(
} }
} }
template <typename T> template <typename T, int WINDOW_WIDTH, int WINDOW_HEIGHT>
__global__ void census_transform_kernel( __global__ void census_transform_kernel(
feature_type* __restrict__ dest, feature_type* __restrict__ dest,
const T* __restrict__ src, const T* __restrict__ src,
...@@ -145,24 +142,27 @@ void enqueue_census_transform( ...@@ -145,24 +142,27 @@ void enqueue_census_transform(
int width, int width,
int height, int height,
int pitch, int pitch,
sgm::CensusShape ct_shape,
cudaStream_t stream) cudaStream_t stream)
{ {
/* Disable the original center symmetric algorithm */ /* Disable the original center symmetric algorithm */
if (false) { if (ct_shape == sgm::CensusShape::CS_CT_9X7) {
const int width_per_block = BLOCK_SIZE - WINDOW_WIDTH + 1; const int width_per_block = BLOCK_SIZE - 9 + 1;
const int height_per_block = LINES_PER_BLOCK; const int height_per_block = LINES_PER_BLOCK;
const dim3 gdim( const dim3 gdim(
(width + width_per_block - 1) / width_per_block, (width + width_per_block - 1) / width_per_block,
(height + height_per_block - 1) / height_per_block); (height + height_per_block - 1) / height_per_block);
const dim3 bdim(BLOCK_SIZE); const dim3 bdim(BLOCK_SIZE);
cs_census_transform_kernel<<<gdim, bdim, 0, stream>>>(dest, src, width, height, pitch); cs_census_transform_kernel<T, 9, 7><<<gdim, bdim, 0, stream>>>(dest, src, width, height, pitch);
} else { } else if (ct_shape == sgm::CensusShape::CT_5X5) {
static constexpr int THREADS_X = 16; static constexpr int THREADS_X = 16;
static constexpr int THREADS_Y = 16; static constexpr int THREADS_Y = 16;
const dim3 gdim((width + THREADS_X - 1)/THREADS_X, (height + THREADS_Y - 1)/THREADS_Y); const dim3 gdim((width + THREADS_X - 1)/THREADS_X, (height + THREADS_Y - 1)/THREADS_Y);
const dim3 bdim(THREADS_X, THREADS_Y); const dim3 bdim(THREADS_X, THREADS_Y);
census_transform_kernel<<<gdim, bdim, 0, stream>>>(dest, src, width, height, pitch); census_transform_kernel<T, 5, 5><<<gdim, bdim, 0, stream>>>(dest, src, width, height, pitch);
} else if (ct_shape == sgm::CensusShape::CIRCLE_4_2) {
} }
} }
...@@ -180,13 +180,14 @@ void CensusTransform<T>::enqueue( ...@@ -180,13 +180,14 @@ void CensusTransform<T>::enqueue(
int width, int width,
int height, int height,
int pitch, int pitch,
sgm::CensusShape ct_shape,
cudaStream_t stream) cudaStream_t stream)
{ {
if(m_feature_buffer.size() < static_cast<size_t>(width * height)){ if(m_feature_buffer.size() < static_cast<size_t>(width * height)){
m_feature_buffer = DeviceBuffer<feature_type>(width * height); m_feature_buffer = DeviceBuffer<feature_type>(width * height);
} }
enqueue_census_transform( enqueue_census_transform(
m_feature_buffer.data(), src, width, height, pitch, stream); m_feature_buffer.data(), src, width, height, pitch, ct_shape, stream);
} }
template class CensusTransform<uint8_t>; template class CensusTransform<uint8_t>;
......
...@@ -19,6 +19,7 @@ limitations under the License. ...@@ -19,6 +19,7 @@ limitations under the License.
#include "device_buffer.hpp" #include "device_buffer.hpp"
#include "types.hpp" #include "types.hpp"
#include "libsgm_parameters.hpp"
namespace sgm { namespace sgm {
...@@ -43,6 +44,7 @@ public: ...@@ -43,6 +44,7 @@ public:
int width, int width,
int height, int height,
int pitch, int pitch,
sgm::CensusShape ct_shape,
cudaStream_t stream); cudaStream_t stream);
}; };
......
...@@ -58,12 +58,13 @@ public: ...@@ -58,12 +58,13 @@ public:
float uniqueness, float uniqueness,
bool subpixel, bool subpixel,
int min_disp, int min_disp,
sgm::CensusShape ct_shape,
cudaStream_t stream) cudaStream_t stream)
{ {
m_census_left.enqueue( m_census_left.enqueue(
src_left, width, height, src_pitch, stream); src_left, width, height, src_pitch, ct_shape, stream);
m_census_right.enqueue( m_census_right.enqueue(
src_right, width, height, src_pitch, stream); src_right, width, height, src_pitch, ct_shape, stream);
m_path_aggregation.enqueue( m_path_aggregation.enqueue(
m_census_left.get_output(), m_census_left.get_output(),
m_census_right.get_output(), m_census_right.get_output(),
...@@ -109,6 +110,7 @@ void SemiGlobalMatching<T, MAX_DISPARITY>::execute( ...@@ -109,6 +110,7 @@ void SemiGlobalMatching<T, MAX_DISPARITY>::execute(
float uniqueness, float uniqueness,
bool subpixel, bool subpixel,
int min_disp, int min_disp,
sgm::CensusShape ct_shape,
cudaStream_t stream) cudaStream_t stream)
{ {
m_impl->enqueue( m_impl->enqueue(
...@@ -119,7 +121,7 @@ void SemiGlobalMatching<T, MAX_DISPARITY>::execute( ...@@ -119,7 +121,7 @@ void SemiGlobalMatching<T, MAX_DISPARITY>::execute(
penalty1, penalty2, penalty1, penalty2,
weights, weights_pitch, weights, weights_pitch,
uniqueness, subpixel, uniqueness, subpixel,
min_disp, min_disp, ct_shape,
stream); stream);
//cudaStreamSynchronize(0); //cudaStreamSynchronize(0);
} }
...@@ -141,6 +143,7 @@ void SemiGlobalMatching<T, MAX_DISPARITY>::enqueue( ...@@ -141,6 +143,7 @@ void SemiGlobalMatching<T, MAX_DISPARITY>::enqueue(
float uniqueness, float uniqueness,
bool subpixel, bool subpixel,
int min_disp, int min_disp,
sgm::CensusShape ct_shape,
cudaStream_t stream) cudaStream_t stream)
{ {
m_impl->enqueue( m_impl->enqueue(
...@@ -151,7 +154,7 @@ void SemiGlobalMatching<T, MAX_DISPARITY>::enqueue( ...@@ -151,7 +154,7 @@ void SemiGlobalMatching<T, MAX_DISPARITY>::enqueue(
penalty1, penalty2, penalty1, penalty2,
weights, weights_pitch, weights, weights_pitch,
uniqueness, subpixel, uniqueness, subpixel,
min_disp, min_disp, ct_shape,
stream); stream);
} }
......
...@@ -20,6 +20,7 @@ limitations under the License. ...@@ -20,6 +20,7 @@ limitations under the License.
#include <memory> #include <memory>
#include <cstdint> #include <cstdint>
#include "types.hpp" #include "types.hpp"
#include "libsgm_parameters.hpp"
namespace sgm { namespace sgm {
...@@ -54,6 +55,7 @@ public: ...@@ -54,6 +55,7 @@ public:
float uniqueness, float uniqueness,
bool subpixel, bool subpixel,
int min_disp, int min_disp,
sgm::CensusShape ct_shape,
cudaStream_t stream); cudaStream_t stream);
void enqueue( void enqueue(
...@@ -72,6 +74,7 @@ public: ...@@ -72,6 +74,7 @@ public:
float uniqueness, float uniqueness,
bool subpixel, bool subpixel,
int min_disp, int min_disp,
sgm::CensusShape ct_shape,
cudaStream_t stream); cudaStream_t stream);
}; };
......
...@@ -29,7 +29,7 @@ namespace sgm { ...@@ -29,7 +29,7 @@ namespace sgm {
public: public:
using output_type = sgm::output_type; using output_type = sgm::output_type;
virtual void execute(output_type* dst_L, output_type* dst_R, const void* src_L, const void* src_R, virtual void execute(output_type* dst_L, output_type* dst_R, const void* src_L, const void* src_R,
int w, int h, int sp, int dp, unsigned int P1, const uint8_t *P2, const uint8_t *weights, int weights_pitch, float uniqueness, bool subpixel, int min_disp, cudaStream_t stream) = 0; int w, int h, int sp, int dp, unsigned int P1, const uint8_t *P2, const uint8_t *weights, int weights_pitch, float uniqueness, bool subpixel, int min_disp, sgm::CensusShape ct_shape, cudaStream_t stream) = 0;
virtual ~SemiGlobalMatchingBase() {} virtual ~SemiGlobalMatchingBase() {}
}; };
...@@ -38,9 +38,9 @@ namespace sgm { ...@@ -38,9 +38,9 @@ namespace sgm {
class SemiGlobalMatchingImpl : public SemiGlobalMatchingBase { class SemiGlobalMatchingImpl : public SemiGlobalMatchingBase {
public: public:
void execute(output_type* dst_L, output_type* dst_R, const void* src_L, const void* src_R, void execute(output_type* dst_L, output_type* dst_R, const void* src_L, const void* src_R,
int w, int h, int sp, int dp, unsigned int P1, const uint8_t *P2, const uint8_t *weights, int weights_pitch, float uniqueness, bool subpixel, int min_disp, cudaStream_t stream) override int w, int h, int sp, int dp, unsigned int P1, const uint8_t *P2, const uint8_t *weights, int weights_pitch, float uniqueness, bool subpixel, int min_disp, sgm::CensusShape ct_shape, cudaStream_t stream) override
{ {
sgm_engine_.execute(dst_L, dst_R, (const input_type*)src_L, (const input_type*)src_R, w, h, sp, dp, P1, P2, weights, weights_pitch, uniqueness, subpixel, min_disp, stream); sgm_engine_.execute(dst_L, dst_R, (const input_type*)src_L, (const input_type*)src_R, w, h, sp, dp, P1, P2, weights, weights_pitch, uniqueness, subpixel, min_disp, ct_shape, stream);
} }
private: private:
SemiGlobalMatching<input_type, DISP_SIZE> sgm_engine_; SemiGlobalMatching<input_type, DISP_SIZE> sgm_engine_;
...@@ -176,7 +176,7 @@ namespace sgm { ...@@ -176,7 +176,7 @@ namespace sgm {
d_left_disp = dst; // when threre is no device-host copy or type conversion, use passed buffer d_left_disp = dst; // when threre is no device-host copy or type conversion, use passed buffer
cu_res_->sgm_engine->execute((uint16_t*)d_tmp_left_disp, (uint16_t*)d_tmp_right_disp, cu_res_->sgm_engine->execute((uint16_t*)d_tmp_left_disp, (uint16_t*)d_tmp_right_disp,
d_input_left, d_input_right, width, height, src_pitch, dst_pitch, param_.P1, P2, weights, weights_pitch, param_.uniqueness, param_.subpixel, min_disp, stream); d_input_left, d_input_right, width, height, src_pitch, dst_pitch, param_.P1, P2, weights, weights_pitch, param_.uniqueness, param_.subpixel, min_disp, param_.ct_shape, stream);
sgm::details::median_filter((uint16_t*)d_tmp_left_disp, (uint16_t*)d_left_disp, width, height, dst_pitch, stream); sgm::details::median_filter((uint16_t*)d_tmp_left_disp, (uint16_t*)d_left_disp, width, height, dst_pitch, stream);
sgm::details::median_filter((uint16_t*)d_tmp_right_disp, (uint16_t*)d_right_disp, width, height, dst_pitch, stream); sgm::details::median_filter((uint16_t*)d_tmp_right_disp, (uint16_t*)d_right_disp, width, height, dst_pitch, stream);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment