From 317b2094dc209bbe796bfa1b99918667bd167ffe Mon Sep 17 00:00:00 2001
From: Nicolas Pope <nicolas.pope@utu.fi>
Date: Thu, 22 Oct 2020 21:28:41 +0300
Subject: [PATCH] Use 5x5 proper census for SGM

---
 lib/libsgm/src/census_transform.cu | 64 +++++++++++++++++++++++++-----
 1 file changed, 54 insertions(+), 10 deletions(-)

diff --git a/lib/libsgm/src/census_transform.cu b/lib/libsgm/src/census_transform.cu
index 70f67bcff..df3888618 100644
--- a/lib/libsgm/src/census_transform.cu
+++ b/lib/libsgm/src/census_transform.cu
@@ -21,14 +21,15 @@ namespace sgm {
 
 namespace {
 
-static constexpr int WINDOW_WIDTH  = 9;
-static constexpr int WINDOW_HEIGHT = 7;
+static constexpr int WINDOW_WIDTH  = 5;
+static constexpr int WINDOW_HEIGHT = 5;
 
 static constexpr int BLOCK_SIZE = 128;
 static constexpr int LINES_PER_BLOCK = 16;
 
+/* Centre symmetric census */
 template <typename T>
-__global__ void census_transform_kernel(
+__global__ void cs_census_transform_kernel(
 	feature_type *dest,
 	const T *src,
 	int width,
@@ -103,6 +104,39 @@ __global__ void census_transform_kernel(
 	}
 }
 
+template <typename T>
+__global__ void census_transform_kernel(
+	feature_type* __restrict__ dest,
+	const T* __restrict__ src,
+	int width,
+	int height,
+	int pitch)
+{
+	static constexpr int RADIUS_X = WINDOW_WIDTH/2;
+	static constexpr int RADIUS_Y = WINDOW_HEIGHT/2;
+
+	const int x = (blockIdx.x*blockDim.x + threadIdx.x);
+	const int y = blockIdx.y*blockDim.y + threadIdx.y;
+
+	feature_type res = 0;
+
+	if (x >= RADIUS_X && y >= RADIUS_Y && x < width-RADIUS_X && y < height-RADIUS_Y) {
+		const T center = src[y*pitch+x];
+
+		#pragma unroll
+		for (int wy = -RADIUS_Y; wy <= RADIUS_Y; ++wy) {
+			const int i = (y + wy) * pitch + x;
+
+			#pragma unroll
+			for (int wx = -RADIUS_X; wx <= RADIUS_X; ++wx) {
+				res = (res << 1) | (center < (src[i+wx]) ? 1 : 0);
+			}
+		}
+	}
+
+	dest[x+y*width] = res;
+}
+
 template <typename T>
 void enqueue_census_transform(
 	feature_type *dest,
@@ -112,13 +146,23 @@ void enqueue_census_transform(
 	int pitch,
 	cudaStream_t stream)
 {
-	const int width_per_block = BLOCK_SIZE - WINDOW_WIDTH + 1;
-	const int height_per_block = LINES_PER_BLOCK;
-	const dim3 gdim(
-		(width  + width_per_block  - 1) / width_per_block,
-		(height + height_per_block - 1) / height_per_block);
-	const dim3 bdim(BLOCK_SIZE);
-	census_transform_kernel<<<gdim, bdim, 0, stream>>>(dest, src, width, height, pitch);
+	/* Disable the original center symmetric algorithm */
+	if (false) {
+		const int width_per_block = BLOCK_SIZE - WINDOW_WIDTH + 1;
+		const int height_per_block = LINES_PER_BLOCK;
+		const dim3 gdim(
+			(width  + width_per_block  - 1) / width_per_block,
+			(height + height_per_block - 1) / height_per_block);
+		const dim3 bdim(BLOCK_SIZE);
+		cs_census_transform_kernel<<<gdim, bdim, 0, stream>>>(dest, src, width, height, pitch);
+	} else {
+		static constexpr int THREADS_X = 16;
+		static constexpr int THREADS_Y = 16;
+
+		const dim3 gdim((width + THREADS_X - 1)/THREADS_X, (height + THREADS_Y - 1)/THREADS_Y);
+		const dim3 bdim(THREADS_X, THREADS_Y);
+		census_transform_kernel<<<gdim, bdim, 0, stream>>>(dest, src, width, height, pitch);
+	}
 }
 
 }
-- 
GitLab