#include "ilw_cuda.hpp"
#include <ftl/cuda/weighting.hpp>

using ftl::cuda::TextureObject;
using ftl::rgbd::Camera;

#define WARP_SIZE 32
#define T_PER_BLOCK 8
#define FULL_MASK 0xffffffff

__device__ inline float warpMin(float e) {
	for (int i = WARP_SIZE/2; i > 0; i /= 2) {
		const float other = __shfl_xor_sync(FULL_MASK, e, i, WARP_SIZE);
		e = min(e, other);
	}
	return e;
}

#define COR_WIN_RADIUS 17
#define COR_WIN_SIZE (COR_WIN_RADIUS * COR_WIN_RADIUS)

__global__ void correspondence_energy_vector_kernel(
        TextureObject<float4> p1,
        TextureObject<float4> p2,
        TextureObject<uchar4> c1,
        TextureObject<uchar4> c2,
        TextureObject<float4> vout,
        TextureObject<float> eout,
        float4x4 pose2,  // Inverse
        Camera cam2) {

    // Each warp picks point in p1
    const int tid = (threadIdx.x + threadIdx.y * blockDim.x);
	const int x = (blockIdx.x*blockDim.x + threadIdx.x) / WARP_SIZE;
    const int y = blockIdx.y*blockDim.y + threadIdx.y;
    
    const float3 world1 = make_float3(p1.tex2D(x, y));
    const uchar4 colour1 = c1.tex2D(x, y);
	if (world1.x == MINF) return;
    const float3 camPos2 = pose2 * world1;
    const uint2 screen2 = cam2.camToScreen<uint2>(camPos2);
    
    float bestcost = 1.1f;
    float nextbest = 1.0f;
    float3 bestpoint;

    // Project to p2 using cam2
    // Each thread takes a possible correspondence and calculates a weighting
    const int lane = tid % WARP_SIZE;
	for (int i=lane; i<COR_WIN_SIZE; i+=WARP_SIZE) {
		const float u = (i % COR_WIN_RADIUS) - (COR_WIN_RADIUS / 2);
        const float v = (i / COR_WIN_RADIUS) - (COR_WIN_RADIUS / 2);
        
        const float3 world2 = make_float3(p2.tex2D(screen2.x+u, screen2.y+v));
        const uchar4 colour2 = c2.tex2D(screen2.x+u, screen2.y+v);
		if (world2.x == MINF) continue;

        // Determine degree of correspondence
        float cost = 1.0f - ftl::cuda::spatialWeighting(world1, world2, 0.04f);
        cost += 1.0f - ftl::cuda::colourWeighting(colour1, colour2, 50.0f);
        cost /= 2.0f;

        if (cost < bestcost) {
            bestpoint = world2;
            nextbest = bestcost;
            bestcost = cost;
        }
    }

    const float mincost = warpMin(bestcost);
    bool best = mincost == bestcost;
    bestcost = (best) ? nextbest : bestcost;
    const float confidence = mincost / warpMin(bestcost);

    if (best && mincost < 1.0f) {
        vout(x,y) = vout.tex2D(x, y) + make_float4(
            (bestpoint.x - world1.x),
            (bestpoint.y - world1.y),
            (bestpoint.z - world1.z),
            mincost);
		eout(x,y) = max(eout(x,y), (1.0f - mincost) * 7.0f); //confidence * 5.0f;

		// FIXME: This needs to be summed across all frames
        //eout(x,y) = max(eout(x, y), confidence * 7.0f);
    } else if (mincost >= 1.0f && lane == 0) {
        //vout(x,y) = make_float4(0.0f);
        //eout(x,y) = 0.0f;
    }
}

void ftl::cuda::correspondence_energy_vector(
        TextureObject<float4> &p1,
        TextureObject<float4> &p2,
        TextureObject<uchar4> &c1,
        TextureObject<uchar4> &c2,
        TextureObject<float4> &vout,
        TextureObject<float> &eout,
        float4x4 &pose2,
        const Camera &cam2,
        cudaStream_t stream) {

    const dim3 gridSize((p1.width() + 2 - 1)/2, (p1.height() + T_PER_BLOCK - 1)/T_PER_BLOCK);
    const dim3 blockSize(2*WARP_SIZE, T_PER_BLOCK);

    //printf("COR SIZE %d,%d\n", p1.width(), p1.height());

    correspondence_energy_vector_kernel<<<gridSize, blockSize, 0, stream>>>(
        p1, p2, c1, c2, vout, eout, pose2, cam2
    );
    cudaSafeCall( cudaGetLastError() );
}

//==============================================================================



__global__ void move_points_kernel(
    ftl::cuda::TextureObject<float4> p,
    ftl::cuda::TextureObject<float4> v,
    ftl::rgbd::Camera camera,
    float rate) {

    const unsigned int x = blockIdx.x*blockDim.x + threadIdx.x;
    const unsigned int y = blockIdx.y*blockDim.y + threadIdx.y;
    
    if (x < p.width() && y < p.height()) {
        const float4 world = p(x,y);
        const float4 vec = v.tex2D((int)x,(int)y);

        // Calculate screen space distortion with neighbours

        if (vec.w > 0.0f) {
            p(x,y) = world + rate * vec;
        }
    }
}


void ftl::cuda::move_points(
        ftl::cuda::TextureObject<float4> &p,
        ftl::cuda::TextureObject<float4> &v,
        const ftl::rgbd::Camera &camera,
        float rate,
        cudaStream_t stream) {

    const dim3 gridSize((p.width() + T_PER_BLOCK - 1)/T_PER_BLOCK, (p.height() + T_PER_BLOCK - 1)/T_PER_BLOCK);
    const dim3 blockSize(T_PER_BLOCK, T_PER_BLOCK);

    move_points_kernel<<<gridSize, blockSize, 0, stream>>>(p,v,camera,rate);

    cudaSafeCall( cudaGetLastError() );
}
