#include "carver.hpp"
#include <cudatl/fixed.hpp>
#include <ftl/cuda/weighting.hpp>

__device__ inline float depthErrorCoef(const ftl::rgbd::Camera &cam, float disps=1.0f) {
	return disps / (cam.baseline*cam.fx);
}

// ==== Reverse Verify Result ==================================================

// No colour scale calculations
__global__ void reverse_check_kernel(
	float* __restrict__ depth_in,
	const float* __restrict__ depth_original,
	int pitch4,
	int opitch4,
	float4x4 transformR,
	ftl::rgbd::Camera vintrin,
	ftl::rgbd::Camera ointrin
) {
	const int x = blockIdx.x*blockDim.x + threadIdx.x;
	const int y = blockIdx.y*blockDim.y + threadIdx.y;

	if (x < 0 || x >= vintrin.width || y < 0 || y >= vintrin.height) return;

	float d = depth_in[y*pitch4+x];

	const float err_coef = 0.001f; //depthErrorCoef(ointrin);
	
	int count = 10;  // Allow max 2cm of carving.
	while (--count >= 0) {
		float3 campos = transformR * vintrin.screenToCam(x,y,d);
		int2 spos = ointrin.camToScreen<int2>(campos);
		int ox = spos.x;
		int oy = spos.y;

		if (campos.z > 0.0f && ox >= 0 && ox < ointrin.width && oy >= 0 && oy < ointrin.height) {
			float d2 = depth_original[oy*opitch4+ox];

			// TODO: Threshold comes from depth error characteristics
			// If the value is significantly further then carve. Depth error
			// is not always easy to calculate, depends on source.
			if (!(d2 < ointrin.maxDepth && d2 - campos.z > d2*d2*err_coef)) break;

			d += 0.002f;
		} else break;
	}

	// Too much carving means just outright remove the point.
	depth_in[y*pitch4+x] = (count < 0) ? 0.0f : d;
}

__global__ void reverse_check_kernel(
	float* __restrict__ depth_in,
	const float* __restrict__ depth_original,
	const uchar4* __restrict__ in_colour,
	const uchar4* __restrict__ ref_colour,
	int8_t* __restrict__ colour_scale,
	int pitch4,
	int pitch,
	int opitch4,
	int in_col_pitch4,
	int o_col_pitch4,
	int cwidth,
	int cheight,
	float4x4 transformR,
	ftl::rgbd::Camera vintrin,
	ftl::rgbd::Camera ointrin
) {
	const int x = blockIdx.x*blockDim.x + threadIdx.x;
	const int y = blockIdx.y*blockDim.y + threadIdx.y;

	if (x < 0 || x >= vintrin.width || y < 0 || y >= vintrin.height) return;

	float d = depth_in[y*pitch4+x];

	// TODO: Externally provide the error coefficient
	const float err_coef = 0.0005f; //depthErrorCoef(ointrin);

	int ox = 0;
	int oy = 0;

	bool match = false;
	
	int count = 10;  // Allow max 2cm of carving.
	while (--count >= 0) {
		float3 campos = transformR * vintrin.screenToCam(x,y,d);
		int2 spos = ointrin.camToScreen<int2>(campos);
		ox = spos.x;
		oy = spos.y;

		if (campos.z > 0.0f && ox >= 0 && ox < ointrin.width && oy >= 0 && oy < ointrin.height) {
			float d2 = depth_original[oy*opitch4+ox];

			// TODO: Threshold comes from depth error characteristics
			// If the value is significantly further then carve. Depth error
			// is not always easy to calculate, depends on source.
			if (!(d2 < ointrin.maxDepth && d2 - campos.z > d2*d2*err_coef)) {
				match = fabsf(campos.z - d2) < d2*d2*err_coef; break;
			}

			d += 0.002f;  // TODO: Should this be += error or what?
		} else break;
	}

	// We found a match, so do a colour check
	//float idiff = 127.0f;
	//if (match) {
	/*	// Generate colour scaling
		const float ximgscale = float(cwidth) / float(ointrin.width);
		ox = float(ox) * ximgscale;
		const float yimgscale = float(cheight) / float(ointrin.height);
		oy = float(oy) * yimgscale;

		int cy = float(y) * yimgscale;
		int cx = float(x) * ximgscale;

		const uchar4 vcol = in_colour[cy*in_col_pitch4+cx];
		const uchar4 ocol = (match) ? ref_colour[oy*o_col_pitch4+ox] : vcol;

		float i1 = (0.2126f*float(vcol.z) + 0.7152f*float(vcol.y) + 0.0722f*float(vcol.x));
		float i2 = (0.2126f*float(ocol.z) + 0.7152f*float(ocol.y) + 0.0722f*float(ocol.x));
		idiff = i2-i1;

		//const float scaleX = (vcol.x == 0) ? 1.0f : float(ocol.x) / float(vcol.x);
		//const float scaleY = (vcol.y == 0) ? 1.0f : float(ocol.y) / float(vcol.y);
		//const float scaleZ = (vcol.z == 0) ? 1.0f : float(ocol.z) / float(vcol.z);
		//scale = (0.2126f*scaleZ + 0.7152f*scaleY + 0.0722f*scaleX);
	//}
	colour_scale[x+pitch*y] = int8_t(max(-127.0f,min(127.0f,idiff)));*/

	// Too much carving means just outright remove the point.
	depth_in[y*pitch4+x] = (count < 0) ? 0.0f : d;
}

void ftl::cuda::depth_carve(
	cv::cuda::GpuMat &depth_in,
	const cv::cuda::GpuMat &depth_original,
	const cv::cuda::GpuMat &in_colour,
	const cv::cuda::GpuMat &ref_colour,
	cv::cuda::GpuMat &colour_scale,
	const float4x4 &transformR,
	const ftl::rgbd::Camera &vintrin,
	const ftl::rgbd::Camera &ointrin,
	cudaStream_t stream)
{
	static constexpr int THREADS_X = 16;
	static constexpr int THREADS_Y = 8;

	const dim3 gridSize((depth_in.cols + THREADS_X - 1)/THREADS_X, (depth_in.rows + THREADS_Y - 1)/THREADS_Y);
	const dim3 blockSize(THREADS_X, THREADS_Y);

	colour_scale.create(depth_in.size(), CV_8U);

	reverse_check_kernel<<<gridSize, blockSize, 0, stream>>>(
		depth_in.ptr<float>(),
		depth_original.ptr<float>(),
		in_colour.ptr<uchar4>(),
		ref_colour.ptr<uchar4>(),
		colour_scale.ptr<int8_t>(),
		depth_in.step1(),
		colour_scale.step1(),
		depth_original.step1(),
		in_colour.step1()/4,
		ref_colour.step1()/4,
		in_colour.cols,
		in_colour.rows,
		transformR,
		vintrin, ointrin);

	cudaSafeCall( cudaGetLastError() );
}

// ==== Multi image MLS ========================================================

/*
 * Gather points for Moving Least Squares, from each source image
 */
 template <int SEARCH_RADIUS>
 __global__ void mls_gather_kernel(
	const half4* __restrict__ normals_in,
	half4* __restrict__ normals_out,
	const float* __restrict__ depth_origin,
	const float* __restrict__ depth_in,
	float4* __restrict__ centroid_out,
	float* __restrict__ contrib_out,
	float smoothing,
	float4x4 o_2_in,
	float4x4 in_2_o,
	ftl::rgbd::Camera camera_origin,
	ftl::rgbd::Camera camera_in,
	int npitch_out,
	int cpitch_out,
	int wpitch_out,
	int dpitch_o,
	int dpitch_i,
	int npitch_in
) {        
    const int x = blockIdx.x*blockDim.x + threadIdx.x;
    const int y = blockIdx.y*blockDim.y + threadIdx.y;

    if (x < 0 || y < 0 || x >= camera_origin.width || y >= camera_origin.height) return;

	float3 nX = make_float3(normals_out[y*npitch_out+x]);
	float3 aX = make_float3(centroid_out[y*cpitch_out+x]);
    float contrib = contrib_out[y*wpitch_out+x];

	float d0 = depth_origin[x+y*dpitch_o];
	if (d0 <= camera_origin.minDepth || d0 >= camera_origin.maxDepth) return;

	float3 X = camera_origin.screenToCam((int)(x),(int)(y),d0);

	int2 s = camera_in.camToScreen<int2>(o_2_in * X);

    // Neighbourhood
    for (int v=-SEARCH_RADIUS; v<=SEARCH_RADIUS; ++v) {
    for (int u=-SEARCH_RADIUS; u<=SEARCH_RADIUS; ++u) {
		const float d = (s.x+u >= 0 && s.x+u < camera_in.width && s.y+v >= 0 && s.y+v < camera_in.height) ? depth_in[s.x+u+(s.y+v)*dpitch_i] : 0.0f;
		if (d <= camera_in.minDepth || d >= camera_in.maxDepth) continue;

		// Point and normal of neighbour
		const float3 Xi = in_2_o * camera_in.screenToCam(s.x+u, s.y+v, d);
		const float3 Ni = in_2_o.getFloat3x3() * make_float3(normals_in[s.x+u+(s.y+v)*npitch_in]);

		// Gauss approx weighting function using point distance
		const float w = (Ni.x+Ni.y+Ni.z > 0.0f) ? ftl::cuda::spatialWeighting(X,Xi,smoothing) : 0.0f;

		aX += Xi*w;
		nX += Ni*w;
		contrib += w;
    }
	}

	normals_out[y*npitch_out+x] = make_half4(nX, 0.0f);
	centroid_out[y*cpitch_out+x] = make_float4(aX, 0.0f);
	contrib_out[y*wpitch_out+x] = contrib;
}

/**
 * Convert accumulated values into estimate of depth and normals at pixel.
 */
__global__ void mls_reduce_kernel(
	const float4* __restrict__ centroid,
	const half4* __restrict__ normals,
	const float* __restrict__ contrib_out,
	half4* __restrict__ normals_out,
	float* __restrict__ depth,
	ftl::rgbd::Camera camera,
	int npitch_in,
	int cpitch_in,
	int wpitch,
	int npitch,
	int dpitch
) {
	const int x = blockIdx.x*blockDim.x + threadIdx.x;
    const int y = blockIdx.y*blockDim.y + threadIdx.y;

	if (x >= 0 && y >= 0 && x < camera.width && y < camera.height) {
		float3 nX = make_float3(normals[y*npitch_in+x]);
		float3 aX = make_float3(centroid[y*cpitch_in+x]);
		float contrib = contrib_out[y*wpitch+x];

		//depth[x+y*dpitch] = X.z;
		normals_out[x+y*npitch] = make_half4(0.0f, 0.0f, 0.0f, 0.0f);

		float d0 = depth[x+y*dpitch];
		if (d0 < camera.minDepth || d0 > camera.maxDepth) return;
		float3 X = camera.screenToCam((int)(x),(int)(y),d0);
		
		nX /= contrib;  // Weighted average normal
		aX /= contrib;  // Weighted average point (centroid)

		// Signed-Distance Field function
		float fX = nX.x * (X.x - aX.x) + nX.y * (X.y - aX.y) + nX.z * (X.z - aX.z);

		// Calculate new point using SDF function to adjust depth (and position)
		X = X - nX * fX;

		depth[x+y*dpitch] = X.z;
		normals_out[x+y*npitch] = make_half4(nX / length(nX), 0.0f);
	}
}

#define T_PER_BLOCK 8

void ftl::cuda::mls_gather(
	const cv::cuda::GpuMat &normals_in,		// Source frame
	cv::cuda::GpuMat &normals_out,
	const cv::cuda::GpuMat &depth_origin,  // Rendered image
	const cv::cuda::GpuMat &depth_in,
	cv::cuda::GpuMat &centroid_out,
	cv::cuda::GpuMat &contrib_out,
	float smoothing,
	const float4x4 &o_2_in,
	const float4x4 &in_2_o,
	const ftl::rgbd::Camera &camera_origin,  // Virtual camera
	const ftl::rgbd::Camera &camera_in,
	cudaStream_t stream
) {

	const dim3 gridSize((depth_origin.cols + T_PER_BLOCK - 1)/T_PER_BLOCK, (depth_origin.rows + T_PER_BLOCK - 1)/T_PER_BLOCK);
	const dim3 blockSize(T_PER_BLOCK, T_PER_BLOCK);

	normals_out.create(depth_origin.size(), CV_16FC4);
	centroid_out.create(depth_origin.size(), CV_32FC4);
	contrib_out.create(depth_origin.size(), CV_32F);

	mls_gather_kernel<2><<<gridSize, blockSize, 0, stream>>>(
		normals_in.ptr<half4>(),
		normals_out.ptr<half4>(),
		depth_origin.ptr<float>(),
		depth_in.ptr<float>(),
		centroid_out.ptr<float4>(),
		contrib_out.ptr<float>(),
		smoothing,
		o_2_in,
		in_2_o,
		camera_origin,
		camera_in,
		normals_out.step1()/4,
		centroid_out.step1()/4,
		contrib_out.step1(),
		depth_origin.step1(),
		depth_in.step1(),
		normals_in.step1()/4
	);
	cudaSafeCall( cudaGetLastError() );
}

void ftl::cuda::mls_reduce(
	const cv::cuda::GpuMat &centroid,
	const cv::cuda::GpuMat &normals,
	const cv::cuda::GpuMat &contrib,
	cv::cuda::GpuMat &normals_out,
	cv::cuda::GpuMat &depth,
	const ftl::rgbd::Camera &camera,
	cudaStream_t stream
) {

	const dim3 gridSize((depth.cols + T_PER_BLOCK - 1)/T_PER_BLOCK, (depth.rows + T_PER_BLOCK - 1)/T_PER_BLOCK);
	const dim3 blockSize(T_PER_BLOCK, T_PER_BLOCK);

	normals_out.create(depth.size(), CV_16FC4);

	mls_reduce_kernel<<<gridSize, blockSize, 0, stream>>>(
		centroid.ptr<float4>(),
		normals.ptr<half4>(),
		contrib.ptr<float>(),
		normals_out.ptr<half4>(),
		depth.ptr<float>(),
		camera,
		normals.step1()/4,
		centroid.step1()/4,
		contrib.step1(),
		normals_out.step1()/4,
		depth.step1()
	);
	cudaSafeCall( cudaGetLastError() );
}

// ==== Apply colour scale =====================================================

template <int RADIUS>
__global__ void apply_colour_scaling_kernel(
	const int8_t* __restrict__ scale,
	uchar4* __restrict__ colour,
	int spitch,
	int cpitch,
	int swidth,
	int sheight,
	int cwidth,
	int cheight
) {
	const int x = blockIdx.x*blockDim.x + threadIdx.x;
	const int y = blockIdx.y*blockDim.y + threadIdx.y;

	if (x >= 0 && x < cwidth && y >= 0 && y < cheight) {
		int sx = (float(swidth) / float(cwidth)) * float(x);
		int sy = (float(sheight) / float(cheight)) * float(y);

		float s = 0.0f;
		int count = 0;
		//float mindiff = 100.0f;

		for (int v=-RADIUS; v<=RADIUS; ++v) {
			#pragma unroll
			for (int u=-RADIUS; u<=RADIUS; ++u) {
				float ns = (sx >= RADIUS && sy >= RADIUS && sx < swidth-RADIUS && sy < sheight-RADIUS) ? scale[sx+u+(sy+v)*spitch] : 0.0f;
				if (fabsf(ns) < 30) {
					s += ns;
					++count;
				}
			}
		}

		if (count > 0) s /= float(count);

		uchar4 c = colour[x+y*cpitch];
		colour[x+y*cpitch] = make_uchar4(
			max(0.0f, min(255.0f, float(c.x) + s)),
			max(0.0f, min(255.0f, float(c.y) + s)),
			max(0.0f, min(255.0f, float(c.z) + s)),
			255.0f
		);
	}
}

void ftl::cuda::apply_colour_scaling(
	const cv::cuda::GpuMat &scale,
	cv::cuda::GpuMat &colour,
	int radius,
	cudaStream_t stream)
{
	static constexpr int THREADS_X = 16;
	static constexpr int THREADS_Y = 8;

	const dim3 gridSize((colour.cols + THREADS_X - 1)/THREADS_X, (colour.rows + THREADS_Y - 1)/THREADS_Y);
	const dim3 blockSize(THREADS_X, THREADS_Y);

	apply_colour_scaling_kernel<2><<<gridSize, blockSize, 0, stream>>>(
		scale.ptr<int8_t>(),
		colour.ptr<uchar4>(),
		scale.step1(),
		colour.step1()/4,
		scale.cols,
		scale.rows,
		colour.cols,
		colour.rows
	);

	cudaSafeCall( cudaGetLastError() );
}