#include <esp_log.h>
#include <esp_timer.h>
#include "MadgwickAHRS.h"
#include "i2c_mutex.h"
#include "ugv_comms.hh"
#include "ugv_display.hh"
#include "ugv_io.hh"

#include <math.h>

namespace ugv {

using ugv::comms::CommsClass;
using ugv::comms::messages::UGV_State;
using ugv::io::IOClass;

static const char *TAG = "ugv_main";

extern "C" {
SemaphoreHandle_t i2c_mutex;
}

constexpr uint64_t LOOP_PERIOD_US = 1e6 / 100;
static const float PI =
    3.1415926535897932384626433832795028841971693993751058209749445923078164062;

extern "C" void OnTimeout(void *arg);

void UpdateLocationFromGPS(comms::messages::Location &location,
                           const io::GpsData &        gps_data) {
  location.set_fix_quality(gps_data.fix_quality);
  location.set_latitude(gps_data.latitude);
  location.set_longitude(gps_data.longitude);
  location.set_altitude(gps_data.altitude);
}

static const float RAD_PER_DEG = PI / 180.f;
// Radius of earth in meters
static const float EARTH_RAD = 6372795.f;

static const float DRIVE_POWER = 0.5;
static const float ANGLE_P     = 0.02;
static const float MIN_DIST    = 10.0;

struct LatLong {
 public:
  float latitude;
  float longitude;

  inline LatLong(double latitude_, double longitude_)
      : latitude(latitude_), longitude(longitude_) {}

  /**
   * Return distance from this LatLong to target, in meters
   */
  float distance_to(const LatLong &target) const {
    float lat1  = latitude * RAD_PER_DEG;
    float lat2  = target.latitude * RAD_PER_DEG;
    float long1 = longitude * RAD_PER_DEG;
    float long2 = target.longitude * RAD_PER_DEG;
    float clat1 = cosf(lat1);
    float clat2 = cosf(lat2);
    float a     = powf(sinf((long2 - long1) / 2.f), 2.f) * clat1 * clat2 +
              powf(sinf((lat2 - lat1) / 2.f), 2.f);
    float d_over_r = 2 * atan2f(sqrtf(a), sqrtf(1 - a));
    return d_over_r * EARTH_RAD;
  }

  float bearing_toward(const LatLong &target) const {
    float dlong  = (target.longitude - longitude) * RAD_PER_DEG;
    float sdlong = sinf(dlong);
    float cdlong = cosf(dlong);
    float lat1   = latitude * RAD_PER_DEG;
    float lat2   = target.latitude * RAD_PER_DEG;
    float slat1  = sinf(lat1);
    float clat1  = cosf(lat1);
    float slat2  = sinf(lat2);
    float clat2  = cosf(lat2);
    float num    = sdlong * clat2;
    float denom  = (clat1 * slat2) - (slat1 * clat2 * cdlong);
    float course = atan2f(num, denom);
    if (course < 0.0) {
      course += 2 * PI;
    }
    return course / RAD_PER_DEG;
  }
};

struct State {
 public:
  CommsClass *       comms;
  IOClass *          io;
  DisplayClass *     display;
  esp_timer_handle_t timer_handle;
  io::Inputs         inputs;
  io::Outputs        outputs;
  int64_t            last_print;
  Madgwick           ahrs_;
  LatLong            target;

  State() : target{34.069022, -118.443067} {
    comms   = new CommsClass();
    io      = new IOClass();
    display = new DisplayClass(comms);
  }

  void Init() {
    esp_timer_init();
    i2c_mutex = xSemaphoreCreateMutex();

    ahrs_.begin(1000000.f /
                static_cast<float>(LOOP_PERIOD_US));  // rough sample frequency

    io->Init();
    comms->Init();
    display->Init();

    esp_timer_create_args_t timer_args;
    timer_args.callback        = OnTimeout;
    timer_args.arg             = this;
    timer_args.dispatch_method = ESP_TIMER_TASK;
    timer_args.name            = "ugv_main_loop";
    esp_timer_create(&timer_args, &this->timer_handle);
    esp_timer_start_periodic(timer_handle, LOOP_PERIOD_US);
    last_print = 0;
  }

  void OnTick() {
    ESP_LOGV(TAG, "OnTick");
    int64_t time_us = esp_timer_get_time();
    float   time_s  = ((float)time_us) / 1e6;
    io->ReadInputs(inputs);
    {
      io::Vec3f &g = inputs.mpu.gyro_rate, &a = inputs.mpu.accel,
                &m = inputs.mpu.mag;
      ahrs_.update(g.x, g.y, g.z, a.x, a.y, a.z, m.x, m.y, m.z);
    }
    if (time_us >= last_print + 500 * 1000) {  // 1s
      ESP_LOGD(TAG,
               "inputs: acc=(%f, %f, %f) gyro=(%f, %f, %f) mag=(%f, %f, %f)",
               inputs.mpu.accel.x, inputs.mpu.accel.y, inputs.mpu.accel.z,
               inputs.mpu.gyro_rate.x, inputs.mpu.gyro_rate.y,
               inputs.mpu.gyro_rate.z, inputs.mpu.mag.x, inputs.mpu.mag.y,
               inputs.mpu.mag.z);
      ESP_LOGD(TAG, "ahrs: yaw=%f, pitch=%f, roll=%f", ahrs_.getYaw(),
               ahrs_.getPitch(), ahrs_.getRoll());
      last_print = time_us;
    }

    comms->Lock();
    UpdateLocationFromGPS(comms->location, inputs.gps);
    UGV_State ugv_state = comms->ugv_state;
    comms->Unlock();

    switch (ugv_state) {
      default:
        ESP_LOGW(TAG, "unhandled state: %d", ugv_state);
        // fall through
      case UGV_State::STATE_IDLE:
      case UGV_State::STATE_FINISHED:
        outputs.left_motor  = 0.0;
        outputs.right_motor = 0.0;
        break;
      case UGV_State::STATE_AQUIRING: {
        TickType_t current_tick    = xTaskGetTickCount();
        TickType_t ticks_since_gps = current_tick - inputs.gps.last_update;
        bool       not_old         = ticks_since_gps <= pdMS_TO_TICKS(2000);
        bool       not_invalid = inputs.gps.fix_quality != io::GPS_FIX_INVALID;
        outputs.left_motor     = 0.0;
        outputs.right_motor    = 0.0;
        if (not_old && not_invalid) {
          comms->ugv_state = UGV_State::STATE_DRIVING;
        }
        break;
      }
      case UGV_State::STATE_DRIVING: {
        LatLong current_pos = {inputs.gps.latitude, inputs.gps.longitude};
        float   tgt_dist    = current_pos.distance_to(target);

        if (tgt_dist <= MIN_DIST) {
          ESP_LOGI(TAG, "Finished driving to target");
          comms->ugv_state = UGV_State::STATE_FINISHED;
          break;
        }

        float tgt_bearing = current_pos.bearing_toward(target);
        float cur_bearing = ahrs_.getYaw();
        float angle_delta = tgt_bearing - cur_bearing;
        if (angle_delta < 180.f) angle_delta += 360.f;
        if (angle_delta > 180.f) angle_delta -= 360.f;
        float angle_pwr = angle_delta * ANGLE_P;

        outputs.left_motor  = DRIVE_POWER + angle_pwr;
        outputs.right_motor = DRIVE_POWER - angle_pwr;
        break;
      }
      case UGV_State::STATE_TEST:
        outputs.left_motor  = sinf(time_s * PI);
        outputs.right_motor = cosf(time_s * PI);
        break;
    }
    io->WriteOutputs(outputs);
  }
};

extern "C" void OnTimeout(void *arg) {
  State *state = (State *)arg;
  state->OnTick();
}

State *state;

void Setup(void) {
  ESP_LOGI(TAG, "Starting UAS UGV");
  state = new State();
  state->Init();
  ESP_LOGI(TAG, "Setup finished");
}

}  // namespace ugv

extern "C" void app_main() { ugv::Setup(); }