From c1d58062a025c4499a2506b1fbb06c2fe19aca3f Mon Sep 17 00:00:00 2001
From: Sebastian Hahta <joseha@utu.fi>
Date: Fri, 3 Jul 2020 16:33:30 +0300
Subject: [PATCH] fix data race and normalized ct cost

Bug(s) in algorithms::CensusTransform::operator()
and/or in parallel2D<>. Some pixels are processed
more than once. Fix removes read from *out by
using local variable.
---
 lib/libstereo/middlebury/main.cpp  |  2 +-
 lib/libstereo/src/costs/census.cu  | 19 +++++++++++----
 lib/libstereo/src/costs/census.hpp | 39 ++++++++++++++++++++++++++++++
 3 files changed, 54 insertions(+), 6 deletions(-)

diff --git a/lib/libstereo/middlebury/main.cpp b/lib/libstereo/middlebury/main.cpp
index 75e68a99c..d0bf19665 100644
--- a/lib/libstereo/middlebury/main.cpp
+++ b/lib/libstereo/middlebury/main.cpp
@@ -180,7 +180,7 @@ void main_default(const std::vector<std::string> &paths,
 			}
 			std::cout << "Saved\n";
 		}
-		else if (k == 27) {
+		else if (k == 27 || k == 255) {
 			return;
 		}
 	}
diff --git a/lib/libstereo/src/costs/census.cu b/lib/libstereo/src/costs/census.cu
index f2262e985..bd3e01357 100644
--- a/lib/libstereo/src/costs/census.cu
+++ b/lib/libstereo/src/costs/census.cu
@@ -17,6 +17,10 @@ namespace algorithms {
 		__host__ __device__ inline void window(const int y, const int x, uint64_t* __restrict__ out) {
 			short center = im(y, x);
 			uint8_t i = 0; // bit counter for *out
+			// possible BUG in operator(), gets called more than once per pixel;
+			// local variable for sub-bitstring to avoid data race (no read
+			// dependency to out; writes are identical)
+			uint64_t res = 0;
 
 			for (int wy = -WINY/2; wy <= WINY/2; wy++) {
 				for (int wx = -WINX/2; wx <= WINX/2; wx++) {
@@ -24,15 +28,20 @@ namespace algorithms {
 					const int x_ = x + wx;
 
 					// zero if first value, otherwise shift to left
-					if (i % 64 == 0) { *out = 0; }
-					else             { *out = (*out << 1); }
-					*out |= (center < (im(y_,x_)) ? 1 : 0);
+					res = (res << 1);
+					res |= (center < (im(y_,x_)) ? 1 : 0);
 
-					i += 1;
 					// if all bits set, continue to next element
-					if (i % 64 == 0) { out++; }
+					if (++i % 64 == 0) {
+						*out = res;
+						out++;
+					}
 				}
 			}
+			if ((i - 1)%64 != 0) {
+				// write remaining bits
+				*out = res;
+			}
 		}
 
 		__host__ __device__  void operator()(ushort2 thread, ushort2 stride, ushort2 size) {
diff --git a/lib/libstereo/src/costs/census.hpp b/lib/libstereo/src/costs/census.hpp
index 4a74e2e82..2690ddfa7 100644
--- a/lib/libstereo/src/costs/census.hpp
+++ b/lib/libstereo/src/costs/census.hpp
@@ -52,6 +52,45 @@ namespace impl {
 	template<uint8_t WW, uint8_t WH, int BPP=1>
 	using CensusMatchingCost = HammingCost<WW*WH*BPP>;
 
+	/**
+	 * Normalized Hamming cost, same as above except float type and normalized
+	 * by number of bits (user set). Cost will always be within range [0, 1].
+	 */
+	template<int SIZE>
+	struct NormalizedHammingCost : DSImplBase<float> {
+		static_assert(SIZE%64 == 0);
+
+		typedef float Type;
+
+		NormalizedHammingCost(ushort w, ushort h, ushort dmin, ushort dmax) : DSImplBase<Type>({w,h,dmin,dmax}) {}
+		NormalizedHammingCost() : DSImplBase<Type>({0,0,0,0}) {}
+
+		__host__ __device__ inline Type operator()(const int y, const int x, const int d) const {
+			if ((x-d) < 0) { return COST_MAX; }
+			float c = 0;
+
+			#pragma unroll
+			for (int i = 0; i < WSTEP; i++) {
+				c+= popcount(l(y, x*WSTEP+i) ^ r(y, (x-d)*WSTEP+i));
+			}
+			return c*normalize;
+		}
+
+		// number of uint64_t values for each window
+		static constexpr int WSTEP = (SIZE - 1)/(sizeof(uint64_t)*8) + 1;
+		static constexpr Type COST_MAX = 1.0f;
+
+		Array2D<uint64_t>::Data l;
+		Array2D<uint64_t>::Data r;
+		float normalize = 1.0f; // set to 1.0f/(number of bits used)
+	};
+
+	template<uint8_t WW, uint8_t WH, int BPP=1>
+	using NormalizedCensusMatchingCost = NormalizedHammingCost<WW*WH*BPP>;
+
+	/**
+	 * WeightedCensusMatchingCost
+	 */
 	template<uint8_t R, uint8_t NBINS>
 	struct WeightedCensusMatchingCost : DSImplBase<unsigned short> {
 		static_assert(R % 2 == 1, "R must be odd");
-- 
GitLab