Skip to content
Snippets Groups Projects
Commit bf281645 authored by Nicolas Pope's avatar Nicolas Pope
Browse files

Allow selection of CT shape

parent 4afea678
No related branches found
No related tags found
1 merge request!351Circle census transform or SGM
Pipeline #33077 failed
......@@ -28,6 +28,7 @@ limitations under the License.
#include <stdint.h>
#include "libsgm_config.h"
#include "libsgm_parameters.hpp"
#include <cuda_runtime.h>
#if defined(LIBSGM_SHARED)
......@@ -71,7 +72,8 @@ namespace sgm {
int P2;
float uniqueness;
bool subpixel;
Parameters(int P1 = 10, int P2 = 120, float uniqueness = 0.95f, bool subpixel = false) : P1(P1), P2(P2), uniqueness(uniqueness), subpixel(subpixel) {}
CensusShape ct_shape;
Parameters(int P1 = 10, int P2 = 120, float uniqueness = 0.95f, bool subpixel = false, CensusShape ct_shape = CensusShape::CIRCLE_4_2) : P1(P1), P2(P2), uniqueness(uniqueness), subpixel(subpixel), ct_shape(ct_shape) {}
};
/**
......
#pragma once
namespace sgm {
enum class CensusShape {
CT_5X5,
CS_CT_9X7,
CIRCLE_4_2
};
}
......@@ -21,14 +21,11 @@ namespace sgm {
namespace {
static constexpr int WINDOW_WIDTH = 5;
static constexpr int WINDOW_HEIGHT = 5;
static constexpr int BLOCK_SIZE = 128;
static constexpr int LINES_PER_BLOCK = 16;
/* Centre symmetric census */
template <typename T>
template <typename T, int WINDOW_WIDTH, int WINDOW_HEIGHT>
__global__ void cs_census_transform_kernel(
feature_type *dest,
const T *src,
......@@ -104,7 +101,7 @@ __global__ void cs_census_transform_kernel(
}
}
template <typename T>
template <typename T, int WINDOW_WIDTH, int WINDOW_HEIGHT>
__global__ void census_transform_kernel(
feature_type* __restrict__ dest,
const T* __restrict__ src,
......@@ -145,24 +142,27 @@ void enqueue_census_transform(
int width,
int height,
int pitch,
sgm::CensusShape ct_shape,
cudaStream_t stream)
{
/* Disable the original center symmetric algorithm */
if (false) {
const int width_per_block = BLOCK_SIZE - WINDOW_WIDTH + 1;
if (ct_shape == sgm::CensusShape::CS_CT_9X7) {
const int width_per_block = BLOCK_SIZE - 9 + 1;
const int height_per_block = LINES_PER_BLOCK;
const dim3 gdim(
(width + width_per_block - 1) / width_per_block,
(height + height_per_block - 1) / height_per_block);
const dim3 bdim(BLOCK_SIZE);
cs_census_transform_kernel<<<gdim, bdim, 0, stream>>>(dest, src, width, height, pitch);
} else {
cs_census_transform_kernel<T, 9, 7><<<gdim, bdim, 0, stream>>>(dest, src, width, height, pitch);
} else if (ct_shape == sgm::CensusShape::CT_5X5) {
static constexpr int THREADS_X = 16;
static constexpr int THREADS_Y = 16;
const dim3 gdim((width + THREADS_X - 1)/THREADS_X, (height + THREADS_Y - 1)/THREADS_Y);
const dim3 bdim(THREADS_X, THREADS_Y);
census_transform_kernel<<<gdim, bdim, 0, stream>>>(dest, src, width, height, pitch);
census_transform_kernel<T, 5, 5><<<gdim, bdim, 0, stream>>>(dest, src, width, height, pitch);
} else if (ct_shape == sgm::CensusShape::CIRCLE_4_2) {
}
}
......@@ -180,13 +180,14 @@ void CensusTransform<T>::enqueue(
int width,
int height,
int pitch,
sgm::CensusShape ct_shape,
cudaStream_t stream)
{
if(m_feature_buffer.size() < static_cast<size_t>(width * height)){
m_feature_buffer = DeviceBuffer<feature_type>(width * height);
}
enqueue_census_transform(
m_feature_buffer.data(), src, width, height, pitch, stream);
m_feature_buffer.data(), src, width, height, pitch, ct_shape, stream);
}
template class CensusTransform<uint8_t>;
......
......@@ -19,6 +19,7 @@ limitations under the License.
#include "device_buffer.hpp"
#include "types.hpp"
#include "libsgm_parameters.hpp"
namespace sgm {
......@@ -43,6 +44,7 @@ public:
int width,
int height,
int pitch,
sgm::CensusShape ct_shape,
cudaStream_t stream);
};
......
......@@ -58,12 +58,13 @@ public:
float uniqueness,
bool subpixel,
int min_disp,
sgm::CensusShape ct_shape,
cudaStream_t stream)
{
m_census_left.enqueue(
src_left, width, height, src_pitch, stream);
src_left, width, height, src_pitch, ct_shape, stream);
m_census_right.enqueue(
src_right, width, height, src_pitch, stream);
src_right, width, height, src_pitch, ct_shape, stream);
m_path_aggregation.enqueue(
m_census_left.get_output(),
m_census_right.get_output(),
......@@ -109,6 +110,7 @@ void SemiGlobalMatching<T, MAX_DISPARITY>::execute(
float uniqueness,
bool subpixel,
int min_disp,
sgm::CensusShape ct_shape,
cudaStream_t stream)
{
m_impl->enqueue(
......@@ -119,7 +121,7 @@ void SemiGlobalMatching<T, MAX_DISPARITY>::execute(
penalty1, penalty2,
weights, weights_pitch,
uniqueness, subpixel,
min_disp,
min_disp, ct_shape,
stream);
//cudaStreamSynchronize(0);
}
......@@ -141,6 +143,7 @@ void SemiGlobalMatching<T, MAX_DISPARITY>::enqueue(
float uniqueness,
bool subpixel,
int min_disp,
sgm::CensusShape ct_shape,
cudaStream_t stream)
{
m_impl->enqueue(
......@@ -151,7 +154,7 @@ void SemiGlobalMatching<T, MAX_DISPARITY>::enqueue(
penalty1, penalty2,
weights, weights_pitch,
uniqueness, subpixel,
min_disp,
min_disp, ct_shape,
stream);
}
......
......@@ -20,6 +20,7 @@ limitations under the License.
#include <memory>
#include <cstdint>
#include "types.hpp"
#include "libsgm_parameters.hpp"
namespace sgm {
......@@ -54,6 +55,7 @@ public:
float uniqueness,
bool subpixel,
int min_disp,
sgm::CensusShape ct_shape,
cudaStream_t stream);
void enqueue(
......@@ -72,6 +74,7 @@ public:
float uniqueness,
bool subpixel,
int min_disp,
sgm::CensusShape ct_shape,
cudaStream_t stream);
};
......
......@@ -29,7 +29,7 @@ namespace sgm {
public:
using output_type = sgm::output_type;
virtual void execute(output_type* dst_L, output_type* dst_R, const void* src_L, const void* src_R,
int w, int h, int sp, int dp, unsigned int P1, const uint8_t *P2, const uint8_t *weights, int weights_pitch, float uniqueness, bool subpixel, int min_disp, cudaStream_t stream) = 0;
int w, int h, int sp, int dp, unsigned int P1, const uint8_t *P2, const uint8_t *weights, int weights_pitch, float uniqueness, bool subpixel, int min_disp, sgm::CensusShape ct_shape, cudaStream_t stream) = 0;
virtual ~SemiGlobalMatchingBase() {}
};
......@@ -38,9 +38,9 @@ namespace sgm {
class SemiGlobalMatchingImpl : public SemiGlobalMatchingBase {
public:
void execute(output_type* dst_L, output_type* dst_R, const void* src_L, const void* src_R,
int w, int h, int sp, int dp, unsigned int P1, const uint8_t *P2, const uint8_t *weights, int weights_pitch, float uniqueness, bool subpixel, int min_disp, cudaStream_t stream) override
int w, int h, int sp, int dp, unsigned int P1, const uint8_t *P2, const uint8_t *weights, int weights_pitch, float uniqueness, bool subpixel, int min_disp, sgm::CensusShape ct_shape, cudaStream_t stream) override
{
sgm_engine_.execute(dst_L, dst_R, (const input_type*)src_L, (const input_type*)src_R, w, h, sp, dp, P1, P2, weights, weights_pitch, uniqueness, subpixel, min_disp, stream);
sgm_engine_.execute(dst_L, dst_R, (const input_type*)src_L, (const input_type*)src_R, w, h, sp, dp, P1, P2, weights, weights_pitch, uniqueness, subpixel, min_disp, ct_shape, stream);
}
private:
SemiGlobalMatching<input_type, DISP_SIZE> sgm_engine_;
......@@ -176,7 +176,7 @@ namespace sgm {
d_left_disp = dst; // when threre is no device-host copy or type conversion, use passed buffer
cu_res_->sgm_engine->execute((uint16_t*)d_tmp_left_disp, (uint16_t*)d_tmp_right_disp,
d_input_left, d_input_right, width, height, src_pitch, dst_pitch, param_.P1, P2, weights, weights_pitch, param_.uniqueness, param_.subpixel, min_disp, stream);
d_input_left, d_input_right, width, height, src_pitch, dst_pitch, param_.P1, P2, weights, weights_pitch, param_.uniqueness, param_.subpixel, min_disp, param_.ct_shape, stream);
sgm::details::median_filter((uint16_t*)d_tmp_left_disp, (uint16_t*)d_left_disp, width, height, dst_pitch, stream);
sgm::details::median_filter((uint16_t*)d_tmp_right_disp, (uint16_t*)d_right_disp, width, height, dst_pitch, stream);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment