From bf281645b62f3faf89d89d188b063036950e50cc Mon Sep 17 00:00:00 2001
From: Nicolas Pope <nwpope@utu.fi>
Date: Tue, 27 Oct 2020 14:15:06 +0200
Subject: [PATCH] Allow selection of CT shape

---
 lib/libsgm/include/libsgm.h              |  4 +++-
 lib/libsgm/include/libsgm_parameters.hpp |  9 +++++++++
 lib/libsgm/src/census_transform.cu       | 23 ++++++++++++-----------
 lib/libsgm/src/census_transform.hpp      |  2 ++
 lib/libsgm/src/sgm.cu                    | 11 +++++++----
 lib/libsgm/src/sgm.hpp                   |  3 +++
 lib/libsgm/src/stereo_sgm.cpp            |  8 ++++----
 7 files changed, 40 insertions(+), 20 deletions(-)
 create mode 100644 lib/libsgm/include/libsgm_parameters.hpp

diff --git a/lib/libsgm/include/libsgm.h b/lib/libsgm/include/libsgm.h
index bf2c58ea7..7f8218ea0 100644
--- a/lib/libsgm/include/libsgm.h
+++ b/lib/libsgm/include/libsgm.h
@@ -28,6 +28,7 @@ limitations under the License.
 
 #include <stdint.h>
 #include "libsgm_config.h"
+#include "libsgm_parameters.hpp"
 #include <cuda_runtime.h>
 
 #if defined(LIBSGM_SHARED)
@@ -71,7 +72,8 @@ namespace sgm {
 			int P2;
 			float uniqueness;
 			bool subpixel;
-			Parameters(int P1 = 10, int P2 = 120, float uniqueness = 0.95f, bool subpixel = false) : P1(P1), P2(P2), uniqueness(uniqueness), subpixel(subpixel) {}
+			CensusShape ct_shape;
+			Parameters(int P1 = 10, int P2 = 120, float uniqueness = 0.95f, bool subpixel = false, CensusShape ct_shape = CensusShape::CIRCLE_4_2) : P1(P1), P2(P2), uniqueness(uniqueness), subpixel(subpixel), ct_shape(ct_shape) {}
 		};
 
 		/**
diff --git a/lib/libsgm/include/libsgm_parameters.hpp b/lib/libsgm/include/libsgm_parameters.hpp
new file mode 100644
index 000000000..9c61db728
--- /dev/null
+++ b/lib/libsgm/include/libsgm_parameters.hpp
@@ -0,0 +1,9 @@
+#pragma once
+
+namespace sgm {
+	enum class CensusShape {
+		CT_5X5,
+		CS_CT_9X7,
+		CIRCLE_4_2
+	};
+}
diff --git a/lib/libsgm/src/census_transform.cu b/lib/libsgm/src/census_transform.cu
index d437f978e..90d839743 100644
--- a/lib/libsgm/src/census_transform.cu
+++ b/lib/libsgm/src/census_transform.cu
@@ -21,14 +21,11 @@ namespace sgm {
 
 namespace {
 
-static constexpr int WINDOW_WIDTH  = 5;
-static constexpr int WINDOW_HEIGHT = 5;
-
 static constexpr int BLOCK_SIZE = 128;
 static constexpr int LINES_PER_BLOCK = 16;
 
 /* Centre symmetric census */
-template <typename T>
+template <typename T, int WINDOW_WIDTH, int WINDOW_HEIGHT>
 __global__ void cs_census_transform_kernel(
 	feature_type *dest,
 	const T *src,
@@ -104,7 +101,7 @@ __global__ void cs_census_transform_kernel(
 	}
 }
 
-template <typename T>
+template <typename T, int WINDOW_WIDTH, int WINDOW_HEIGHT>
 __global__ void census_transform_kernel(
 	feature_type* __restrict__ dest,
 	const T* __restrict__ src,
@@ -145,24 +142,27 @@ void enqueue_census_transform(
 	int width,
 	int height,
 	int pitch,
+	sgm::CensusShape ct_shape,
 	cudaStream_t stream)
 {
 	/* Disable the original center symmetric algorithm */
-	if (false) {
-		const int width_per_block = BLOCK_SIZE - WINDOW_WIDTH + 1;
+	if (ct_shape == sgm::CensusShape::CS_CT_9X7) {
+		const int width_per_block = BLOCK_SIZE - 9 + 1;
 		const int height_per_block = LINES_PER_BLOCK;
 		const dim3 gdim(
 			(width  + width_per_block  - 1) / width_per_block,
 			(height + height_per_block - 1) / height_per_block);
 		const dim3 bdim(BLOCK_SIZE);
-		cs_census_transform_kernel<<<gdim, bdim, 0, stream>>>(dest, src, width, height, pitch);
-	} else {
+		cs_census_transform_kernel<T, 9, 7><<<gdim, bdim, 0, stream>>>(dest, src, width, height, pitch);
+	} else if (ct_shape == sgm::CensusShape::CT_5X5) {
 		static constexpr int THREADS_X = 16;
 		static constexpr int THREADS_Y = 16;
 
 		const dim3 gdim((width + THREADS_X - 1)/THREADS_X, (height + THREADS_Y - 1)/THREADS_Y);
 		const dim3 bdim(THREADS_X, THREADS_Y);
-		census_transform_kernel<<<gdim, bdim, 0, stream>>>(dest, src, width, height, pitch);
+		census_transform_kernel<T, 5, 5><<<gdim, bdim, 0, stream>>>(dest, src, width, height, pitch);
+	} else if (ct_shape == sgm::CensusShape::CIRCLE_4_2) {
+
 	}
 }
 
@@ -180,13 +180,14 @@ void CensusTransform<T>::enqueue(
 	int width,
 	int height,
 	int pitch,
+	sgm::CensusShape ct_shape,
 	cudaStream_t stream)
 {
 	if(m_feature_buffer.size() < static_cast<size_t>(width * height)){
 		m_feature_buffer = DeviceBuffer<feature_type>(width * height);
 	}
 	enqueue_census_transform(
-		m_feature_buffer.data(), src, width, height, pitch, stream);
+		m_feature_buffer.data(), src, width, height, pitch, ct_shape, stream);
 }
 
 template class CensusTransform<uint8_t>;
diff --git a/lib/libsgm/src/census_transform.hpp b/lib/libsgm/src/census_transform.hpp
index 8a80b903b..23c1ebd3f 100644
--- a/lib/libsgm/src/census_transform.hpp
+++ b/lib/libsgm/src/census_transform.hpp
@@ -19,6 +19,7 @@ limitations under the License.
 
 #include "device_buffer.hpp"
 #include "types.hpp"
+#include "libsgm_parameters.hpp"
 
 namespace sgm {
 
@@ -43,6 +44,7 @@ public:
 		int width,
 		int height,
 		int pitch,
+		sgm::CensusShape ct_shape,
 		cudaStream_t stream);
 
 };
diff --git a/lib/libsgm/src/sgm.cu b/lib/libsgm/src/sgm.cu
index eb5d4179e..0adf22301 100644
--- a/lib/libsgm/src/sgm.cu
+++ b/lib/libsgm/src/sgm.cu
@@ -58,12 +58,13 @@ public:
 		float uniqueness,
 		bool subpixel,
 		int min_disp,
+		sgm::CensusShape ct_shape,
 		cudaStream_t stream)
 	{
 		m_census_left.enqueue(
-			src_left, width, height, src_pitch, stream);
+			src_left, width, height, src_pitch, ct_shape, stream);
 		m_census_right.enqueue(
-			src_right, width, height, src_pitch, stream);
+			src_right, width, height, src_pitch, ct_shape, stream);
 		m_path_aggregation.enqueue(
 			m_census_left.get_output(),
 			m_census_right.get_output(),
@@ -109,6 +110,7 @@ void SemiGlobalMatching<T, MAX_DISPARITY>::execute(
 	float uniqueness,
 	bool subpixel,
 	int min_disp,
+	sgm::CensusShape ct_shape,
 	cudaStream_t stream)
 {
 	m_impl->enqueue(
@@ -119,7 +121,7 @@ void SemiGlobalMatching<T, MAX_DISPARITY>::execute(
 		penalty1, penalty2,
 		weights, weights_pitch,
 		uniqueness, subpixel,
-		min_disp,
+		min_disp, ct_shape,
 		stream);
 	//cudaStreamSynchronize(0);
 }
@@ -141,6 +143,7 @@ void SemiGlobalMatching<T, MAX_DISPARITY>::enqueue(
 	float uniqueness,
 	bool subpixel,
 	int min_disp,
+	sgm::CensusShape ct_shape,
 	cudaStream_t stream)
 {
 	m_impl->enqueue(
@@ -151,7 +154,7 @@ void SemiGlobalMatching<T, MAX_DISPARITY>::enqueue(
 		penalty1, penalty2,
 		weights, weights_pitch,
 		uniqueness, subpixel,
-		min_disp,
+		min_disp, ct_shape,
 		stream);
 }
 
diff --git a/lib/libsgm/src/sgm.hpp b/lib/libsgm/src/sgm.hpp
index 9aa2cd387..c792c4a6e 100644
--- a/lib/libsgm/src/sgm.hpp
+++ b/lib/libsgm/src/sgm.hpp
@@ -20,6 +20,7 @@ limitations under the License.
 #include <memory>
 #include <cstdint>
 #include "types.hpp"
+#include "libsgm_parameters.hpp"
 
 namespace sgm {
 
@@ -54,6 +55,7 @@ public:
 		float uniqueness,
 		bool subpixel,
 		int min_disp,
+		sgm::CensusShape ct_shape,
 		cudaStream_t stream);
 
 	void enqueue(
@@ -72,6 +74,7 @@ public:
 		float uniqueness,
 		bool subpixel,
 		int min_disp,
+		sgm::CensusShape ct_shape,
 		cudaStream_t stream);
 
 };
diff --git a/lib/libsgm/src/stereo_sgm.cpp b/lib/libsgm/src/stereo_sgm.cpp
index 70e126314..6e7c0bf18 100644
--- a/lib/libsgm/src/stereo_sgm.cpp
+++ b/lib/libsgm/src/stereo_sgm.cpp
@@ -29,7 +29,7 @@ namespace sgm {
 	public:
 		using output_type = sgm::output_type;
 		virtual void execute(output_type* dst_L, output_type* dst_R, const void* src_L, const void* src_R,
-			int w, int h, int sp, int dp, unsigned int P1, const uint8_t *P2, const uint8_t *weights, int weights_pitch, float uniqueness, bool subpixel, int min_disp, cudaStream_t stream) = 0;
+			int w, int h, int sp, int dp, unsigned int P1, const uint8_t *P2, const uint8_t *weights, int weights_pitch, float uniqueness, bool subpixel, int min_disp, sgm::CensusShape ct_shape, cudaStream_t stream) = 0;
 
 		virtual ~SemiGlobalMatchingBase() {}
 	};
@@ -38,9 +38,9 @@ namespace sgm {
 	class SemiGlobalMatchingImpl : public SemiGlobalMatchingBase {
 	public:
 		void execute(output_type* dst_L, output_type* dst_R, const void* src_L, const void* src_R,
-			int w, int h, int sp, int dp, unsigned int P1, const uint8_t *P2, const uint8_t *weights, int weights_pitch, float uniqueness, bool subpixel, int min_disp, cudaStream_t stream) override
+			int w, int h, int sp, int dp, unsigned int P1, const uint8_t *P2, const uint8_t *weights, int weights_pitch, float uniqueness, bool subpixel, int min_disp, sgm::CensusShape ct_shape, cudaStream_t stream) override
 		{
-			sgm_engine_.execute(dst_L, dst_R, (const input_type*)src_L, (const input_type*)src_R, w, h, sp, dp, P1, P2, weights, weights_pitch, uniqueness, subpixel, min_disp, stream);
+			sgm_engine_.execute(dst_L, dst_R, (const input_type*)src_L, (const input_type*)src_R, w, h, sp, dp, P1, P2, weights, weights_pitch, uniqueness, subpixel, min_disp, ct_shape, stream);
 		}
 	private:
 		SemiGlobalMatching<input_type, DISP_SIZE> sgm_engine_;
@@ -176,7 +176,7 @@ namespace sgm {
 			d_left_disp = dst; // when threre is no device-host copy or type conversion, use passed buffer
 
 		cu_res_->sgm_engine->execute((uint16_t*)d_tmp_left_disp, (uint16_t*)d_tmp_right_disp,
-			d_input_left, d_input_right, width, height, src_pitch, dst_pitch, param_.P1, P2, weights, weights_pitch, param_.uniqueness, param_.subpixel, min_disp, stream);
+			d_input_left, d_input_right, width, height, src_pitch, dst_pitch, param_.P1, P2, weights, weights_pitch, param_.uniqueness, param_.subpixel, min_disp, param_.ct_shape, stream);
 
 		sgm::details::median_filter((uint16_t*)d_tmp_left_disp, (uint16_t*)d_left_disp, width, height, dst_pitch, stream);
 		sgm::details::median_filter((uint16_t*)d_tmp_right_disp, (uint16_t*)d_right_disp, width, height, dst_pitch, stream);
-- 
GitLab