15 #include "smith/smith_config.hpp"
17 #ifdef SMITH_USE_ENZYME
28 template <
typename S,
typename T,
int m>
29 MFEM_HOST_DEVICE
auto inner(
const tensor<S, m>& A,
const tensor<T, m>& B) -> decltype(S{} * T{})
31 decltype(S{} * T{}) sum{};
32 for (
int i = 0; i < m; i++) {
49 MFEM_HOST_DEVICE constexpr
auto detApIm1(
const mfem::future::tensor<T, 2, 2>& A)
57 return A(0, 0) - A(0, 1) * A(1, 0) + A(1, 1) + A(0, 0) * A(1, 1);
62 MFEM_HOST_DEVICE constexpr
auto detApIm1(
const mfem::future::tensor<T, 3, 3>& A)
68 return A(0, 0) + A(1, 1) + A(2, 2)
69 - A(0, 1) * A(1, 0) * (1 + A(2, 2))
70 + A(0, 0) * A(1, 1) * (1 + A(2, 2))
71 - A(0, 2) * A(2, 0) * (1 + A(1, 1))
72 - A(1, 2) * A(2, 1) * (1 + A(0, 0))
75 + A(0, 1) * A(1, 2) * A(2, 0)
76 + A(0, 2) * A(1, 0) * A(2, 1);
91 template <
typename Primal,
typename OrigQFn,
typename... Args>
92 struct InnerQFunction {
93 InnerQFunction(OrigQFn orig_qfn) : orig_qfn_(orig_qfn) {}
97 Primal orig_residual = mfem::future::get<0>(orig_qfn_(std::forward<Args>(args)...));
106 template <
typename OrigQFn,
typename R,
typename... Args>
107 auto makeInnerQFunction(OrigQFn orig_qfn, R (OrigQFn::*)(Args...)
const)
110 return InnerQFunction<decltype(mfem::future::type<0>(R{})), OrigQFn, Args...>{orig_qfn};
114 template <
typename OrigQFn>
115 auto makeInnerQFunction(OrigQFn orig_qfn)
117 return makeInnerQFunction(orig_qfn, &OrigQFn::operator());
127 class DfemWeakForm :
public WeakForm {
129 using SpacesT = std::vector<const mfem::ParFiniteElementSpace*>;
139 DfemWeakForm(std::string physics_name, std::shared_ptr<Mesh> mesh,
140 const mfem::ParFiniteElementSpace& output_mfem_space,
const SpacesT& input_mfem_spaces)
141 : WeakForm(physics_name),
143 output_mfem_space_(output_mfem_space),
144 input_mfem_spaces_(input_mfem_spaces),
145 weak_form_(makeFieldDescriptors({&output_mfem_space}, input_mfem_spaces.size()),
146 makeFieldDescriptors(input_mfem_spaces), mesh->mfemParMesh()),
147 v_dot_weak_form_residual_(makeFieldDescriptors({&output_mfem_space}, input_mfem_spaces.size()),
148 makeFieldDescriptors(input_mfem_spaces), mesh->mfemParMesh()),
149 residual_vector_(output_mfem_space.GetTrueVSize())
152 v_dot_weak_form_residual_.DisableTensorProductStructure();
153 residual_vector_.UseDevice(
true);
177 template <
typename BodyIntegralType,
typename InputType,
typename OutputType,
typename DerivIdsType>
178 void addBodyIntegral(mfem::Array<int> domain_attributes, BodyIntegralType body_integral, InputType integral_inputs,
179 OutputType integral_outputs,
const mfem::IntegrationRule& integration_rule,
180 DerivIdsType derivative_ids)
182 weak_form_.AddDomainIntegrator(body_integral, integral_inputs, integral_outputs, integration_rule,
183 domain_attributes, derivative_ids);
184 auto scalar_body_integral = makeInnerQFunction(body_integral);
185 v_dot_weak_form_residual_.AddDomainIntegrator(
186 scalar_body_integral, addToTupleType(integral_inputs, mfem::future::get<0>(integral_outputs)),
187 mfem::future::tuple<mfem::future::Sum<mfem::future::get<0>(integral_outputs).GetFieldId()>>{}, integration_rule,
188 domain_attributes, derivative_ids);
192 mfem::Vector residual(TimeInfo time_info,
ConstFieldPtr ,
const std::vector<ConstFieldPtr>& fields,
193 const std::vector<ConstQuadratureFieldPtr>& = {})
const override
195 dt_ = time_info.dt();
196 cycle_ = time_info.cycle();
198 weak_form_.SetParameters(getLVectors(fields));
199 weak_form_.Mult(residual_vector_, residual_vector_);
200 return residual_vector_;
204 std::unique_ptr<mfem::HypreParMatrix> jacobian(
205 TimeInfo time_info,
ConstFieldPtr ,
const std::vector<ConstFieldPtr>& ,
206 const std::vector<double>& ,
207 const std::vector<ConstQuadratureFieldPtr>& = {})
const override
209 SLIC_ERROR_ROOT(
"DfemWeakForm does not support matrix assembly");
210 dt_ = time_info.dt();
211 cycle_ = time_info.cycle();
213 return std::make_unique<mfem::HypreParMatrix>();
217 void jvp(TimeInfo time_info,
ConstFieldPtr ,
const std::vector<ConstFieldPtr>& ,
218 const std::vector<ConstQuadratureFieldPtr>& ,
ConstFieldPtr ,
219 const std::vector<ConstFieldPtr>& ,
220 const std::vector<ConstQuadratureFieldPtr>& ,
DualFieldPtr )
const override
222 SLIC_ERROR_ROOT(
"DfemWeakForm does not support jvp calculations");
227 dt_ = time_info.dt();
228 cycle_ = time_info.cycle();
245 void vjp(TimeInfo time_info,
ConstFieldPtr ,
const std::vector<ConstFieldPtr>& ,
246 const std::vector<ConstQuadratureFieldPtr>& ,
ConstFieldPtr ,
248 const std::vector<QuadratureFieldPtr>& )
const override
250 SLIC_ERROR_ROOT(
"DfemWeakForm does not support vjp calculations");
255 dt_ = time_info.dt();
256 cycle_ = time_info.cycle();
281 static std::vector<mfem::future::FieldDescriptor> makeFieldDescriptors(
282 const std::vector<const mfem::ParFiniteElementSpace*>&
spaces,
size_t offset = 0)
284 std::vector<mfem::future::FieldDescriptor> field_descriptors;
285 field_descriptors.reserve(
spaces.size());
286 for (
size_t i = 0; i <
spaces.size(); ++i) {
287 field_descriptors.emplace_back(i + offset,
spaces[i]);
289 return field_descriptors;
292 std::vector<mfem::Vector*> getLVectors(
const std::vector<ConstFieldPtr>& fields)
const
294 std::vector<mfem::Vector*> fields_l;
295 fields_l.reserve(fields.size());
296 for (
size_t i = 0; i < fields.size(); ++i) {
297 fields_l.push_back(&fields[i]->gridFunction());
302 template <
typename Tnew,
typename... Ttuple>
303 static auto addToTupleType(
const mfem::future::tuple<Ttuple...>&,
const Tnew&)
309 template <
int Id,
template <
int>
class FieldOp>
312 return mfem::future::tuple<mfem::future::Sum<Id>>{};
319 mutable size_t cycle_ = 0;
322 std::shared_ptr<Mesh> mesh_;
325 const mfem::ParFiniteElementSpace& output_mfem_space_;
328 std::vector<const mfem::ParFiniteElementSpace*> input_mfem_spaces_;
331 mutable mfem::future::DifferentiableOperator weak_form_;
334 mutable mfem::future::DifferentiableOperator v_dot_weak_form_residual_;
337 mutable mfem::Vector residual_vector_;
#define SMITH_HOST_DEVICE
Macro that evaluates to __host__ __device__ when compiling with nvcc or amdclang and does nothing on ...
This file contains the declaration of structure that manages the MFEM objects that make up the state ...
Smith mesh class which assists in constructing the appropriate parallel mfem meshes and registering a...
Accelerator functionality.
tuple(T...) -> tuple< T... >
Class template argument deduction rule for tuples.
FiniteElementDual * DualFieldPtr
using
std::vector< const mfem::ParFiniteElementSpace * > spaces(const std::vector< FieldState > &states, const std::vector< FieldState > ¶ms={})
Get the spaces from the primal fields of a vector of field states.
constexpr SMITH_HOST_DEVICE auto detApIm1(const tensor< T, 2, 2 > &A)
computes det(A + I) - 1, where precision is not lost when the entries A_{ij} << 1
SMITH_HOST_DEVICE auto max(dual< gradient_type > a, double b)
Implementation of max for dual numbers.
constexpr SMITH_HOST_DEVICE auto inner(const dual< S > &A, const dual< T > &B)
FiniteElementState const * ConstFieldPtr
using