From 9bcd8d9e5e749e272ab3c9303d5aaae75ee91c2d Mon Sep 17 00:00:00 2001
From: Nicolas Pope <nwpope@utu.fi>
Date: Tue, 27 Oct 2020 14:42:26 +0200
Subject: [PATCH] Implement circle 3 CT

---
 lib/libsgm/include/libsgm.h              |  2 +-
 lib/libsgm/include/libsgm_parameters.hpp |  2 +-
 lib/libsgm/src/census_transform.cu       | 73 +++++++++++++++++++++++-
 3 files changed, 74 insertions(+), 3 deletions(-)

diff --git a/lib/libsgm/include/libsgm.h b/lib/libsgm/include/libsgm.h
index 7f8218ea0..653a2a5df 100644
--- a/lib/libsgm/include/libsgm.h
+++ b/lib/libsgm/include/libsgm.h
@@ -73,7 +73,7 @@ namespace sgm {
 			float uniqueness;
 			bool subpixel;
 			CensusShape ct_shape;
-			Parameters(int P1 = 10, int P2 = 120, float uniqueness = 0.95f, bool subpixel = false, CensusShape ct_shape = CensusShape::CIRCLE_4_2) : P1(P1), P2(P2), uniqueness(uniqueness), subpixel(subpixel), ct_shape(ct_shape) {}
+			Parameters(int P1 = 10, int P2 = 120, float uniqueness = 0.95f, bool subpixel = false, CensusShape ct_shape = CensusShape::CIRCLE_3) : P1(P1), P2(P2), uniqueness(uniqueness), subpixel(subpixel), ct_shape(ct_shape) {}
 		};
 
 		/**
diff --git a/lib/libsgm/include/libsgm_parameters.hpp b/lib/libsgm/include/libsgm_parameters.hpp
index 9c61db728..60c003247 100644
--- a/lib/libsgm/include/libsgm_parameters.hpp
+++ b/lib/libsgm/include/libsgm_parameters.hpp
@@ -4,6 +4,6 @@ namespace sgm {
 	enum class CensusShape {
 		CT_5X5,
 		CS_CT_9X7,
-		CIRCLE_4_2
+		CIRCLE_3
 	};
 }
diff --git a/lib/libsgm/src/census_transform.cu b/lib/libsgm/src/census_transform.cu
index 90d839743..f46a98557 100644
--- a/lib/libsgm/src/census_transform.cu
+++ b/lib/libsgm/src/census_transform.cu
@@ -135,6 +135,72 @@ __global__ void census_transform_kernel(
 	if (x < width && y < height) dest[x+y*width] = res;
 }
 
+template <typename T>
+__global__ void circle_ct_3_kernel(
+	feature_type* __restrict__ dest,
+	const T* __restrict__ src,
+	int width,
+	int height,
+	int pitch)
+{
+	static constexpr int RADIUS_X = 3;
+	static constexpr int RADIUS_Y = 3;
+
+	const int x = (blockIdx.x*blockDim.x + threadIdx.x);
+	const int y = blockIdx.y*blockDim.y + threadIdx.y;
+
+	feature_type res = 0;
+
+	if (x >= RADIUS_X && y >= RADIUS_Y && x < width-RADIUS_X && y < height-RADIUS_Y) {
+		const T center = src[y*pitch+x];
+
+		int yix = y*pitch+x;
+		res = (res << 1) | (center < (src[yix-3]) ? 1 : 0);
+		res = (res << 1) | (center < (src[yix-2]) ? 1 : 0);
+		res = (res << 1) | (center < (src[yix-1]) ? 1 : 0);
+		res = (res << 1) | (center < (src[yix+1]) ? 1 : 0);
+		res = (res << 1) | (center < (src[yix+2]) ? 1 : 0);
+		res = (res << 1) | (center < (src[yix+3]) ? 1 : 0);
+
+		yix = (y-1)*pitch+x;
+		res = (res << 1) | (center < (src[yix-2]) ? 1 : 0);
+		res = (res << 1) | (center < (src[yix-1]) ? 1 : 0);
+		res = (res << 1) | (center < (src[yix]) ? 1 : 0);
+		res = (res << 1) | (center < (src[yix+1]) ? 1 : 0);
+		res = (res << 1) | (center < (src[yix+2]) ? 1 : 0);
+
+		yix = (y-2)*pitch+x;
+		res = (res << 1) | (center < (src[yix-2]) ? 1 : 0);
+		res = (res << 1) | (center < (src[yix-1]) ? 1 : 0);
+		res = (res << 1) | (center < (src[yix]) ? 1 : 0);
+		res = (res << 1) | (center < (src[yix+1]) ? 1 : 0);
+		res = (res << 1) | (center < (src[yix+2]) ? 1 : 0);
+
+		yix = (y-3)*pitch+x;
+		res = (res << 1) | (center < (src[yix]) ? 1 : 0);
+
+		yix = (y+1)*pitch+x;
+		res = (res << 1) | (center < (src[yix-2]) ? 1 : 0);
+		res = (res << 1) | (center < (src[yix-1]) ? 1 : 0);
+		res = (res << 1) | (center < (src[yix]) ? 1 : 0);
+		res = (res << 1) | (center < (src[yix+1]) ? 1 : 0);
+		res = (res << 1) | (center < (src[yix+2]) ? 1 : 0);
+
+		yix = (y+2)*pitch+x;
+		res = (res << 1) | (center < (src[yix-2]) ? 1 : 0);
+		res = (res << 1) | (center < (src[yix-1]) ? 1 : 0);
+		res = (res << 1) | (center < (src[yix]) ? 1 : 0);
+		res = (res << 1) | (center < (src[yix+1]) ? 1 : 0);
+		res = (res << 1) | (center < (src[yix+2]) ? 1 : 0);
+
+		yix = (y+3)*pitch+x;
+		res = (res << 1) | (center < (src[yix]) ? 1 : 0);
+	}
+
+	// FIXME: Should use feature pitch, not width.
+	if (x < width && y < height) dest[x+y*width] = res;
+}
+
 template <typename T>
 void enqueue_census_transform(
 	feature_type *dest,
@@ -161,8 +227,13 @@ void enqueue_census_transform(
 		const dim3 gdim((width + THREADS_X - 1)/THREADS_X, (height + THREADS_Y - 1)/THREADS_Y);
 		const dim3 bdim(THREADS_X, THREADS_Y);
 		census_transform_kernel<T, 5, 5><<<gdim, bdim, 0, stream>>>(dest, src, width, height, pitch);
-	} else if (ct_shape == sgm::CensusShape::CIRCLE_4_2) {
+	} else if (ct_shape == sgm::CensusShape::CIRCLE_3) {
+		static constexpr int THREADS_X = 16;
+		static constexpr int THREADS_Y = 16;
 
+		const dim3 gdim((width + THREADS_X - 1)/THREADS_X, (height + THREADS_Y - 1)/THREADS_Y);
+		const dim3 bdim(THREADS_X, THREADS_Y);
+		circle_ct_3_kernel<<<gdim, bdim, 0, stream>>>(dest, src, width, height, pitch);
 	}
 }
 
-- 
GitLab