// ============================================================================
// File:               $File$
//
// Project:            
//
// Purpose:            
//
// Author:             Rammi
//
// Copyright Notice:   (c) 2008  Rammi (rammi@caff.de)
//                     This code is in the public domain.
//                     Use at own risk.
//                     No guarantees given.
//
// Latest change:      $Date$
//
// History:	       $Log$
//=============================================================================
package de.caff.asteroid;

import static de.caff.asteroid.Communication.MAX_FRAMES_KEPT;
import de.caff.util.Pair;
import de.caff.util.Tools;

import java.awt.*;
import java.awt.geom.Point2D;
import java.util.*;
import java.util.List;

/**
 *  A much improved version of a velocity preparer.
 */
public class ImprovedVelocityPreparer
        extends SimpleVelocityPreparer
{
  private static final double MIN_BULLET_DELTA = 18;
  private static final double MAX_BULLET_DELTA = 22;
  private static final Point2D NULL_POINT = new Point();
  private int asteroidIdCounter = 0;
  private int bulletIdCounter = 0;

  private static class BulletKey
          implements Comparable<BulletKey>
  {
    /** Squared distance. */
    private final double dist2;
    /** Index delta. */
    private final int deltaIndex;
    /** First index. */
    private final int firstIndex;

    BulletKey(double dist2, int newIndex, int oldIndex)
    {
      this.dist2 = dist2;
      this.firstIndex = newIndex;
      this.deltaIndex = newIndex - oldIndex;
    }

    /**
     * Compares this object with the specified object for order.
     *
     * @param o the Object to be compared.
     * @return a negative integer, zero, or a positive integer as this object
     *         is less than, equal to, or greater than the specified object.
     * @throws ClassCastException if the specified object's type prevents it
     *                            from being compared to this Object.
     */
    public int compareTo(BulletKey o)
    {
      int comp = Double.compare(dist2, o.dist2);
      if (comp == 0) {
        comp = deltaIndex - o.deltaIndex;
        if (comp == 0) {
          comp = firstIndex - o.firstIndex;
        }
      }
      return comp;
    }
  }

  private static class AverageVelocity
  {
    /** Maximum needed for velocity calculation. */
    protected static final int MAX_COUNT = 8;
    /** Correction to apply. */
    private static final double CORRECTION = 1.0/MAX_COUNT;
    /** The last {@link #MAX_COUNT} velocities. */
    protected List<Point> lastVelocities = new LinkedList<Point>();
    /** The accumulated vx component. */
    private int vx;
    /** The accumulated vy component. */
    private int vy;

    public AverageVelocity()
    {
    }

    /**
     *  Constructor.
     *  @param vx  velocity's x component
     *  @param vy  velocity's y component
     */
    public AverageVelocity(int vx, int vy)
    {
      this.vx = GameObject.normalizeDeltaX(vx);
      this.vy = GameObject.normalizeDeltaY(vy);
      lastVelocities.add(new Point(this.vx, this.vy));
    }

    /**
     *  Constructor.
     *  @param v  velocity
     */
    public AverageVelocity(Point v)
    {
      this(v.x, v.y);
    }

    /**
     *  Add a velocity.
     *  @param velX  velocity's x component
     *  @param velY  velocity's y component
     */
    public void add(int velX, int velY)
    {
      if (lastVelocities.size() == MAX_COUNT) {
        Point v = lastVelocities.remove(0);
        vx -= v.x;
        vy -= v.y;
      }
      Point v = new Point(GameObject.normalizeDeltaX(velX),
                          GameObject.normalizeDeltaY(velY));
      lastVelocities.add(v);
      vx += v.x;
      vy += v.y;
    }

    /**
     *  Add a velocity.
     *  @param v velocity
     */
    public void add(Point v)
    {
      add(v.x, v.y);
    }

    /**
     *  Get the average speed in double precision.
     *  @return average speed
     */
    public Point2D getAverage()
    {
      if (lastVelocities.isEmpty()) {
        return new Point2D.Double(0, 0);
      }
      else {
        double count = 1.0/lastVelocities.size();
        return new Point2D.Double(vx*count,
                                  vy*count);
      }
    }

    /**
     *  Get the average speed's x component (rounded to int).
     *  @return x component
     */
    public double getAverageX()
    {
      return lastVelocities.isEmpty() ? 0 : vx/(double)lastVelocities.size();
    }

    /**
     *  Get the average speed's y component (rounded to int).
     *  @return y component
     */
    public double getAverageY()
    {
      return lastVelocities.isEmpty() ? 0 : vy/(double)lastVelocities.size();
    }

    /**
     *  Get the offset from the latest screen position to the correct internal MAME position.
     *  @return position correction
     */
    public Point2D getPositionCorrection()
    {
      if (lastVelocities.size() < MAX_COUNT) {
        return NULL_POINT;
      }
      double cx = 0;
      double cy = 0;
      double ax = getAverageX();
      double ay = getAverageY();
      if (vx % MAX_COUNT != 0  ||  vy % MAX_COUNT != 0) {
        int px = 0;
        int py = 0;
        double fx = 0;
        double fy = 0;
        for (Point v: lastVelocities) {
          px += v.x;
          py += v.y;
          fx += ax;
          fy += ay;
          if (px != (int)(fx + cx)) {
            if ((int)(fx + cx) > px) {
              do {
                cx -= CORRECTION;
              } while ((int)(fx + cx) > px);
            }
            else {
              do {
                cx += CORRECTION;
              } while ((int)(fx + cx) < px);
            }
          }
          if (py != (int)(fy + cy)) {
            if ((int)(fy + cy) > py) {
              do {
                cy -= CORRECTION;
              } while ((int)(fy + cy) > py);
            }
            else {
              do {
                cy += CORRECTION;
              } while ((int)(fy + cy) < py);
            }
          }
        }
      }
      return new Point2D.Double(cx, cy);
    }

    @Override
    public String toString()
    {
      return String.format("(%d/%d, %d/%d)",
                           vx, lastVelocities.size(),
                           vy, lastVelocities.size());
    }
  }

  private static class AverageBulletVelocity
          extends AverageVelocity
  {
    private final double initVX;
    private final double initVY;

    AverageBulletVelocity(Point2D initV)
    {
      this(initV.getX(), initV.getY());
    }

    AverageBulletVelocity(double initVX, double initVY)
    {
      super();
      this.initVX = initVX;
      this.initVY = initVY;
    }

    /**
     * Get the average speed in double precision.
     *
     * @return average speed
     */
    @Override
    public Point2D getAverage()
    {
      int size = lastVelocities.size();
      if (size < MAX_COUNT) {
        return new Point2D.Double(initVX, initVY);
      }
      else {
        return super.getAverage(); 
      }
    }

    /**
     * Get the average speed's x component (rounded to int).
     *
     * @return x component
     */
    @Override
    public double getAverageX()
    {
      int size = lastVelocities.size();
      if (size < MAX_COUNT) {
        if (size == 0) {
          return initVX;
        }
        else {
          double vx = super.getAverageX();
          return (initVX * (MAX_COUNT - size) + vx * size)/MAX_COUNT;
        }
      }
      else {
        return super.getAverageX();
      }
    }

    /**
     * Get the average speed's y component (rounded to int).
     *
     * @return y component
     */
    @Override
    public double getAverageY()
    {
      int size = lastVelocities.size();
      if (size < MAX_COUNT) {
        if (size == 0) {
          return initVY;
        }
        else {
          double vy = super.getAverageY();
          return (initVY * (MAX_COUNT - size) + vy * size)/MAX_COUNT;
        }
      }
      else {
        return super.getAverageY();
      }
    }
  }


  private AverageVelocity[][] averageAsteroidVelocities = new AverageVelocity[MAX_FRAMES_KEPT][256];
  private AverageVelocity[][] averageBulletVelocities = new AverageVelocity[MAX_FRAMES_KEPT][256];
  private AverageVelocity shipVelocity;

  /**
   * Prepare the asteroids.
   * Called if there are at least two frames.
   * Overwrite to change the default behavior
   *
   * @param frameInfos all frame infos
   * @param prevFrame  the second to last frame
   * @param currFrame  the last frame
   */
  @Override
  protected void prepareAsteroids(LinkedList<FrameInfo> frameInfos, FrameInfo prevFrame, FrameInfo currFrame)
  {
    AsteroidSelector selector = new AsteroidSelector(currFrame.getAsteroids());
    int oldPhase = prevFrame.getIndex() % MAX_FRAMES_KEPT;
    int newPhase = currFrame.getIndex() % MAX_FRAMES_KEPT;
    Arrays.fill(averageAsteroidVelocities[newPhase], null);
    for (Asteroid ast: prevFrame.getAsteroids()) {
      Asteroid candidate = selector.getBestMatch(ast);
      if (candidate != null) {
        AverageVelocity average = averageAsteroidVelocities[oldPhase][ast.getIndex()];
        if (average != null) {
          average.add(candidate.getX() - ast.getX(),
                      candidate.getY() - ast.getY());
        }
        else {
          average = new AverageVelocity(candidate.getX() - ast.getX(),
                                        candidate.getY() - ast.getY());
        }
        candidate.setVelocity(average.getAverageX(),
                              average.getAverageY());
        Point2D offset = average.getPositionCorrection();
        candidate.correctLocation(offset.getX(), offset.getY());
        candidate.inheret(ast);
        averageAsteroidVelocities[newPhase][candidate.getIndex()] = average;
      }
    }
    for (Asteroid ast: currFrame.getAsteroids()) {
      if (ast.getIdentity() == null) {
        // probably new
        ast.setIdentity(asteroidIdCounter++);
      }
    }
  }


  /**
   *  Prepare the bullets.
   *  Called if there are at least two frames.
   *  Overwrite to change the default behavior
   *  @param frameInfos all frame infos
   *  @param prevFrame  the second to last frame
   *  @param currFrame  the last frame
   */
  protected void prepareBullets(LinkedList<FrameInfo> frameInfos,
                                FrameInfo prevFrame,
                                FrameInfo currFrame)
  {
    int oldPhase = prevFrame.getIndex() % MAX_FRAMES_KEPT;
    int newPhase = currFrame.getIndex() % MAX_FRAMES_KEPT;
    Arrays.fill(averageBulletVelocities[newPhase], null);
    SortedMap<BulletKey, Pair<Bullet>> result = new TreeMap<BulletKey, Pair<Bullet>>();
    for (Bullet oldBullet: prevFrame.getBullets()) {
      double futureX = oldBullet.getX() + oldBullet.getVelocityX();
      double futureY = oldBullet.getY() + oldBullet.getVelocityY();

      for (Bullet bullet: currFrame.getBullets()) {
        double dist2 = bullet.getSquaredDistance(futureX, futureY);
        if (dist2 < MAX_SQUARED_BULLET_VELOCITY) {
          result.put(new BulletKey(dist2, bullet.getIndex(), oldBullet.getIndex()),
                     new Pair<Bullet>(oldBullet, bullet));
        }
      }
    }
    LinkedList<Pair<Bullet>> pairs = new LinkedList<Pair<Bullet>>(result.values());
    while (!pairs.isEmpty()) {
      Pair<Bullet> pair = pairs.remove(0);
      AverageVelocity average = averageBulletVelocities[oldPhase][pair.first.getIndex()];
      if (average != null) {
        average.add(pair.second.getX() - pair.first.getX(),
                    pair.second.getY() - pair.first.getY());
      }
      else {
        average = new AverageVelocity(pair.second.getX() - pair.first.getX(),
                                      pair.second.getY() - pair.first.getY());
      }
      pair.second.setVelocity(average.getAverageX(),
                              average.getAverageY());
      Point2D offset = average.getPositionCorrection();
      pair.second.correctLocation(offset.getX(), offset.getY());
      pair.second.inheret(pair.first);
      averageBulletVelocities[newPhase][pair.second.getIndex()] = average;

      for (ListIterator<Pair<Bullet>> it = pairs.listIterator();  it.hasNext();  ) {
        Pair<Bullet> p = it.next();
        if (p.first.equals(pair.first)  ||  p.second.equals(pair.second)) {
          it.remove();
        }
      }
    }
    for (Bullet bullet: currFrame.getBullets()) {
      if (bullet.getIdentity() == null) {
        bullet.setIdentity(bulletIdCounter++);

        SpaceShip ship = prevFrame.getSpaceShip();
        if (ship != null) {
          Point delta = bullet.getDelta(ship.getNextLocation());
          double deltaLen = Tools.getLength(delta);
          if (deltaLen >= MIN_BULLET_DELTA  &&  deltaLen <= MAX_BULLET_DELTA) {
            Point2D bulletVelocity = prevFrame.getNextShootingDirection().getBulletVelocity();
            bullet.setVelocity(bulletVelocity.getX() + ship.getVelocityX(),
                               bulletVelocity.getY() + ship.getVelocityY());
            AverageVelocity average = new AverageBulletVelocity(bullet.getVelocity());
            averageBulletVelocities[newPhase][bullet.getIndex()] = average;
            bullet.setFriendly(true);
          }
          else {
            bullet.setFriendly(false);
          }
        }
        else {
          bullet.setFriendly(false);
        }
      }
    }
  }

  /**
   * Prepare the space ship.
   * Called if there are at least two frames.
   * Overwrite to change the default behavior
   *
   * @param frameInfos all frame infos
   * @param prevFrame  the second to last frame
   * @param currFrame  the last frame
   */
  @Override
  protected void prepareSpaceShip(LinkedList<FrameInfo> frameInfos, FrameInfo prevFrame, FrameInfo currFrame)
  {
    SpaceShip prevShip = prevFrame.getSpaceShip();
    SpaceShip currShip = currFrame.getSpaceShip();
    if (prevShip != null  &&  currShip != null) {
      if (shipVelocity == null) {
        shipVelocity = new AverageVelocity(prevShip.getDelta(currShip));
      }
      else {
        shipVelocity.add(prevShip.getDelta(currShip));
      }
      currShip.inheret(prevShip);
      currShip.setVelocity(shipVelocity.getAverageX(), shipVelocity.getAverageY());
    }
    else {
      shipVelocity = null;
    }
  }

  /**
   * Prepare the ufo.
   * Called if there are at least two frames.
   * Overwrite to change the default behavior
   *
   * @param frameInfos all frame infos
   * @param prevFrame  the second to last frame
   * @param currFrame  the last frame
   */
  @Override
  protected void prepareUfo(LinkedList<FrameInfo> frameInfos, FrameInfo prevFrame, FrameInfo currFrame)
  {
    super.prepareUfo(frameInfos, prevFrame, currFrame);
    Ufo prevUfo = prevFrame.getUfo();
    if (prevUfo != null) {
      Ufo currUfo = currFrame.getUfo();
      if (currUfo != null) {
        currUfo.inheret(prevUfo);
      }
    }
  }

  /**
   *  Get the average velocity for an asteroid.
   *  Note that this only gives correct results for the last 256 frames received.
   *  @param frameInfo frame info containing asteroid
   *  @param ast asteroid
   *  @return average velocity or <code>NULL_POINT</code> if the average velocity is unknown because
   *                  the asteroid is new
   */
  Point2D getAverageVelocity(FrameInfo frameInfo, Asteroid ast)
  {
    AverageVelocity velocity = averageAsteroidVelocities[frameInfo.getIndex() % MAX_FRAMES_KEPT][ast.getIndex()];
    return velocity != null ? velocity.getAverage() : NULL_POINT;
  }
}
