From bf281645b62f3faf89d89d188b063036950e50cc Mon Sep 17 00:00:00 2001 From: Nicolas Pope <nwpope@utu.fi> Date: Tue, 27 Oct 2020 14:15:06 +0200 Subject: [PATCH] Allow selection of CT shape --- lib/libsgm/include/libsgm.h | 4 +++- lib/libsgm/include/libsgm_parameters.hpp | 9 +++++++++ lib/libsgm/src/census_transform.cu | 23 ++++++++++++----------- lib/libsgm/src/census_transform.hpp | 2 ++ lib/libsgm/src/sgm.cu | 11 +++++++---- lib/libsgm/src/sgm.hpp | 3 +++ lib/libsgm/src/stereo_sgm.cpp | 8 ++++---- 7 files changed, 40 insertions(+), 20 deletions(-) create mode 100644 lib/libsgm/include/libsgm_parameters.hpp diff --git a/lib/libsgm/include/libsgm.h b/lib/libsgm/include/libsgm.h index bf2c58ea7..7f8218ea0 100644 --- a/lib/libsgm/include/libsgm.h +++ b/lib/libsgm/include/libsgm.h @@ -28,6 +28,7 @@ limitations under the License. #include <stdint.h> #include "libsgm_config.h" +#include "libsgm_parameters.hpp" #include <cuda_runtime.h> #if defined(LIBSGM_SHARED) @@ -71,7 +72,8 @@ namespace sgm { int P2; float uniqueness; bool subpixel; - Parameters(int P1 = 10, int P2 = 120, float uniqueness = 0.95f, bool subpixel = false) : P1(P1), P2(P2), uniqueness(uniqueness), subpixel(subpixel) {} + CensusShape ct_shape; + Parameters(int P1 = 10, int P2 = 120, float uniqueness = 0.95f, bool subpixel = false, CensusShape ct_shape = CensusShape::CIRCLE_4_2) : P1(P1), P2(P2), uniqueness(uniqueness), subpixel(subpixel), ct_shape(ct_shape) {} }; /** diff --git a/lib/libsgm/include/libsgm_parameters.hpp b/lib/libsgm/include/libsgm_parameters.hpp new file mode 100644 index 000000000..9c61db728 --- /dev/null +++ b/lib/libsgm/include/libsgm_parameters.hpp @@ -0,0 +1,9 @@ +#pragma once + +namespace sgm { + enum class CensusShape { + CT_5X5, + CS_CT_9X7, + CIRCLE_4_2 + }; +} diff --git a/lib/libsgm/src/census_transform.cu b/lib/libsgm/src/census_transform.cu index d437f978e..90d839743 100644 --- a/lib/libsgm/src/census_transform.cu +++ b/lib/libsgm/src/census_transform.cu @@ -21,14 +21,11 @@ namespace sgm { namespace { -static constexpr int WINDOW_WIDTH = 5; -static constexpr int WINDOW_HEIGHT = 5; - static constexpr int BLOCK_SIZE = 128; static constexpr int LINES_PER_BLOCK = 16; /* Centre symmetric census */ -template <typename T> +template <typename T, int WINDOW_WIDTH, int WINDOW_HEIGHT> __global__ void cs_census_transform_kernel( feature_type *dest, const T *src, @@ -104,7 +101,7 @@ __global__ void cs_census_transform_kernel( } } -template <typename T> +template <typename T, int WINDOW_WIDTH, int WINDOW_HEIGHT> __global__ void census_transform_kernel( feature_type* __restrict__ dest, const T* __restrict__ src, @@ -145,24 +142,27 @@ void enqueue_census_transform( int width, int height, int pitch, + sgm::CensusShape ct_shape, cudaStream_t stream) { /* Disable the original center symmetric algorithm */ - if (false) { - const int width_per_block = BLOCK_SIZE - WINDOW_WIDTH + 1; + if (ct_shape == sgm::CensusShape::CS_CT_9X7) { + const int width_per_block = BLOCK_SIZE - 9 + 1; const int height_per_block = LINES_PER_BLOCK; const dim3 gdim( (width + width_per_block - 1) / width_per_block, (height + height_per_block - 1) / height_per_block); const dim3 bdim(BLOCK_SIZE); - cs_census_transform_kernel<<<gdim, bdim, 0, stream>>>(dest, src, width, height, pitch); - } else { + cs_census_transform_kernel<T, 9, 7><<<gdim, bdim, 0, stream>>>(dest, src, width, height, pitch); + } else if (ct_shape == sgm::CensusShape::CT_5X5) { static constexpr int THREADS_X = 16; static constexpr int THREADS_Y = 16; const dim3 gdim((width + THREADS_X - 1)/THREADS_X, (height + THREADS_Y - 1)/THREADS_Y); const dim3 bdim(THREADS_X, THREADS_Y); - census_transform_kernel<<<gdim, bdim, 0, stream>>>(dest, src, width, height, pitch); + census_transform_kernel<T, 5, 5><<<gdim, bdim, 0, stream>>>(dest, src, width, height, pitch); + } else if (ct_shape == sgm::CensusShape::CIRCLE_4_2) { + } } @@ -180,13 +180,14 @@ void CensusTransform<T>::enqueue( int width, int height, int pitch, + sgm::CensusShape ct_shape, cudaStream_t stream) { if(m_feature_buffer.size() < static_cast<size_t>(width * height)){ m_feature_buffer = DeviceBuffer<feature_type>(width * height); } enqueue_census_transform( - m_feature_buffer.data(), src, width, height, pitch, stream); + m_feature_buffer.data(), src, width, height, pitch, ct_shape, stream); } template class CensusTransform<uint8_t>; diff --git a/lib/libsgm/src/census_transform.hpp b/lib/libsgm/src/census_transform.hpp index 8a80b903b..23c1ebd3f 100644 --- a/lib/libsgm/src/census_transform.hpp +++ b/lib/libsgm/src/census_transform.hpp @@ -19,6 +19,7 @@ limitations under the License. #include "device_buffer.hpp" #include "types.hpp" +#include "libsgm_parameters.hpp" namespace sgm { @@ -43,6 +44,7 @@ public: int width, int height, int pitch, + sgm::CensusShape ct_shape, cudaStream_t stream); }; diff --git a/lib/libsgm/src/sgm.cu b/lib/libsgm/src/sgm.cu index eb5d4179e..0adf22301 100644 --- a/lib/libsgm/src/sgm.cu +++ b/lib/libsgm/src/sgm.cu @@ -58,12 +58,13 @@ public: float uniqueness, bool subpixel, int min_disp, + sgm::CensusShape ct_shape, cudaStream_t stream) { m_census_left.enqueue( - src_left, width, height, src_pitch, stream); + src_left, width, height, src_pitch, ct_shape, stream); m_census_right.enqueue( - src_right, width, height, src_pitch, stream); + src_right, width, height, src_pitch, ct_shape, stream); m_path_aggregation.enqueue( m_census_left.get_output(), m_census_right.get_output(), @@ -109,6 +110,7 @@ void SemiGlobalMatching<T, MAX_DISPARITY>::execute( float uniqueness, bool subpixel, int min_disp, + sgm::CensusShape ct_shape, cudaStream_t stream) { m_impl->enqueue( @@ -119,7 +121,7 @@ void SemiGlobalMatching<T, MAX_DISPARITY>::execute( penalty1, penalty2, weights, weights_pitch, uniqueness, subpixel, - min_disp, + min_disp, ct_shape, stream); //cudaStreamSynchronize(0); } @@ -141,6 +143,7 @@ void SemiGlobalMatching<T, MAX_DISPARITY>::enqueue( float uniqueness, bool subpixel, int min_disp, + sgm::CensusShape ct_shape, cudaStream_t stream) { m_impl->enqueue( @@ -151,7 +154,7 @@ void SemiGlobalMatching<T, MAX_DISPARITY>::enqueue( penalty1, penalty2, weights, weights_pitch, uniqueness, subpixel, - min_disp, + min_disp, ct_shape, stream); } diff --git a/lib/libsgm/src/sgm.hpp b/lib/libsgm/src/sgm.hpp index 9aa2cd387..c792c4a6e 100644 --- a/lib/libsgm/src/sgm.hpp +++ b/lib/libsgm/src/sgm.hpp @@ -20,6 +20,7 @@ limitations under the License. #include <memory> #include <cstdint> #include "types.hpp" +#include "libsgm_parameters.hpp" namespace sgm { @@ -54,6 +55,7 @@ public: float uniqueness, bool subpixel, int min_disp, + sgm::CensusShape ct_shape, cudaStream_t stream); void enqueue( @@ -72,6 +74,7 @@ public: float uniqueness, bool subpixel, int min_disp, + sgm::CensusShape ct_shape, cudaStream_t stream); }; diff --git a/lib/libsgm/src/stereo_sgm.cpp b/lib/libsgm/src/stereo_sgm.cpp index 70e126314..6e7c0bf18 100644 --- a/lib/libsgm/src/stereo_sgm.cpp +++ b/lib/libsgm/src/stereo_sgm.cpp @@ -29,7 +29,7 @@ namespace sgm { public: using output_type = sgm::output_type; virtual void execute(output_type* dst_L, output_type* dst_R, const void* src_L, const void* src_R, - int w, int h, int sp, int dp, unsigned int P1, const uint8_t *P2, const uint8_t *weights, int weights_pitch, float uniqueness, bool subpixel, int min_disp, cudaStream_t stream) = 0; + int w, int h, int sp, int dp, unsigned int P1, const uint8_t *P2, const uint8_t *weights, int weights_pitch, float uniqueness, bool subpixel, int min_disp, sgm::CensusShape ct_shape, cudaStream_t stream) = 0; virtual ~SemiGlobalMatchingBase() {} }; @@ -38,9 +38,9 @@ namespace sgm { class SemiGlobalMatchingImpl : public SemiGlobalMatchingBase { public: void execute(output_type* dst_L, output_type* dst_R, const void* src_L, const void* src_R, - int w, int h, int sp, int dp, unsigned int P1, const uint8_t *P2, const uint8_t *weights, int weights_pitch, float uniqueness, bool subpixel, int min_disp, cudaStream_t stream) override + int w, int h, int sp, int dp, unsigned int P1, const uint8_t *P2, const uint8_t *weights, int weights_pitch, float uniqueness, bool subpixel, int min_disp, sgm::CensusShape ct_shape, cudaStream_t stream) override { - sgm_engine_.execute(dst_L, dst_R, (const input_type*)src_L, (const input_type*)src_R, w, h, sp, dp, P1, P2, weights, weights_pitch, uniqueness, subpixel, min_disp, stream); + sgm_engine_.execute(dst_L, dst_R, (const input_type*)src_L, (const input_type*)src_R, w, h, sp, dp, P1, P2, weights, weights_pitch, uniqueness, subpixel, min_disp, ct_shape, stream); } private: SemiGlobalMatching<input_type, DISP_SIZE> sgm_engine_; @@ -176,7 +176,7 @@ namespace sgm { d_left_disp = dst; // when threre is no device-host copy or type conversion, use passed buffer cu_res_->sgm_engine->execute((uint16_t*)d_tmp_left_disp, (uint16_t*)d_tmp_right_disp, - d_input_left, d_input_right, width, height, src_pitch, dst_pitch, param_.P1, P2, weights, weights_pitch, param_.uniqueness, param_.subpixel, min_disp, stream); + d_input_left, d_input_right, width, height, src_pitch, dst_pitch, param_.P1, P2, weights, weights_pitch, param_.uniqueness, param_.subpixel, min_disp, param_.ct_shape, stream); sgm::details::median_filter((uint16_t*)d_tmp_left_disp, (uint16_t*)d_left_disp, width, height, dst_pitch, stream); sgm::details::median_filter((uint16_t*)d_tmp_right_disp, (uint16_t*)d_right_disp, width, height, dst_pitch, stream); -- GitLab