#include <esp_log.h>
#include <esp_timer.h>
#include <math.h>

#include "MadgwickAHRS.h"
#include "config.pb.h"
#include "i2c_mutex.h"
#include "ugv_comms.hh"
#include "ugv_display.hh"
#include "ugv_io.hh"
#include "pid_controller.hh"

namespace ugv {

using ugv::comms::CommsClass;
using ugv::comms::messages::UGV_State;
using ugv::io::IOClass;

static const char *TAG = "ugv_main";

extern "C" {
SemaphoreHandle_t i2c_mutex;
}

constexpr uint64_t LOOP_PERIOD_US = 1e6 / 100;
constexpr float LOOP_PERIOD_S = 1000000.f / static_cast<float>(LOOP_PERIOD_US);
static const float PI =
    3.1415926535897932384626433832795028841971693993751058209749445923078164062;

static const float RAD_PER_DEG = PI / 180.f;
// Radius of earth in meters
static const float EARTH_RAD = 6372795.f;

extern "C" void OnTimeout(void *arg);

void UpdateLocationFromGPS(comms::messages::Location &location,
                           const io::GpsData &        gps_data) {
  location.set_fix_quality(gps_data.fix_quality);
  location.set_latitude(gps_data.latitude);
  location.set_longitude(gps_data.longitude);
  location.set_altitude(gps_data.altitude);
}

struct LatLong {
 public:
  float latitude;
  float longitude;

  inline LatLong() : LatLong(0., 0.) {}

  inline LatLong(double latitude_, double longitude_)
      : latitude(latitude_), longitude(longitude_) {}

  inline LatLong(const comms::messages::TargetLocation &loc)
      : latitude(loc.latitude()), longitude(loc.longitude()) {}

  /**
   * Return distance from this LatLong to target, in meters
   */
  float distance_to(const LatLong &target) const {
    float lat1  = latitude * RAD_PER_DEG;
    float lat2  = target.latitude * RAD_PER_DEG;
    float long1 = longitude * RAD_PER_DEG;
    float long2 = target.longitude * RAD_PER_DEG;
    float clat1 = cosf(lat1);
    float clat2 = cosf(lat2);
    float a     = powf(sinf((long2 - long1) / 2.f), 2.f) * clat1 * clat2 +
              powf(sinf((lat2 - lat1) / 2.f), 2.f);
    float d_over_r = 2 * atan2f(sqrtf(a), sqrtf(1 - a));
    return d_over_r * EARTH_RAD;
  }

  float bearing_toward(const LatLong &target) const {
    float dlong  = (target.longitude - longitude) * RAD_PER_DEG;
    float sdlong = sinf(dlong);
    float cdlong = cosf(dlong);
    float lat1   = latitude * RAD_PER_DEG;
    float lat2   = target.latitude * RAD_PER_DEG;
    float slat1  = sinf(lat1);
    float clat1  = cosf(lat1);
    float slat2  = sinf(lat2);
    float clat2  = cosf(lat2);
    float num    = sdlong * clat2;
    float denom  = (clat1 * slat2) - (slat1 * clat2 * cdlong);
    float course = atan2f(num, denom);
    if (course < 0.0) {
      course += 2 * PI;
    }
    return course / RAD_PER_DEG;
  }
};


struct State {
 public:
  CommsClass *       comms;
  IOClass *          io;
  DisplayClass *     display;
  esp_timer_handle_t timer_handle;

  io::Inputs     inputs_;
  io::Outputs    outputs_;
  int64_t        last_print_;
  Madgwick       ahrs_;
  LatLong        target_;
  PIDController  angle_controller_;
  config::Config conf_;

  State() : angle_controller_(LOOP_PERIOD_S) {
    SetTarget({34.069022, -118.443067});

    comms   = new CommsClass();
    io      = new IOClass();
    display = new DisplayClass(comms);

    SetConfig(DefaultConfig());
  }

  static config::Config DefaultConfig() {
    config::Config c;

    auto *apid = c.mutable_angle_pid();
    apid->set_kp(0.10);
    apid->set_ki(0.0);
    apid->set_kd(0.4);
    apid->set_max_output(0.5);
    apid->set_max_i_error(15.0);

    c.set_min_target_dist(10.0);
    c.set_min_flip_pitch(90.0);
    return c;
  }

  void SetConfig(const config::Config &conf) {
    auto &apid = conf.angle_pid();
    angle_controller_.SetPID(apid.kp(), apid.ki(), apid.kd());
    angle_controller_.MaxOutput(apid.max_output());
    angle_controller_.MaxIError(apid.max_i_error());
    conf_ = conf;
  }

  void SetTarget(LatLong target) { target_ = target; }

  void Init() {
    esp_timer_init();
    i2c_mutex = xSemaphoreCreateMutex();

    ahrs_.begin(LOOP_PERIOD_S);  // rough sample frequency

    io->Init();
    comms->Init();
    display->Init();

    esp_timer_create_args_t timer_args;
    timer_args.callback        = OnTimeout;
    timer_args.arg             = this;
    timer_args.dispatch_method = ESP_TIMER_TASK;
    timer_args.name            = "ugv_main_loop";
    esp_timer_create(&timer_args, &this->timer_handle);
    esp_timer_start_periodic(timer_handle, LOOP_PERIOD_US);
    last_print_ = 0;
  }

  void OnTick() {
    ESP_LOGV(TAG, "OnTick");
    int64_t time_us = esp_timer_get_time();
    // float   time_s  = ((float)time_us) / 1e6;
    io->ReadInputs(inputs_);
    {
      io::Vec3f &g = inputs_.mpu.gyro_rate, &a = inputs_.mpu.accel,
                &m = inputs_.mpu.mag;
      ahrs_.update(g.x, g.y, g.z, a.x, a.y, a.z, m.x, m.y, m.z);
    }
    if (time_us >= last_print_ + 500 * 1000) {  // 1s
      auto &mpu = inputs_.mpu;
      ESP_LOGD(
          TAG, "inputs: acc=(%f, %f, %f) gyro=(%f, %f, %f) mag=(%f, %f, %f)",
          mpu.accel.x, mpu.accel.y, mpu.accel.z, mpu.gyro_rate.x,
          mpu.gyro_rate.y, mpu.gyro_rate.z, mpu.mag.x, mpu.mag.y, mpu.mag.z);
      ESP_LOGD(TAG, "ahrs: yaw=%f, pitch=%f, roll=%f", ahrs_.getYaw(),
               ahrs_.getPitch(), ahrs_.getRoll());
      ESP_LOGD(TAG, "PID: error: %f", angle_controller_.Error());
      last_print_ = time_us;
    }

    comms->Lock();
    UpdateLocationFromGPS(*(comms->status.mutable_location()), inputs_.gps);
    comms->status.set_yaw_angle(ahrs_.getYaw());
    UGV_State ugv_state = comms->status.state();
    if (comms->new_target) {
      SetTarget(*comms->new_target);
      ESP_LOGI(TAG, "Updating target to (%f, %f)", target_.latitude,
               target_.longitude);
      delete comms->new_target;
      comms->new_target = nullptr;
    }
    if (comms->new_config) {
      ESP_LOGI(TAG, "Updating config");
      SetConfig(*comms->new_config);
      delete comms->new_config;
      comms->new_config = nullptr;
    }
    comms->Unlock();
    UGV_State next_state = ugv_state;

    angle_controller_.Input(ahrs_.getYaw());
    float drive_power    = 0.;
    outputs_.left_motor  = 0.0;
    outputs_.right_motor = 0.0;

    float pitch = ahrs_.getPitch();

    auto min_flip_pitch = conf_.min_flip_pitch();
    bool is_upside_down = (pitch > min_flip_pitch) || (pitch < -min_flip_pitch);

    switch (ugv_state) {
      default:
        ESP_LOGW(TAG, "unhandled state: %d", ugv_state);
        // fall through
      case UGV_State::STATE_IDLE:
      case UGV_State::STATE_FINISHED: angle_controller_.Disable(); break;
      case UGV_State::STATE_AQUIRING: {
        if (is_upside_down) {
          next_state = UGV_State::STATE_FLIPPING;
          break;
        }
        angle_controller_.Disable();
        TickType_t current_tick    = xTaskGetTickCount();
        TickType_t ticks_since_gps = current_tick - inputs_.gps.last_update;
        bool       not_old         = ticks_since_gps <= pdMS_TO_TICKS(2000);
        bool       not_invalid = inputs_.gps.fix_quality != io::GPS_FIX_INVALID;
        if (not_old && not_invalid) {
          next_state = UGV_State::STATE_TURNING;
        }
        break;
      }
      case UGV_State::STATE_FLIPPING: {
        angle_controller_.Disable();
        outputs_.left_motor  = -1.0;
        outputs_.right_motor = -1.0;
        if (!is_upside_down) {
          next_state = UGV_State::STATE_AQUIRING;
          break;
        }
        break;
      }
      case UGV_State::STATE_TURNING: {
        if (is_upside_down) {
          next_state = UGV_State::STATE_FLIPPING;
          break;
        }
        if (inputs_.gps.fix_quality == io::GPS_FIX_INVALID) {
          next_state = UGV_State::STATE_AQUIRING;
          break;
        }

        LatLong current_pos = {inputs_.gps.latitude, inputs_.gps.longitude};
        float   tgt_bearing = current_pos.bearing_toward(target_);
        angle_controller_.Enable();
        angle_controller_.Setpoint(tgt_bearing);

        if (fabs(angle_controller_.Error()) <= 5.0) {
          next_state = UGV_State::STATE_DRIVING;
        }
        break;
      }
      case UGV_State::STATE_DRIVING: {
        if (is_upside_down) {
          next_state = UGV_State::STATE_FLIPPING;
          break;
        }
        if (inputs_.gps.fix_quality == io::GPS_FIX_INVALID) {
          next_state = UGV_State::STATE_AQUIRING;
          break;
        }

        LatLong current_pos = {inputs_.gps.latitude, inputs_.gps.longitude};
        float   tgt_dist    = current_pos.distance_to(target_);

        if (tgt_dist <= conf_.min_target_dist()) {
          ESP_LOGI(TAG, "Finished driving to target");
          next_state = UGV_State::STATE_FINISHED;
          break;
        }

        float tgt_bearing = current_pos.bearing_toward(target_);
        angle_controller_.Enable();
        angle_controller_.Setpoint(tgt_bearing);
        break;
      }
      case UGV_State::STATE_TEST:
#ifdef BASIC_TEST
        outputs.left_motor  = sinf(time_s * PI);
        outputs.right_motor = cosf(time_s * PI);
#else
        angle_controller_.Enable();
        angle_controller_.Setpoint(90.0);
#endif
        break;
      case UGV_State::STATE_DRIVE_HEADING:
        angle_controller_.Enable();
        angle_controller_.Setpoint(comms->drive_heading.heading());
        drive_power = comms->drive_heading.power();
        break;
    }

    if (angle_controller_.Enabled()) {
      float angle_pwr      = angle_controller_.Update();
      outputs_.left_motor  = drive_power - angle_pwr;
      outputs_.right_motor = drive_power + angle_pwr;
    }

    io->WriteOutputs(outputs_);

    comms->Lock();
    comms->status.set_state(next_state);
    comms->Unlock();
  }
};

extern "C" void OnTimeout(void *arg) {
  State *state = (State *)arg;
  state->OnTick();
}

State *state;

void Setup(void) {
  ESP_LOGI(TAG, "Starting UAS UGV");
  state = new State();
  state->Init();
  ESP_LOGI(TAG, "Setup finished");
}

}  // namespace ugv

extern "C" void app_main() { ugv::Setup(); }