Skip to content
Snippets Groups Projects
normalgridviewfunction.hh 6.36 KiB
// -*- tab-width: 4; indent-tabs-mode: nil; c-basic-offset: 2 -*-
// vi: set et ts=4 sw=2 sts=2:
#ifndef DUNE_CURVED_SURFACE_GRID_NORMAL_GRIDVIEWFUNCTION_HH
#define DUNE_CURVED_SURFACE_GRID_NORMAL_GRIDVIEWFUNCTION_HH

#include <array>
#include <vector>

#include <dune/common/fvector.hh>
#include <dune/functions/backends/istlvectorbackend.hh>
#include <dune/functions/functionspacebases/basistags.hh>
#include <dune/functions/functionspacebases/defaultglobalbasis.hh>
#include <dune/functions/functionspacebases/lagrangebasis.hh>
#include <dune/functions/functionspacebases/powerbasis.hh>
#include <dune/functions/gridfunctions/gridviewentityset.hh>
#include <dune/grid/utility/hierarchicsearch.hh>
#include <dune/istl/bvector.hh>


namespace Dune 
{
  //! Grid-view function representing averaged normal vector
  /**
   * \tparam GridView   The grid-view this grid-view-function is defined on
   * \tparam ORDER      Polynomial order of the lagrange bases used for representing the normals
   * \tparam T          Value type used for the basis and the coefficients
   **/
  template< class GridView, int ORDER = -1, class T = double >
  class NormalGridViewFunction
  {
    static auto makeBasis (const GridView& gridView, int order)
    {
      namespace BF = BasisFactory;
      return BF::makeBasis(gridView, BF::power<GridView::dimensionworld>(BF::lagrange<T>(order), BF::blockedInterleaved()));
    }

    using Basis = decltype(makeBasis(std::declval<GridView>(), ORDER));

  public:
    using EntitySet = GridViewEntitySet<GridView,0>;

    using Domain = typename EntitySet::GlobalCoordinate;
    using Range = FieldVector<T,GridView::dimensionworld>;

    using VectorType = BlockVector<Range>;

  private:
    class LocalFunction
    {
      using LocalView = typename Basis::LocalView;
      using LocalContext = typename LocalView::Element;

      using Domain = typename EntitySet::LocalCoordinate;
      using Range = typename NormalGridViewFunction::Range;

    public:
      LocalFunction (LocalView&& localView, const VectorType& normals)
        : localView_(std::move(localView))
        , normals_(normals)
      {}

      //! Collect the normal vector from all element DOFs into a local
      //! vector that can be accessed in the operator() for interpolation
      void bind (const LocalContext& element)
      {
        localView_.bind(element);

        const auto& leafNode = localView_.tree().child(0);
        localNormals_.resize(leafNode.size());
        // collect local normal vectors
        for (std::size_t i = 0; i < localNormals_.size(); ++i) {
          auto idx = localView_.index(leafNode.localIndex(i));
          localNormals_[i] = normals_[idx[0]];
        }

        bound_ = true;
      }

      void unbind ()
      {
        localView_.unbind();
        bound_ = false;
      }

      // evaluate normal vectors in local coordinate
      // by interpolation of stored local normals.
      Range operator() (const Domain& local) const 
      {
        assert(bound_);

        const auto& leafNode = localView_.tree().child(0);
        const auto& lfe = leafNode.finiteElement();

        // evaluate basis functions in local coordinate
        lfe.localBasis().evaluateFunction(local, shapeValues_);
        assert(localNormals_.size() == shapeValues_.size());

        Range n(0);
        for (std::size_t i = 0; i < localNormals_.size(); ++i)
          n.axpy(shapeValues_[i], localNormals_[i]);

        // return normalized vector
        return n / n.two_norm();
      }

    private:
      LocalView localView_;
      const VectorType& normals_;

      std::vector<Range> localNormals_;
      mutable std::vector<FieldVector<T,1>> shapeValues_;
      bool bound_ = false;
    };

  public:
    //! Constructor of the grid function. 
    /** 
     * Creates a global basis of a power of langrange nodes of given order.
     * The constructor argument `order` is defaulted to the class template parameter.
     **/
    NormalGridViewFunction (const GridView& gridView, int order = ORDER)
      : entitySet_(gridView)
      , basis_(makeBasis(gridView, order))
    {
      update(gridView);
    }

    //! Epdate the grid function. 
    /**
     * This calculates a mean average of normal vectors in the  DOFs of the basis. 
     * Those averages are stored normalized in the coefficients vector.
     **/
    void update (const GridView& gridView)
    {
      entitySet_ = EntitySet{gridView};
      basis_.update(gridView);

      normals_.resize(basis_.size());
      normals_ = 0;

      // compute normal vectors by mean averaging
      auto localView = basis_.localView();
      for (const auto& e : elements(basis_.gridView()))
      {
        localView.bind(e);
        auto geometry = e.geometry();

        const auto& leafNode = localView.tree().child(0);
        const auto& lfe = leafNode.finiteElement();

        // interpolate normal of geometry
        std::vector<Range> localNormals;
        lfe.localInterpolation().interpolate([&](const auto& local) -> Range {
          return Dune::normal(geometry, local);
        }, localNormals);

        // copy to global vector
        for (std::size_t i = 0; i < localNormals.size(); ++i) {
          auto idx = localView.index(leafNode.localIndex(i));
          normals_[idx[0]] += localNormals[i];
        }
      }

      // normalize vector
      for (std::size_t i = 0; i < normals_.size(); ++i)
        normals_[i] /= normals_[i].two_norm();
    }

    //! Evaluate normal vectors in global coordinates
    // NOTE: expensive
    Range operator() (const Domain& x) const 
    {
      using Grid = typename GridView::Grid;
      using IS = typename GridView::IndexSet;

      const auto& gv = basis_.gridView();
      HierarchicSearch<Grid,IS> hsearch{gv.grid(), gv.indexSet()};

      auto element = hsearch.findEntity(x);
      auto geometry = element.geometry();
      auto localFct = localFunction(*this);
      localFct.bind(element);
      return localFct(geometry.local(x));
    }

    //! Create a local function of this grifunction
    friend LocalFunction localFunction (const NormalGridViewFunction& gf)
    {
      return LocalFunction{gf.basis_.localView(), gf.normals_};
    }

    //! obtain the stored \ref GridViewEntitySet
    const EntitySet& entitySet () const
    {
      return entitySet_;
    }

  private:
    EntitySet entitySet_;
    Basis basis_;
    VectorType normals_;
  };

} // end namespace Dune 

#endif // DUNE_CURVED_SURFACE_GRID_NORMAL_GRIDVIEWFUNCTION_HH