package fntd.core.tracking;

import fntd.core.options.LinkerOptions;
import java.util.Arrays;

import java.util.ArrayList;

/**
 * Class for particle linking and trajectory relinking
 *
 * @author Jasper Kouwenberg
 */
public class KouwenbergLinker {

    /**
     * Holds the current (maximum) linkrange. Is adjusted to stay within the
     * total number of slices.
     */
    int curr_linkrange;

    /**
     * Cost matrix. The first dimension represents the particles in the current
     * slice. The second dimension represents the cumulative particles in the
     * slices within the linkrange
     */
    float[][] C;

    /**
     * Association vector filled with the linking particles indices for PAn
     */
    int[] G;

    /**
     * Temporary value that holds the maximum allowed linking cost. Usually
     * given by ( 2 * max_displacement )^2
     */
    float max_cost;

    /**
     * Hold the current linking candidates
     */
    ArrayList<FNTDtrackSpot> p1, p2;

    ArrayList<FNTDtrackSpot>[] eligibleParticles;
    ArrayList<FNTDtrackSpot>[] filteredParticles;
    
    /**
     * Main function for linking particles in a stack of slices.
     *
     * @param particles Array of FNTDtrackSpot vectors
     * @param lo LinkerOptions
     * @return boolean
     */
    public boolean linkParticles(ArrayList<FNTDtrackSpot>[] particles, LinkerOptions lo) {
        int n, i, j, k, count;
        float lx, ly, lx1, ly1, lx2, ly2, l1_m, l2_m;
        int dim_PAn, dim_PAn_k;
        FNTDtrackSpot temp;

        /* Reduce particle list to those that are eligible for tracking */
        eligibleParticles = new ArrayList[particles.length];
        filteredParticles = new ArrayList[particles.length];

        for (int sliceIdx = 0; sliceIdx < particles.length; sliceIdx++) {
            eligibleParticles[sliceIdx] = new ArrayList<>();
            filteredParticles[sliceIdx] = new ArrayList<>();
            for (FNTDtrackSpot ts : particles[sliceIdx]) {
                if (ts.filteredTrackspot == FNTDtrackSpot.FILTERED_TRACKSPOT_NO) {
                    eligibleParticles[sliceIdx].add(ts);
                } else {
                    filteredParticles[sliceIdx].add(ts);
                }
            }
        }

        /* RUN LINKER ON ELIGIBLE PARTICLES */
        this.curr_linkrange = lo.linkRange;
        this.max_cost = (float) Math.pow(2 * lo.maxDisplacement, 2);

        // Adjust linkrange to stay within bounds
        if (eligibleParticles.length < (this.curr_linkrange + 1)) {
            this.curr_linkrange = eligibleParticles.length - 1;
        }

        for (n = 0; n < eligibleParticles.length - this.curr_linkrange; n++) {
            dim_PAn = eligibleParticles[n].size();

            // Calculate the eligibleParticles that have been encountered at frame x.
            // Used to fill the second dimension of matrix g
            int[] acc_frame_part = new int[this.curr_linkrange + 1];
            acc_frame_part[0] = 0;

            for (i = 1; i <= this.curr_linkrange; i++) {
                acc_frame_part[i] = acc_frame_part[i - 1] + eligibleParticles[n + i].size();
            }

            dim_PAn_k = acc_frame_part[this.curr_linkrange];

            this.G = new int[dim_PAn];
            Arrays.fill(this.G, -1);

            this.C = new float[dim_PAn][dim_PAn_k];

            for (k = 0; k < this.curr_linkrange; k++) {
                this.p1 = eligibleParticles[n];
                this.p2 = eligibleParticles[n + (k + 1)];

                // Loop through the eligibleParticles in slice n
                for (i = 0; i < dim_PAn; i++) {

                    // Loop through the eligibleParticles in slice k
                    for (j = acc_frame_part[k]; j < acc_frame_part[k + 1]; j++) {
                        int p2_index = j - acc_frame_part[k];

                        float distance = this.p1.get(i).distanceSq(p2.get(p2_index));

                        this.C[i][j] = (float) lo.lDynamic * (distance + k * k);

                        // Skip this link if the distance is too large
                        if (this.C[i][j] > this.max_cost) {
                            continue;
                        }

                        // Calculate the cost based on the distance squared and the differences
                        // in intensity momenta
                        this.C[i][j] += lo.l_i / (lo.radius * lo.radius * Math.PI)
                                * (Math.pow(this.p1.get(i).m0 - this.p2.get(p2_index).m0, 2)
                                + Math.pow(this.p1.get(i).m2 - this.p2.get(p2_index).m2, 2));

                        // Skip this link if the cost is too large
                        if (this.C[i][j] > this.max_cost) {
                            continue;
                        }

                        // If velocity feature is used and particle p1 has been linked before
                        if (lo.l_v > 0 && this.p1.get(i).distance > 0) {
                            lx = (p2.get(p2_index).iX - this.p1.get(i).iX) / (k + 1);
                            ly = (p2.get(p2_index).iY - this.p1.get(i).iY) / (k + 1);

                            float f_magn_sq;
                            if (lo.angleHistory > 1) {
                                f_magn_sq = (lx * lx + ly * ly) - this.p1.get(i).linkModuleHSq();
                            } else {
                                f_magn_sq = (lx * lx + ly * ly) - this.p1.get(i).linkModuleSq();
                            }

                            // Add the cost of the velocity differences squared
                            this.C[i][j] += lo.l_v * Math.abs(f_magn_sq);

                            // Skip this link if the cost is too large
                            if (this.C[i][j] > this.max_cost) {
                                continue;
                            }
                        }
                        // If angle feature is used
                        if (lo.l_a > 0) {
                            if (lo.angleHistory > 1) { // Use vector average velocity vectors
                                l1_m = this.p1.get(i).linkModuleH();

                                lx1 = this.p1.get(i).lxh / l1_m;
                                ly1 = this.p1.get(i).lyh / l1_m;
                            } else { // Use velocity vectors of the last link
                                l1_m = this.p1.get(i).linkModule();

                                lx1 = this.p1.get(i).lx / l1_m;
                                ly1 = this.p1.get(i).ly / l1_m;
                            }

                            if (l1_m > 0) { // If particle p1 has been linked before, ie has a velocity vector
                                // Set a standard cost if the distance is below particle radius / 2
                                // Allows for small variantes for eligibleParticles with a small incidence angle
                                if (Math.sqrt(distance) <= lo.radius / 2 && l1_m <= lo.radius / 2) {
                                    this.C[i][j] += lo.lDynamic * Math.pow((Math.cos(lo.maxAngle * Math.PI / 180.0f) - 1) / 2 * lo.maxDisplacement, 2);
                                    continue;
                                }

                                lx2 = this.p2.get(p2_index).iX - this.p1.get(i).iX;
                                ly2 = this.p2.get(p2_index).iY - this.p1.get(i).iY;

                                l2_m = (float) Math.sqrt(lx2 * lx2 + ly2 * ly2);

                                // Normalize velocity vectors
                                lx2 /= l2_m;
                                ly2 /= l2_m;

                                // Calculate the cosine
                                float cos_phi = lx1 * lx2 + ly1 * ly2;

                                if (cos_phi >= Math.cos(lo.maxAngle * Math.PI / 180.0f)) // If angle is smaller than maximum angle given by the user, add cost
                                {
                                    this.C[i][j] += lo.l_a * Math.pow((cos_phi - 1) / 2 * lo.maxDisplacement, 2);
                                } else // Else, set cost in infinity
                                {
                                    this.C[i][j] = Float.MAX_VALUE;
                                }
                            } else // A cost slightly larger than the cost for the maximum angle is added when this link would form a new trajectory
                            // as to favor extending existing trajectories compared to creating new ones
                            {
                                this.C[i][j] += lo.l_a * 1.2 * Math.pow((Math.cos(lo.maxAngle * Math.PI / 180.0f) - 1) / 2 * lo.maxDisplacement, 2);
                            }
                        }

                        // Check whether the particle p1 is attempting to link to has actually already been linked when due to linkrange is larger then 1.
                        // If so, set the cost for this link to infinity if the cost is higher than the existing link to the target particle
                        // as to prevent unneccessary overwriting of the link
                        if (this.p2.get(p2_index).prevTrackSpotNo > -1 && this.C[i][j] > this.p2.get(p2_index).cost) {
                            this.C[i][j] = Float.MAX_VALUE;
                        }
                    }
                }
            }

            // Find the best link for all the eligibleParticles within the linkrange (personal preference compared to
            // looking for the best link for each particle in slice m)
            for (j = 0; j < dim_PAn_k; j++) {
                this.FindBestLink(j, dim_PAn, dim_PAn_k);
            }

            for (k = 0; k < this.curr_linkrange; k++) {
                this.p2 = eligibleParticles[n + (k + 1)];

                for (i = 0; i < dim_PAn; i++) {
                    // Loop through the eligibleParticles in slice n
                    for (j = acc_frame_part[k]; j < acc_frame_part[k + 1]; j++) {
                        int p2_index = j - acc_frame_part[k];
                        if (this.G[i] == j) { // These eligibleParticles have been linked
                            if (this.p2.get(p2_index).prevTrackSpotSliceIdx >= 0) { // Particle p2 was already linked due to linkrange is larger then 1
                                eligibleParticles[this.p2.get(p2_index).prevTrackSpotSliceIdx].get(this.p2.get(p2_index).prevTrackSpotNo).nextTrackSpotNo = -1; // Unset the link information in the previous
                                eligibleParticles[this.p2.get(p2_index).prevTrackSpotSliceIdx].get(this.p2.get(p2_index).prevTrackSpotNo).nextTrackSpotSliceIdx = -1; // linking partner of particle p2
                            }

                            // Store the linking information
                            this.p1.get(i).nextTrackSpotNo = p2_index;
                            this.p1.get(i).nextTrackSpotSliceIdx = n + (k + 1);
                            this.p2.get(p2_index).prevTrackSpotNo = i;
                            this.p2.get(p2_index).prevTrackSpotSliceIdx = n;

                            this.p2.get(p2_index).distance = (float) Math.sqrt(this.p1.get(i).distance(this.p2.get(p2_index)));
                            this.p2.get(p2_index).cost = this.C[i][j];

                            if (lo.l_v > 0 || lo.l_a > 0) { // If velocity or angle features are used
                                this.p2.get(p2_index).lx = (this.p2.get(p2_index).iX - this.p1.get(i).iX) / (k + 1);
                                this.p2.get(p2_index).ly = (this.p2.get(p2_index).iY - this.p1.get(i).iY) / (k + 1);

                                if (lo.angleHistory > 1) { // If vector averages are used, calculate the new averages
                                    count = 1;
                                    lx = 0;
                                    ly = 0;
                                    temp = this.p1.get(i);

                                    while (count <= lo.angleHistory) {
                                        if (temp.prevTrackSpotSliceIdx < 0 || temp.prevTrackSpotNo < 0) // Continue while there are eligibleParticles available in trajectory
                                        {
                                            break;
                                        }

                                        lx += temp.lx;
                                        ly += temp.ly;
                                        count++;

                                        temp = eligibleParticles[temp.prevTrackSpotSliceIdx].get(temp.prevTrackSpotNo);
                                    }

                                    this.p2.get(p2_index).lxh = (this.p2.get(p2_index).lx + lx) / count;
                                    this.p2.get(p2_index).lyh = (this.p2.get(p2_index).ly + ly) / count;
                                }
                            }
                        }
                    }
                }
            }
            if (n == (eligibleParticles.length - this.curr_linkrange - 1) && this.curr_linkrange > 1) // Adjust the linkrange to stay within bounds
            {
                this.curr_linkrange--;
            }
        }

        determineTrackNos();

        /* Copy back linked and filtered particles */
        for (int sliceIdx = 0; sliceIdx < particles.length; sliceIdx++) {
            particles[sliceIdx] = new ArrayList<>();
            for (FNTDtrackSpot ts : eligibleParticles[sliceIdx]) {
                particles[sliceIdx].add(ts);
            }
            for (FNTDtrackSpot ts : filteredParticles[sliceIdx]) {
                particles[sliceIdx].add(ts);
            }
        }
        return true;
    }

    /**
     * Self-calling function that finds the optimal link for a particle. Can
     * break sub-optimal links and reassing new links to the particles from the
     * broken link.
     *
     * @param part_index Index of the particle that is going to be linked
     * @param dim_PAn Number of particles in the current slice
     * @param dim_PAn_next Number of particles in the next slice
     */
    private void FindBestLink(int part_index, int dim_PAn, int dim_PAn_next) {
        double min = this.max_cost, old_cost = Float.MAX_VALUE;
        int best = -1, old_index = 0;

        // Loop through all the particles in frame m
        for (int i = 0; i < dim_PAn; i++) {
            // If the cost for forming this link is lower then the existing and the cost is lower
            // than the lowest cost found till this point
            if (this.C[i][part_index] < min && (this.G[i] < 0 || (this.G[i] > -1 && this.C[i][part_index] < this.C[i][this.G[i]]))) {
                best = i; // Set index of best link
                min = this.C[i][part_index]; // Set minimum cost to the cost of this link
            }
        }

        if (best == -1) // No link satisfying the above conditions has been found
        {
            return;
        }

        if (this.G[best] > -1) {
            old_cost = this.C[best][this.G[best]];
            old_index = this.G[best];
        }

        this.G[best] = part_index;

        if (old_cost < Float.MAX_VALUE) // If this link overwrites an existing link
        {
            this.FindBestLink(old_index, dim_PAn, dim_PAn_next); // Find next best link for unlinked particle
        }
    }

    /**
     * Relinks tracks that have been interrupted using an approach similar to
     * linkParticles.
     *
     * @param trajectories Vector of Trajectories
     * @param particles Vector of FNTDParticles
     * @param lo LinkerOptions
     */
    public static void reLinkTrajectories(ArrayList<FNTDtrack> trajectories, ArrayList<FNTDtrackSpot>[] particles, LinkerOptions lo) {
        ArrayList<FNTDtrack> to_be_removed = new ArrayList<FNTDtrack>();

        int min_particles = 3;

        int index, direction;
        double cos_phi, max_cost, cost;
        double l1_m, l2_m, lx1, ly1, lx2, ly2, l_x, l_y, l_tx, l_ty;
        double[] dist = new double[2];

        ArrayList<FNTDtrackSpot> new_traj_particles;
        FNTDtrack new_trac, t1, t2;

        int total = trajectories.size();

        // Loop through the trajectories
        for (int l = 0; l < total; l++) {
            t1 = trajectories.get(l);
            index = -1;
            max_cost = lo.relink_cost_factor * 4 * lo.relink_distance * lo.relink_distance; // Calculate the maximum allowed cost, based on the cost factor given by the user
            for (int m = l + 1; m < total; m++) { // Loop through the trajectories with a higher index than trajectory m, as to form a triangular matrix
                t2 = trajectories.get(m);

                if (t1.trackSpots.size() < min_particles || t2.trackSpots.size() < min_particles) // Ignore trajectories that are too short
                {
                    continue;
                }

                int lastIdx1 = t1.trackSpots.size() - 1;
                int lastIdx2 = t2.trackSpots.size() - 1;
                dist[0] = Math.pow(t1.trackSpots.get(lastIdx1).iX - t2.trackSpots.get(0).iX, 2) + Math.pow(t1.trackSpots.get(lastIdx1).iY - t2.trackSpots.get(0).iY, 2) + Math.pow(t1.trackSpots.get(lastIdx1).iZ - t2.trackSpots.get(0).iZ, 2);
                dist[1] = Math.pow(t1.trackSpots.get(0).iX - t2.trackSpots.get(lastIdx2).iX, 2) + Math.pow(t1.trackSpots.get(0).iY - t2.trackSpots.get(lastIdx2).iY, 2) + Math.pow(t1.trackSpots.get(0).iZ - t2.trackSpots.get(lastIdx2).iZ, 2);

                if (Math.sqrt(dist[0]) > lo.relink_distance && Math.sqrt(dist[1]) > lo.relink_distance) // Check if either of the end-to-end distances are within the allowed displacement
                {
                    continue;
                }

                // Find the correct order of the trajectories based on the slice indices
                if (t2.firstSlice - t1.lastSlice > 0 && t2.firstSlice - t1.lastSlice <= 2 * lo.linkRange + 1) {
                    direction = 0;
                } else if (t1.firstSlice - t2.lastSlice > 0 && t1.firstSlice - t2.lastSlice < 2 * lo.linkRange + 1) {
                    direction = 1;
                } else {
                    continue;
                }

                l1_m = Math.sqrt(Math.pow(t1.lx, 2) + Math.pow(t1.ly, 2));
                l2_m = Math.sqrt(Math.pow(t2.lx, 2) + Math.pow(t2.ly, 2));

                lx1 = t1.lx / l1_m;
                ly1 = t1.ly / l1_m;
                lx2 = t2.lx / l2_m;
                ly2 = t2.ly / l2_m;

                cos_phi = lx1 * lx2 + ly1 * ly2;

                // Check whether the angle between the trajectory directions is within the user given limit
                if (Math.abs(Math.acos(cos_phi)) > lo.maxAngle * Math.PI / 180) {
                    continue;
                }

                l_x = 0;
                l_y = 0;
                // Calculate the velocity vector between the connection points of the trajectories, based on the min_particles last particles in the first
                // trajectory and the min_particles first particles in the latter trajectory.
                for (int i = 0; i < min_particles; i++) {
                    if (direction == 0) {
                        l_tx = t2.trackSpots.get(i).iX - t1.trackSpots.get(t1.trackSpots.size() - (1 + i)).iX;
                        l_ty = t2.trackSpots.get(i).iY - t1.trackSpots.get(t1.trackSpots.size() - (1 + i)).iY;
                    } else {
                        l_tx = t1.trackSpots.get(i).iX - t2.trackSpots.get(t2.trackSpots.size() - (1 + i)).iX;
                        l_ty = t1.trackSpots.get(i).iY - t2.trackSpots.get(t2.trackSpots.size() - (1 + i)).iY;
                    }
                    l1_m = Math.sqrt(Math.pow(l_tx, 2) + Math.pow(l_ty, 2));

                    l_x += l_tx / (min_particles * l1_m);
                    l_y += l_ty / (min_particles * l1_m);
                }

                cos_phi = l_x * (lx1 + lx2) / 2 + l_y * (ly1 + ly2) / 2;

                // Check whether the connection vector between the trajectories aligns with the trajectory velocity vectors
                if (Math.abs(Math.acos(cos_phi)) > lo.maxAngle * Math.PI / 180) {
                    continue;
                }

                cost = lo.lDynamic * dist[direction];
                cost += lo.l_i / (lo.radius * lo.radius * Math.PI)
                        * (Math.pow(t1.m0 - t2.m0, 2)
                        + Math.pow(t1.m2 - t2.m2, 2));
                cost += lo.l_v * Math.abs(Math.pow(t1.lx, 2) + Math.pow(t1.ly, 2) - (Math.pow(t2.lx, 2) + Math.pow(t2.ly, 2)));
                cost += lo.l_a * Math.pow((cos_phi - 1) / 2 * lo.relink_distance, 2);

                // Check whether the above calculated cost is lower than the stored maximum cost
                if (cost < max_cost) {
                    index = m; // Store the index for trajectory
                    max_cost = cost; // Set the new maximum cost the cost of this relink
                }
            }

            if (index > -1) { // When a possible relink has been found
                t2 = trajectories.get(index);

                new_traj_particles = new ArrayList<FNTDtrackSpot>();
                new_trac = new FNTDtrack(t2.trackNo);

                // Add the particles of the trajectories to the new trajectory in the correct order
                if (t2.firstSlice - t1.lastSlice > 0 && t2.firstSlice - t1.lastSlice <= 2 * lo.linkRange + 1) {
                    new_traj_particles.addAll(t1.trackSpots);
                    new_traj_particles.addAll(t2.trackSpots);
                } else {
                    new_traj_particles.addAll(t2.trackSpots);
                    new_traj_particles.addAll(t1.trackSpots);
                }

                for (FNTDtrackSpot p : new_traj_particles) {
                    new_trac.addParticle(p);
                }

                // Replace the latter trajectory by the new trajectory. The latter trajectory will always appear later in the main for loop.
                // The new trajectory will therefore be checked for relinking later in the loop.
                trajectories.set(index, new_trac);

                // Add the first trajectory to the -to be removed- list
                to_be_removed.add(t1);
            }
        }

        // Remove the trajectories
        for (FNTDtrack t : to_be_removed) {
            trajectories.remove(t);
        }

        FNTDtrackSpot temp;
        for (FNTDtrack t : trajectories) {
            for (int i = 0; i < t.trackSpots.size(); i++) {
                // Rewrite the linking information in the particles based on the relinked trajectories
                temp = particles[t.trackSpots.get(i).sliceIdx].get(t.trackSpots.get(i).trackSpotIdx);

                temp.trackNo = t.trackNo;
                temp.prevTrackSpotNo = -1;
                temp.prevTrackSpotSliceIdx = -1;
                temp.nextTrackSpotNo = -1;
                temp.nextTrackSpotSliceIdx = -1;

                if (i > 0) {
                    temp.prevTrackSpotNo = t.trackSpots.get(i - 1).trackSpotIdx;
                    temp.prevTrackSpotSliceIdx = t.trackSpots.get(i - 1).sliceIdx;
                }
                if (i < (t.trackSpots.size() - 1)) {
                    temp.nextTrackSpotNo = t.trackSpots.get(i + 1).trackSpotIdx;
                    temp.nextTrackSpotSliceIdx = t.trackSpots.get(i + 1).sliceIdx;
                }
            }
        }
    }

    /**
     * Updates track indices of particles within tracks. Used to recalculate the
     * progressTotal number of tracks and to make sure all particles are
     * connected correctly.
     *
     */
    public void determineTrackNos() {
        int nTracks = 0;
        for (int i = 0; i < eligibleParticles.length; i++) {
            for (FNTDtrackSpot p : eligibleParticles[i]) {

                if (p.prevTrackSpotNo > -1 && p.prevTrackSpotSliceIdx > -1) {
                    p.trackNo = eligibleParticles[p.prevTrackSpotSliceIdx].get(p.prevTrackSpotNo).trackNo;
                } else if (p.nextTrackSpotNo > -1 && p.nextTrackSpotSliceIdx > -1) {
                    p.trackNo = ++nTracks;
                } else {
                    p.trackNo = 0;
                }
            }
        }
    }

}
