// -*- tab-width: 4; indent-tabs-mode: nil; c-basic-offset: 2 -*-
// vi: set et ts=4 sw=2 sts=2:
#ifndef DUNE_CURVED_SURFACE_GRID_SPHERE_GRIDFUNCTION_HH
#define DUNE_CURVED_SURFACE_GRID_SPHERE_GRIDFUNCTION_HH

#include <type_traits>

#include <dune/common/math.hh>
#include <dune/functions/common/defaultderivativetraits.hh>

#include "analyticgridfunction.hh"

namespace Dune
{
  // Ellipsoid functor
  template< class T >
  class EllipsoidProjection
  {
    T a_;
    T b_;
    T c_;

  public:
    //! Constructor of ellipsoid by major axes
    EllipsoidProjection (T a, T b, T c)
      : a_(a)
      , b_(b)
      , c_(c)
    {}

    //! project the coordinate to the ellipsoid
    // NOTE: This is not a closes-point projection, but a spherical-coordinate projection
    template< class Domain >
    Domain operator() (const Domain& X) const
    {
      using std::sin; using std::cos;
      auto [phi,theta] = angles(X);
      return {a_*cos(phi)*sin(theta), b_*sin(phi)*sin(theta), c_*cos(theta)};
    }

    //! derivative of the projection
    friend auto derivative (const EllipsoidProjection& ellipsoid)
    {
      return [a=ellipsoid.a_,b=ellipsoid.b_,c=ellipsoid.c_](auto const& X)
      {
        using std::sqrt;
        using Domain = std::decay_t<decltype(X)>;
        using DerivativeTraits = Functions::DefaultDerivativeTraits<Domain(Domain)>;
        typename DerivativeTraits::Range out;

        T x = X[0], y = X[1], z = X[2];
        T x2 = x*x, y2 = y*y, z2 = z*z;
        T x5 = x2*x2*x;

        T nrm0 = x2 + y2;
        T nrm1 = x2 + y2 + z2;
        T nrm2 = sqrt(nrm0/nrm1);
        T nrm3 = sqrt(nrm0/x2);
        T nrm4 = sqrt(nrm1)*nrm1;
        T nrm5 = sqrt(nrm0)*nrm4;

        return {
          {
             a*x*nrm3*(y2 + z2)/nrm5 ,
            -b*y*nrm0/(nrm3*nrm5) ,
            -c*z*x/nrm4
          },
          {
            -a*x2*y*nrm3/nrm5 ,
             b*nrm0*nrm0*nrm0*(x2 + z2)/(x5*power(nrm3, 5)*nrm5) ,
            -c*y*z/nrm4
          },
          {
            -a*z*nrm0/(nrm3*nrm5) ,
            -b*y*z*nrm0/(x*nrm3*nrm5) ,
             c*(x2 + y2)/nrm4
          }
        };
      };
    }

    //! Normal vector
    template< class Domain >
    Domain normal (const Domain& X) const
    {
      using std::sqrt;
      T x = X[0], y = X[1], z = X[2];
      T a2 = a_*a_, b2 = b_*b_, c2 = c_*c_;

      auto div = sqrt(b2*b2*c2*c2*x*x + a2*a2*c2*c2*y*y + a2*a2*b2*b2*z*z);
      return {b2*c2*x/div, a2*c2*y/div, a2*b2*z/div};
    }
    
    //! Mean curvature
    template< class Domain >
    T mean_curvature (const Domain& X) const
    {
      using std::sqrt; using std::abs;
      T x = X[0], y = X[1], z = X[2];
      T a2 = a_*a_, b2 = b_*b_, c2 = c_*c_;

      auto div = 2*a2*b2*c2*power(sqrt(x*x/(a2*a2) + y*y/(b2*b2) + z*z/(c2*c2)), 3);
      return abs(x*x + y*y + z*z - a2 - b2 - c2)/div;
    }

    //! Gaussian curvature
    template< class Domain >
    T gauss_curvature (const Domain& X) const 
    {
      T x = X[0], y = X[1], z = X[2];
      T a2 = a_*a_, b2 = b_*b_, c2 = c_*c_;

      auto div = a2*b2*c2*power(x*x/(a2*a2) + y*y/(b2*b2) + z*z/(c2*c2), 2);
      return T(1)/div;
    }

  private:
    FieldVector<T,2> angles (Domain x) const 
    {
      using std::acos; using std::atan2;
      x /= x.two_norm();

      return {atan2(x[1], x[0]), acos(x[2])};
    }
  };

  //! construct a grid function representing a sphere parametrization
  template< class Grid, class T >
  auto ellipsoidGridFunction (T a, T b, T c)
  {
    return analyticGridFunction<Grid>(EllipsoidProjection<T>{a,b,c});
  }

} // end namespace Dune

#endif // DUNE_CURVED_SURFACE_GRID_SPHERE_GRIDFUNCTION_HH