Commit bab2cc12 authored by Praetorius, Simon's avatar Praetorius, Simon
Browse files

Merge branch 'feature/function_wrapper' into 'master'

Cleanup of Function and LocalFunction interface

See merge request extensions/dune-vtk!12
parents 96f8bfb5 801badd8
......@@ -11,13 +11,13 @@ std::vector<T> DataCollectorInterface<GV,D,P>
::cellDataImpl (VtkFunction const& fct) const
{
std::vector<T> data;
data.reserve(this->numCells() * fct.ncomps());
data.reserve(this->numCells() * fct.numComponents());
auto localFct = localFunction(fct);
for (auto const& e : elements(gridView_, partition)) {
localFct.bind(e);
auto refElem = referenceElement<T,dim>(e.type());
for (int comp = 0; comp < fct.ncomps(); ++comp)
for (int comp = 0; comp < fct.numComponents(); ++comp)
data.emplace_back(localFct.evaluate(comp, refElem.position(0,0)));
localFct.unbind();
}
......
......@@ -119,7 +119,7 @@ namespace Dune
template <class T, class GlobalFunction>
std::vector<T> pointDataImpl (GlobalFunction const& fct) const
{
std::vector<T> data(numPoints_ * fct.ncomps());
std::vector<T> data(numPoints_ * fct.numComponents());
auto const& indexSet = gridView_.indexSet();
auto localFct = localFunction(fct);
for (auto const& e : elements(gridView_, partition)) {
......@@ -127,8 +127,8 @@ namespace Dune
Vtk::CellType cellType{e.type()};
auto refElem = referenceElement(e.geometry());
for (unsigned int j = 0; j < e.subEntities(dim); ++j) {
std::size_t idx = fct.ncomps() * indexMap_[indexSet.subIndex(e,cellType.permutation(j),dim)];
for (int comp = 0; comp < fct.ncomps(); ++comp)
std::size_t idx = fct.numComponents() * indexMap_[indexSet.subIndex(e,cellType.permutation(j),dim)];
for (int comp = 0; comp < fct.numComponents(); ++comp)
data[idx + comp] = T(localFct.evaluate(comp, refElem.position(cellType.permutation(j),dim)));
}
localFct.unbind();
......
......@@ -102,7 +102,7 @@ namespace Dune
template <class T, class GlobalFunction>
std::vector<T> pointDataImpl (GlobalFunction const& fct) const
{
std::vector<T> data(numPoints_ * fct.ncomps());
std::vector<T> data(numPoints_ * fct.numComponents());
auto const& indexSet = gridView_.indexSet();
auto localFct = localFunction(fct);
for (auto const& e : elements(gridView_, partition)) {
......@@ -110,8 +110,8 @@ namespace Dune
Vtk::CellType cellType{e.type()};
auto refElem = referenceElement(e.geometry());
for (unsigned int j = 0; j < e.subEntities(dim); ++j) {
std::size_t idx = fct.ncomps() * indexMap_[indexSet.subIndex(e, cellType.permutation(j), dim)];
for (int comp = 0; comp < fct.ncomps(); ++comp)
std::size_t idx = fct.numComponents() * indexMap_[indexSet.subIndex(e, cellType.permutation(j), dim)];
for (int comp = 0; comp < fct.numComponents(); ++comp)
data[idx + comp] = T(localFct.evaluate(comp, refElem.position(cellType.permutation(j),dim)));
}
localFct.unbind();
......
......@@ -154,7 +154,7 @@ namespace Dune
template <class T, class GlobalFunction>
std::vector<T> pointDataImpl (GlobalFunction const& fct) const
{
int nComps = fct.ncomps();
int nComps = fct.numComponents();
std::vector<T> data(this->numPoints() * nComps);
auto const& indexSet = gridView_.indexSet();
......
......@@ -113,7 +113,7 @@ namespace Dune
template <class T, class GlobalFunction>
std::vector<T> pointDataImpl (GlobalFunction const& fct) const
{
std::vector<T> data(this->numPoints() * fct.ncomps());
std::vector<T> data(this->numPoints() * fct.numComponents());
auto const& indexSet = gridView_.indexSet();
auto localFct = localFunction(fct);
for (auto const& e : elements(gridView_, partition)) {
......@@ -122,14 +122,14 @@ namespace Dune
auto refElem = referenceElement(e.geometry());
for (unsigned int j = 0; j < e.subEntities(dim); ++j) {
int k = cellType.permutation(j);
std::size_t idx = fct.ncomps() * indexSet.subIndex(e, k, dim);
for (int comp = 0; comp < fct.ncomps(); ++comp)
std::size_t idx = fct.numComponents() * indexSet.subIndex(e, k, dim);
for (int comp = 0; comp < fct.numComponents(); ++comp)
data[idx + comp] = T(localFct.evaluate(comp, refElem.position(k, dim)));
}
for (unsigned int j = 0; j < e.subEntities(dim-1); ++j) {
int k = cellType.permutation(e.subEntities(dim) + j);
std::size_t idx = fct.ncomps() * (indexSet.subIndex(e, k, dim-1) + gridView_.size(dim));
for (int comp = 0; comp < fct.ncomps(); ++comp)
std::size_t idx = fct.numComponents() * (indexSet.subIndex(e, k, dim-1) + gridView_.size(dim));
for (int comp = 0; comp < fct.numComponents(); ++comp)
data[idx + comp] = T(localFct.evaluate(comp, refElem.position(k, dim-1)));
}
localFct.unbind();
......
......@@ -21,14 +21,11 @@ namespace Dune
using Entity = typename Interface::Entity;
using LocalCoordinate = typename Interface::LocalCoordinate;
template <class F, class D>
using Range = std::decay_t<decltype(std::declval<F>()(std::declval<D>()))>;
public:
/// Constructor. Stores a copy of the passed `localFct` in a local variable.
template <class LocalFct,
disableCopyMove<Self, LocalFct> = 0>
LocalFunctionWrapper (LocalFct&& localFct)
explicit LocalFunctionWrapper (LocalFct&& localFct)
: localFct_(std::forward<LocalFct>(localFct))
{}
......@@ -67,22 +64,22 @@ namespace Dune
return comp < N ? vec[comp] : 0.0;
}
// Evaluate a component of a vector valued data
template <class T,
std::enable_if_t<IsIndexable<T,int>::value, int> = 0>
double evaluateImpl (int comp, T const& value) const
{
return value[comp];
}
// Evaluate a component of a vector valued data
template <class T,
std::enable_if_t<IsIndexable<T,int>::value, int> = 0>
double evaluateImpl (int comp, T const& value) const
{
return value[comp];
}
// Return the scalar values
template <class T,
std::enable_if_t<not IsIndexable<T,int>::value, int> = 0>
double evaluateImpl (int comp, T const& value) const
{
assert(comp == 0);
return value;
}
// Return the scalar values
template <class T,
std::enable_if_t<not IsIndexable<T,int>::value, int> = 0>
double evaluateImpl (int comp, T const& value) const
{
assert(comp == 0);
return value;
}
private:
LocalFunction localFct_;
......
#pragma once
#include <optional>
#include <numeric>
#include <type_traits>
#include <dune/common/std/type_traits.hh>
#include <dune/common/typetraits.hh>
#include <dune/common/version.hh>
#include "localfunction.hh"
#include "types.hh"
#include "utility/arguments.hh"
namespace Dune
{
......@@ -26,74 +28,130 @@ namespace Dune
template <class F>
using LocalFunction = decltype(localFunction(std::declval<F>()));
using Domain = typename GridView::template Codim<0>::Entity::Geometry::LocalCoordinate;
template <class LF, class E>
using HasBind = decltype(std::declval<std::decay_t<LF>&>().bind(std::declval<E>()));
template <class F>
using Range = std::decay_t<std::result_of_t<F(Domain)>>;
using Element = typename GridView::template Codim<0>::Entity;
using LocalDomain = typename Element::Geometry::LocalCoordinate;
template <class F, class D>
using Range = std::decay_t<std::result_of_t<F(D)>>;
private:
template <class T, int N>
static auto sizeOfImpl (FieldVector<T,N> const&)
-> std::integral_constant<int, N> { return {}; }
static auto sizeOfImpl (FieldVector<T,N>) -> std::integral_constant<int, N> { return {}; }
template <class T, int N, int M>
static auto sizeOfImpl (FieldMatrix<T,N,M> const&)
-> std::integral_constant<int, N*M> { return {}; }
static auto sizeOfImpl (FieldMatrix<T,N,M>) -> std::integral_constant<int, N*M> { return {}; }
static auto sizeOfImpl (...)
-> std::integral_constant<int, 1> { return {}; }
static auto sizeOfImpl (...) -> std::integral_constant<int, 1> { return {}; }
template <class T>
static constexpr int sizeOf () { return decltype(sizeOfImpl(std::declval<T>()))::value; }
static std::vector<int> allComponents(int n)
{
std::vector<int> components(n);
std::iota(components.begin(), components.end(), 0);
return components;
}
public:
/// Constructor VtkFunction from legacy VTKFunction
/// (1) Construct from a LocalFunction directly
/**
* \param fct The VTKFunction to wrap
* \param type The VTK datatype how to write the function values to the output [Vtk::DataTypes::FLOAT64]
* \param localFct A local-function, providing a `bind(Element)` and an `operator()(LocalDomain)`
* \param name The name to use as identification in the VTK file
* \param components A vector of component indices to extract from the range type
* \param category The \ref Vtk::RangeTypes category for the range. [Vtk::RangeTypes::AUTO]
* \param dataType The \ref Vtk::DataTypes used in the output. [Vtk::DataTypes::FLOAT32]
*
* The arguments `category` and `dataType` can be passed in any order.
*
* NOTE: Stores the localFunction by value.
**/
Function (std::shared_ptr<VTKFunction<GridView> const> const& fct,
std::optional<Vtk::DataTypes> type = {})
: localFct_(fct)
, name_(fct->name())
, ncomps_(fct->ncomps())
, type_(type ? *type : Vtk::DataTypes::FLOAT32)
{}
template <class LF, class... Args,
class = void_t<HasBind<LF,Element>>,
class R = Range<LF,LocalDomain> >
Function (LF&& localFct, std::string name, std::vector<int> components, Args const&... args)
: localFct_(std::forward<LF>(localFct))
, name_(std::move(name))
{
setComponents(std::move(components));
setRangeType(getArg<Vtk::RangeTypes>(args..., Vtk::RangeTypes::AUTO), components_.size());
setDataType(getArg<Vtk::DataTypes>(args..., Vtk::DataTypes::FLOAT32));
}
/// Construct VtkFunction from dune-functions GridFunction with Signature
// NOTE: Stores the localFunction(fct) by value.
/// (2) Construct from a LocalFunction directly
/**
* \param fct A Grid(View)-function, providing a `localFunction(fct)`
* \param name The name to use component identification in the VTK file
* \param localFct A local-function, providing a `bind(Element)` and an `operator()(LocalDomain)`
* \param name The name to use as identification in the VTK file
* \param ncomps Number of components of the pointwise data. Is extracted
* from the range type of the GridFunction if not given.
* \param type The \ref Vtk::DataTypes used in the output. E.g. FLOAT32,
* or FLOAT64. Is extracted from the range type of the
* GridFunction if not given.
*
* Forwards all the other parmeters to the constructor (1)
*
* NOTE: Stores the localFunction by value.
**/
template <class F,
class = void_t<LocalFunction<F>> >
Function (F&& fct, std::string name,
std::optional<int> ncomps = {},
std::optional<Vtk::DataTypes> type = {})
: localFct_(localFunction(std::forward<F>(fct)))
, name_(std::move(name))
{
using R = Range<LocalFunction<F>>;
template <class LF, class... Args,
class = void_t<HasBind<LF,Element>>,
class R = Range<LF,LocalDomain> >
Function (LF&& localFct, std::string name, Args const&... args)
: Function(std::forward<LF>(localFct), std::move(name),
allComponents(getArg<int,unsigned int,long,unsigned long>(args..., sizeOf<R>())),
getArg<Vtk::RangeTypes>(args..., Vtk::RangeTypes::AUTO),
getArg<Vtk::DataTypes>(args..., Vtk::DataTypes::FLOAT32))
{}
ncomps_ = ncomps ? *ncomps : sizeOf<R>();
type_ = type ? *type : Vtk::Map::type<R>();
}
/// (3) Construct from a Vtk::Function
template <class... Args>
Function (Function<GridView> const& fct, std::string name, Args const&... args)
: Function(fct.localFct_, std::move(name),
getArg<int,unsigned int,long,unsigned long,std::vector<int>>(args..., fct.components_),
getArg<Vtk::RangeTypes>(args..., fct.rangeType_),
getArg<Vtk::DataTypes>(args..., fct.dataType_))
{}
/// Constructor that forward the number of components and data type to the other constructor
template <class F,
/// (4) Construct from a GridFunction
/**
* \param fct A Grid(View)-function, providing a `localFunction(fct)`
* \param name The name to use as identification in the VTK file
*
* Forwards all other arguments to the constructor (1) or (2).
*
* NOTE: Stores the localFunction(fct) by value.
*/
template <class F, class... Args,
disableCopyMove<Function, F> = 0,
class = void_t<LocalFunction<F>> >
Function (F&& fct, Vtk::FieldInfo fieldInfo,
std::optional<Vtk::DataTypes> type = {})
: Function(std::forward<F>(fct), fieldInfo.name(), fieldInfo.ncomps(), type)
Function (F&& fct, std::string name, Args&&... args)
: Function(localFunction(std::forward<F>(fct)), std::move(name), std::forward<Args>(args)...)
{}
/// (5) Constructor that forwards the number of components and data type to the other constructor
template <class F>
Function (F&& fct, Vtk::FieldInfo info)
: Function(std::forward<F>(fct), info.name(), info.size(), info.rangeType(), info.dataType())
{}
/// (6) Construct from legacy VTKFunction
/**
* \param fct The Dune::VTKFunction to wrap
**/
explicit Function (std::shared_ptr<VTKFunction<GridView> const> const& fct)
: localFct_(fct)
, name_(fct->name())
{
setComponents(fct->ncomps());
#if DUNE_VERSION_LT(DUNE_GRID,2,7)
setDataType(Vtk::DataTypes::FLOAT32);
#else
setDataType(dataTypeOf(fct->precision()));
#endif
setRangeType(rangeTypeOf(fct->ncomps()));
}
/// (7) Default constructor. After construction, the function is an an invalid state.
Function () = default;
/// Create a LocalFunction
......@@ -108,23 +166,74 @@ namespace Dune
return name_;
}
/// Return the number of components of the Range
int ncomps () const
/// Set the function name
void setName (std::string name)
{
name_ = std::move(name);
}
/// Return the number of components of the Range as it is written to the file
int numComponents () const
{
return rangeType_ == Vtk::RangeTypes::SCALAR ? 1 :
rangeType_ == Vtk::RangeTypes::VECTOR ? 3 :
rangeType_ == Vtk::RangeTypes::TENSOR ? 9 : int(components_.size());
}
/// Set the components of the Range to visualize
void setComponents (std::vector<int> components)
{
components_ = components;
localFct_.setComponents(components_);
}
/// Set the number of components of the Range and generate component range [0...ncomps)
void setComponents (int ncomps)
{
return ncomps_ > 3 ? 9 : ncomps_ > 1 ? 3 : 1; // tensor, vector, scalar
setComponents(allComponents(ncomps));
}
/// Return the VTK Datatype associated with the functions range type
Vtk::DataTypes type () const
Vtk::DataTypes dataType () const
{
return dataType_;
}
/// Set the data-type for the components
void setDataType (Vtk::DataTypes type)
{
dataType_ = type;
}
/// The category of the range, SCALAR, VECTOR, TENSOR, or UNSPECIFIED
Vtk::RangeTypes rangeType () const
{
return rangeType_;
}
/// Set the category of the range, SCALAR, VECTOR, TENSOR, or UNSPECIFIED
void setRangeType (Vtk::RangeTypes type, std::size_t ncomp = 1)
{
rangeType_ = type;
if (type == Vtk::RangeTypes::AUTO)
rangeType_ = rangeTypeOf(ncomp);
}
/// Set all the parameters from a FieldInfo object
void setFieldInfo (Vtk::FieldInfo info)
{
return type_;
setName(info.name());
setComponents(info.size());
setRangeType(info.rangeType());
setDataType(info.dataType());
}
private:
Vtk::LocalFunction<GridView> localFct_;
std::string name_;
int ncomps_ = 1;
Vtk::DataTypes type_ = Vtk::DataTypes::FLOAT32;
std::vector<int> components_;
Vtk::DataTypes dataType_ = Vtk::DataTypes::FLOAT32;
Vtk::RangeTypes rangeType_ = Vtk::RangeTypes::UNSPECIFIED;
};
} // end namespace Vtk
......
......@@ -21,7 +21,7 @@ namespace Dune
public:
/// Constructor. Stores a shared pointer to the passed Dune::VTKFunction
VTKLocalFunctionWrapper (std::shared_ptr<VTKFunction<GridView> const> const& fct)
explicit VTKLocalFunctionWrapper (std::shared_ptr<VTKFunction<GridView> const> const& fct)
: fct_(fct)
{}
......
......@@ -3,7 +3,7 @@
#include <memory>
#include <type_traits>
#include <dune/common/std/type_traits.hh>
#include <dune/common/typetraits.hh>
#include "localfunctioninterface.hh"
#include "legacyvtkfunction.hh"
......@@ -28,21 +28,52 @@ namespace Dune
template <class LF, class E>
using HasBind = decltype(std::declval<LF>().bind(std::declval<E>()));
private:
struct RangeProxy
{
using value_type = double;
using field_type = double;
RangeProxy (LocalFunctionInterface<GridView> const& localFct,
std::vector<int> const& components,
LocalCoordinate const& local)
: localFct_(localFct)
, components_(components)
, local_(local)
{}
std::size_t size () const
{
return components_.size();
}
double operator[] (std::size_t i) const
{
return i < size() ? localFct_.evaluate(components_[i], local_) : 0.0;
}
private:
LocalFunctionInterface<GridView> const& localFct_;
std::vector<int> const& components_;
LocalCoordinate local_;
};
public:
/// Construct the Vtk::LocalFunction from any function object that has a bind(element) method.
template <class LF,
disableCopyMove<Self, LF> = 0,
class = void_t<HasBind<LF,Entity>> >
LocalFunction (LF&& lf)
explicit LocalFunction (LF&& lf)
: localFct_(std::make_shared<LocalFunctionWrapper<GridView,LF>>(std::forward<LF>(lf)))
{}
/// Construct a Vtk::LocalFunction from a legacy VTKFunction
LocalFunction (std::shared_ptr<VTKFunction<GridView> const> const& lf)
explicit LocalFunction (std::shared_ptr<VTKFunction<GridView> const> const& lf)
: localFct_(std::make_shared<VTKLocalFunctionWrapper<GridView>>(lf))
{}
/// Allow the default construction of a Vtk::LocalFunction
/// Allow the default construction of a Vtk::LocalFunction. After construction, the
/// LocalFunction is in an invalid state.
LocalFunction () = default;
/// Bind the function to the grid entity
......@@ -59,15 +90,28 @@ namespace Dune
localFct_->unbind();
}
/// Evaluate the `comp` component of the Range value at local coordinate `xi`
double evaluate (int comp, LocalCoordinate const& xi) const
/// Return a proxy object to access the components of the range vector
RangeProxy operator() (LocalCoordinate const& xi) const
{
assert(bool(localFct_));
return localFct_->evaluate(comp, xi);
return {*localFct_, components_, xi};
}
/// Evaluate the `c`th component of the Range value at local coordinate `xi`
double evaluate (int c, LocalCoordinate const& xi) const
{
assert(bool(localFct_));
return c < components_.size() ? localFct_->evaluate(components_[c], xi) : 0.0;
}
void setComponents (std::vector<int> components)
{
components_ = std::move(components);
}
private:
std::shared_ptr<LocalFunctionInterface<GridView>> localFct_ = nullptr;
std::vector<int> components_;
};
} // end namespace Vtk
......
dune_add_test(SOURCES test-map-datatypes.cc
LINK_LIBRARIES dunevtk)
dune_add_test(SOURCES test-function.cc
LINK_LIBRARIES dunevtk)
dune_add_test(SOURCES test-typededuction.cc
LINK_LIBRARIES dunevtk)
......
#include <config.h>
#include <optional>
#include <dune/grid/io/file/vtk/common.hh>
#include <dune/grid/utility/structuredgridfactory.hh>
#include <dune/vtk/function.hh>
#include <dune/vtk/vtkwriter.hh>
#if HAVE_DUNE_UGGRID
#include <dune/grid/uggrid.hh>
using GridType = Dune::UGGrid<2>;
#else
#include <dune/grid/yaspgrid.hh>
using GridType = Dune::YaspGrid<2>;
#endif
// Wrapper for global-coordinate functions F
template <class GridView, class F>
class GlobalFunction
{
using Element = typename GridView::template Codim<0>::Entity;
using Geometry = typename Element::Geometry;
public:
GlobalFunction (GridView const& gridView, F const& f)
: gridView_(gridView)
, f_(f)
{}
void bind(Element const& element) { geometry_.emplace(element.geometry()); }
void unbind() { geometry_.reset(); }
auto operator() (typename Geometry::LocalCoordinate const& local) const
{
assert(!!geometry_);
return f_(geometry_->global(local));
}
private:
GridView gridView_;
F f_;
std::optional<Geometry> geometry_;
};
int main (int argc, char** argv)
{
using namespace Dune;