From a2aa192e65c82e3e60b7a075e760c9114217a6cb Mon Sep 17 00:00:00 2001
From: Nicolas Pope <nwpope@utu.fi>
Date: Wed, 6 May 2020 12:52:25 +0300
Subject: [PATCH] Add mini census window for hcensus

---
 lib/libstereo/middlebury/algorithms.hpp    |  4 +-
 lib/libstereo/src/algorithms/hcensussgm.cu |  8 ++--
 lib/libstereo/src/costs/census.cu          | 45 ++++++++++++++++++++++
 lib/libstereo/src/costs/census.hpp         | 27 +++++++++++++
 4 files changed, 78 insertions(+), 6 deletions(-)

diff --git a/lib/libstereo/middlebury/algorithms.hpp b/lib/libstereo/middlebury/algorithms.hpp
index b5bbca407..b89f798a9 100644
--- a/lib/libstereo/middlebury/algorithms.hpp
+++ b/lib/libstereo/middlebury/algorithms.hpp
@@ -120,7 +120,7 @@ namespace Impl {
 	};
 
 	struct HCensusSGM : public Algorithm {
-		HCensusSGM() { P1 = 36.0f; P2 = 96.0f; }
+		HCensusSGM() { P1 = 6.0f; P2 = 60.0f; }
 
 		virtual void run(const MiddleburyData &data, cv::Mat &disparity) override {
 			StereoHierCensusSgm stereo;
@@ -128,7 +128,7 @@ namespace Impl {
 			stereo.params.P2 = P2;
 			stereo.params.subpixel = subpixel;
 			stereo.params.lr_consistency = lr_consistency;
-
+			stereo.params.var_window = 5;
 			stereo.params.d_min = data.calib.vmin;
 			stereo.params.d_max = data.calib.vmax;
 			stereo.params.debug = false;
diff --git a/lib/libstereo/src/algorithms/hcensussgm.cu b/lib/libstereo/src/algorithms/hcensussgm.cu
index 2a7c7586c..751147678 100644
--- a/lib/libstereo/src/algorithms/hcensussgm.cu
+++ b/lib/libstereo/src/algorithms/hcensussgm.cu
@@ -6,7 +6,7 @@
 #include <opencv2/cudafilters.hpp>
 #include <opencv2/highgui.hpp>
 
-typedef MultiCostsWeighted<CensusMatchingCost,3> MatchingCost;
+typedef MultiCostsWeighted<MiniCensusMatchingCost,3> MatchingCost;
 
 static void variance_mask(cv::InputArray in, cv::OutputArray out, int wsize=3) {
 	if (in.isGpuMat() && out.isGpuMat()) {
@@ -40,9 +40,9 @@ static void variance_mask(cv::InputArray in, cv::OutputArray out, int wsize=3) {
 }
 
 struct StereoHierCensusSgm::Impl : public StereoSgm<MatchingCost, StereoHierCensusSgm::Parameters> {
-    CensusMatchingCost cost_fine;
-    CensusMatchingCost cost_medium;
-    CensusMatchingCost cost_coarse;
+    MiniCensusMatchingCost cost_fine;
+    MiniCensusMatchingCost cost_medium;
+    MiniCensusMatchingCost cost_coarse;
 	Array2D<uchar> l;
     Array2D<uchar> r;
     Array2D<float> var_fine;
diff --git a/lib/libstereo/src/costs/census.cu b/lib/libstereo/src/costs/census.cu
index 60cdfbaff..6c499c618 100644
--- a/lib/libstereo/src/costs/census.cu
+++ b/lib/libstereo/src/costs/census.cu
@@ -224,6 +224,51 @@ void CensusMatchingCost::set(cv::InputArray l, cv::InputArray r) {
 
 ////////////////////////////////////////////////////////////////////////////////
 
+void MiniCensusMatchingCost::set(const Array2D<uchar> &l, const Array2D<uchar> &r) {
+	if (pattern_ == CensusPattern::STANDARD) {
+		parallel2D<algorithms::CensusTransformRowMajor<5,3>>({l.data(), ct_l_.data()}, l.width, l.height);
+		parallel2D<algorithms::CensusTransformRowMajor<5,3>>({r.data(), ct_r_.data()}, r.width, r.height);
+	} else if (pattern_ == CensusPattern::GENERALISED) {
+		parallel2D<algorithms::GCensusTransformRowMajor<5,3>>({l.data(), ct_l_.data()}, l.width, l.height);
+		parallel2D<algorithms::GCensusTransformRowMajor<5,3>>({r.data(), ct_r_.data()}, r.width, r.height);
+	} else {
+		// TODO: 
+	}
+}
+
+void MiniCensusMatchingCost::set(const Array2D<uchar> &l, const Array2D<uchar> &r, const Array2D<uchar> &hl, const Array2D<uchar> &hr) {
+	if (pattern_ == CensusPattern::STANDARD) {
+		parallel2D<algorithms::HCensusTransformRowMajor<5,3>>({l.data(), hl.data(), ct_l_.data(), {ushort(hl.width), ushort(hl.height)}}, l.width, l.height);
+		parallel2D<algorithms::HCensusTransformRowMajor<5,3>>({r.data(), hr.data(), ct_r_.data(), {ushort(hr.width), ushort(hr.height)}}, r.width, r.height);
+	} else if (pattern_ == CensusPattern::GENERALISED) {
+		parallel2D<algorithms::HGCensusTransformRowMajor<5,3>>({l.data(), hl.data(), ct_l_.data(), {ushort(hl.width), ushort(hl.height)}}, l.width, l.height);
+		parallel2D<algorithms::HGCensusTransformRowMajor<5,3>>({r.data(), hr.data(), ct_r_.data(), {ushort(hr.width), ushort(hr.height)}}, r.width, r.height);
+	}
+}
+
+void MiniCensusMatchingCost::set(cv::InputArray l, cv::InputArray r) {
+	if (l.type() != CV_8UC1 || r.type() != CV_8UC1) { throw std::exception(); }
+	if (l.rows() != r.rows() || l.cols() != r.cols() || l.rows() != height() || l.cols() != width()) {
+		throw std::exception();
+	}
+
+	if (l.isGpuMat() && r.isGpuMat()) {
+		auto ml = l.getGpuMat();
+		auto mr = r.getGpuMat();
+		set(Array2D<uchar>(ml), Array2D<uchar>(mr));
+	}
+	else if (l.isMat() && r.isMat()) {
+		auto ml = l.getMat();
+		auto mr = r.getMat();
+		set(Array2D<uchar>(ml), Array2D<uchar>(mr));
+	}
+	else {
+		throw std::exception();
+	}
+}
+
+////////////////////////////////////////////////////////////////////////////////
+
 void WeightedCensusMatchingCost::set(const Array2D<uchar> &l, const Array2D<uchar> &r) {
 	parallel2D<algorithms::CensusTransformRowMajor<11,11>>({l.data(), ct_l_.data()}, l.width, l.height);
 	parallel2D<algorithms::CensusTransformRowMajor<11,11>>({r.data(), ct_r_.data()}, r.width, r.height);
diff --git a/lib/libstereo/src/costs/census.hpp b/lib/libstereo/src/costs/census.hpp
index 2b8119231..3b05156a4 100644
--- a/lib/libstereo/src/costs/census.hpp
+++ b/lib/libstereo/src/costs/census.hpp
@@ -176,6 +176,33 @@ protected:
 	CensusPattern pattern_ = CensusPattern::STANDARD;
 };
 
+class MiniCensusMatchingCost : public DSBase<impl::CensusMatchingCost<5,3,1>> {
+public:
+	typedef impl::CensusMatchingCost<5,3,1> DataType;
+	typedef unsigned short Type;
+
+	MiniCensusMatchingCost() : DSBase<DataType>(0, 0, 0, 0) {};
+	MiniCensusMatchingCost(int width, int height, int disp_min, int disp_max)
+		: DSBase<DataType>(width, height, disp_min, disp_max),
+			ct_l_(width*data().WSTEP, height), ct_r_(width*data().WSTEP,height)
+		{
+			data().l = ct_l_.data();
+			data().r = ct_r_.data();
+		}
+
+	inline void setPattern(CensusPattern p) { pattern_ = p; }
+
+	void set(cv::InputArray l, cv::InputArray r);
+	void set(const Array2D<uchar>& l, const Array2D<uchar>& r);
+	void set(const Array2D<uchar> &l, const Array2D<uchar> &r, const Array2D<uchar> &hl, const Array2D<uchar> &hr);
+	static constexpr Type COST_MAX = DataType::COST_MAX;
+
+protected:
+	Array2D<uint64_t> ct_l_;
+	Array2D<uint64_t> ct_r_;
+	CensusPattern pattern_ = CensusPattern::STANDARD;
+};
+
 class WeightedCensusMatchingCost : public DSBase<impl::WeightedCensusMatchingCost<11, 5>> {
 public:
 	typedef impl::WeightedCensusMatchingCost<11, 5> DataType;
-- 
GitLab