// -*- tab-width: 4; indent-tabs-mode: nil; c-basic-offset: 2 -*-
// vi: set et ts=4 sw=2 sts=2:
#ifndef DUNE_CURVED_SURFACE_GRID_INTERSECTION_HH
#define DUNE_CURVED_SURFACE_GRID_INTERSECTION_HH

#include <type_traits>
#include <utility>

#include <dune/common/fvector.hh>
#include <dune/common/std/optional.hh>
#include <dune/geometry/referenceelements.hh>

namespace Dune
{
  namespace CGeo
  {

    // Intersection
    // ------------

    template< class Grid, class HostIntersection >
    class Intersection
    {
      using HostGeometry = typename HostIntersection::Geometry;
      using HostLocalGeometry = typename HostIntersection::LocalGeometry;

      using Traits = typename std::remove_const_t<Grid>::Traits;
      using GridFunction = typename Traits::GridFunction;

    public:
      using ctype = typename Traits::ctype;

      static const int dimension = Traits::dimension;
      static const int dimensionworld = Traits::dimensionworld;

      using Entity = typename Traits::template Codim<0>::Entity;
      using Geometry = typename Traits::template Codim<1>::Geometry;
      using LocalGeometry = typename Traits::template Codim<1>::LocalGeometry;
      using ElementGeometry = typename Traits::template Codim<0>::Geometry;

    private:
      using EntityImpl = typename Traits::template Codim<0>::EntityImpl;
      using GeometryImpl = typename Traits::template Codim<1>::GeometryImpl;

    public:
      Intersection() = default;

      Intersection (const HostIntersection& hostIntersection, const GridFunction& gridFunction)
        : hostIntersection_(hostIntersection)
        , gridFunction_(&gridFunction)
      {}

      bool equals (const Intersection& other) const
      {
        return hostIntersection_ == other.hostIntersection_;
      }

      explicit operator bool () const { return bool(hostIntersection_); }

      Entity inside () const
      {
        return EntityImpl(gridFunction(), hostIntersection().inside());
      }

      Entity outside () const
      {
        return EntityImpl(gridFunction(), hostIntersection().outside());
      }

      bool boundary () const { return hostIntersection().boundary(); }

      bool conforming () const { return hostIntersection().conforming(); }

      bool neighbor () const { return hostIntersection().neighbor(); }

      size_t boundarySegmentIndex () const
      {
        return hostIntersection().boundarySegmentIndex();
      }

      LocalGeometry geometryInInside () const
      {
        return hostIntersection().geometryInInside();
      }

      LocalGeometry geometryInOutside () const
      {
        return hostIntersection().geometryInOutside();
      }

      //! Construct a curved geometry for the intersection.
      /**
       * This does only work properly if the intersection is a full subEntity of inside and
       * outside and the trace of the local basis functions along that subEntity is again
       * a local basis function of codim=1
       **/
      Geometry geometry () const
      {
        if (!geo_)
        {
          auto localFct = localFunction(gridFunction());
          localFct.bind(hostIntersection().inside());
          geo_.emplace(type(), localFct, hostIntersection().geometryInInside());
        }
        return Geometry(*geo_);
      }

      GeometryType type () const { return hostIntersection().type(); }

      int indexInInside () const
      {
        return hostIntersection().indexInInside();
      }

      int indexInOutside () const
      {
        return hostIntersection().indexInOutside();
      }

      FieldVector<ctype, dimensionworld> outerNormal (const FieldVector<ctype, dimension-1>& local) const
      {
        return outerNormalImpl(local, false);
      }

      FieldVector<ctype, dimensionworld> integrationOuterNormal (const FieldVector<ctype, dimension-1>& local) const
      {
        return outerNormalImpl(local, true);
      }

      FieldVector<ctype, dimensionworld> unitOuterNormal (const FieldVector<ctype, dimension-1>& local) const
      {
        FieldVector<ctype, dimensionworld> normal = outerNormal(local);
        return normal /= normal.two_norm();
      }

      FieldVector<ctype, dimensionworld> centerUnitOuterNormal () const
      {
        auto refFace = referenceElement<ctype, dimension-1>(type());
        return unitOuterNormal(refFace.position(0, 0));
      }

      const HostIntersection& hostIntersection () const
      {
        return hostIntersection_;
      }

      const GridFunction& gridFunction () const { return *gridFunction_; }

    private:
      FieldVector<ctype, dimensionworld>
      outerNormalImpl (const FieldVector<ctype, dimension-1>& local, bool scaleByIntegrationElement) const
      {
        if (!insideGeo_)
          insideGeo_.emplace(inside().impl().geometry());

        const LocalGeometry geoInInside = geometryInInside();
        const int idxInInside = indexInInside();
        auto refElement = referenceElement<ctype, dimension>(insideGeo_->type());

        FieldVector<ctype, dimension> x(geoInInside.global(local));
        const auto& jit = insideGeo_->jacobianInverseTransposed(x);
        FieldVector<ctype, dimension> refNormal = refElement.integrationOuterNormal(idxInInside);

        FieldVector<ctype, dimensionworld> normal;
        jit.mv(refNormal, normal);

        if (scaleByIntegrationElement) {
          if (!conforming())
            normal *= geoInInside.volume() / refElement.template geometry<1>(idxInInside).volume();
          // normal *= jit.detInv(); TODO: what is detInv()?
        }

        return normal;
      }

    private:
      HostIntersection hostIntersection_;
      const GridFunction* gridFunction_ = nullptr;

      // geometry caches
      mutable Std::optional<ElementGeometry> insideGeo_;
      mutable Std::optional<GeometryImpl> geo_;
    };

  } // namespace CGeo
} // namespace Dune

#endif // DUNE_CURVED_SURFACE_GRID_INTERSECTION_HH