Smith  0.1
Smith is an implicit thermal structural mechanics simulation code.
dfem_weak_form.hpp
Go to the documentation of this file.
1 // Copyright (c) Lawrence Livermore National Security, LLC and
2 // other Smith Project Developers. See the top-level LICENSE file for
3 // details.
4 //
5 // SPDX-License-Identifier: (BSD-3-Clause)
6 
13 #pragma once
14 
15 #include "smith/smith_config.hpp"
16 
17 #ifdef SMITH_USE_ENZYME
18 
20 #include "smith/physics/mesh.hpp"
22 
23 // NOTE (EBC): these should be upstreamed to MFEM, so let's put them in the mfem::future namespace
24 namespace mfem {
25 namespace future {
26 
27 // Inner product of 1D tensors
28 template <typename S, typename T, int m>
29 MFEM_HOST_DEVICE auto inner(const tensor<S, m>& A, const tensor<T, m>& B) -> decltype(S{} * T{})
30 {
31  decltype(S{} * T{}) sum{};
32  for (int i = 0; i < m; i++) {
33  sum += A[i] * B[i];
34  }
35  return sum;
36 }
37 
48 template <typename T>
49 MFEM_HOST_DEVICE constexpr auto detApIm1(const mfem::future::tensor<T, 2, 2>& A)
50 {
51  // From the Cayley-Hamilton theorem, we get that for any N by N matrix A,
52  // det(A - I) - 1 = I1(A) + I2(A) + ... + IN(A),
53  // where the In are the principal invariants of A.
54  // We inline the definitions of the principal invariants to increase computational speed.
55 
56  // equivalent to tr(A) + det(A)
57  return A(0, 0) - A(0, 1) * A(1, 0) + A(1, 1) + A(0, 0) * A(1, 1);
58 }
59 
61 template <typename T>
62 MFEM_HOST_DEVICE constexpr auto detApIm1(const mfem::future::tensor<T, 3, 3>& A)
63 {
64  // For notes on the implementation, see the 2x2 version.
65 
66  // clang-format off
67  // equivalent to tr(A) + I2(A) + det(A)
68  return A(0, 0) + A(1, 1) + A(2, 2)
69  - A(0, 1) * A(1, 0) * (1 + A(2, 2))
70  + A(0, 0) * A(1, 1) * (1 + A(2, 2))
71  - A(0, 2) * A(2, 0) * (1 + A(1, 1))
72  - A(1, 2) * A(2, 1) * (1 + A(0, 0))
73  + A(0, 0) * A(2, 2)
74  + A(1, 1) * A(2, 2)
75  + A(0, 1) * A(1, 2) * A(2, 0)
76  + A(0, 2) * A(1, 0) * A(2, 1);
77  // clang-format on
78 }
79 
80 } // namespace future
81 } // namespace mfem
82 
83 namespace smith {
84 
85 // NOTE: Args needs to be on the functor struct instead of the operator() so that operator() isn't overloaded and dfem
86 // can deduce the type
91 template <typename Primal, typename OrigQFn, typename... Args>
92 struct InnerQFunction {
93  InnerQFunction(OrigQFn orig_qfn) : orig_qfn_(orig_qfn) {}
94 
95  SMITH_HOST_DEVICE inline auto operator()(Primal V, Args... args) const
96  {
97  Primal orig_residual = mfem::future::get<0>(orig_qfn_(std::forward<Args>(args)...));
98  return mfem::future::tuple{mfem::future::inner(V, orig_residual)};
99  }
100 
101  OrigQFn orig_qfn_;
102 };
103 
104 // Step 2: deduce the type of the parameters and the first tuple element of the return type of the operator()
105 // Step 3: create the InnerQFunction with the deduced types
106 template <typename OrigQFn, typename R, typename... Args>
107 auto makeInnerQFunction(OrigQFn orig_qfn, R (OrigQFn::*)(Args...) const)
108 {
109  // TODO: is there a better way to get the type of the first tuple element?
110  return InnerQFunction<decltype(mfem::future::type<0>(R{})), OrigQFn, Args...>{orig_qfn};
111 }
112 
113 // Step 1: get function pointer to operator()
114 template <typename OrigQFn>
115 auto makeInnerQFunction(OrigQFn orig_qfn)
116 {
117  return makeInnerQFunction(orig_qfn, &OrigQFn::operator());
118 }
119 
127 class DfemWeakForm : public WeakForm {
128  public:
129  using SpacesT = std::vector<const mfem::ParFiniteElementSpace*>;
130 
139  DfemWeakForm(std::string physics_name, std::shared_ptr<Mesh> mesh,
140  const mfem::ParFiniteElementSpace& output_mfem_space, const SpacesT& input_mfem_spaces)
141  : WeakForm(physics_name),
142  mesh_(mesh),
143  output_mfem_space_(output_mfem_space),
144  input_mfem_spaces_(input_mfem_spaces),
145  weak_form_(makeFieldDescriptors({&output_mfem_space}, input_mfem_spaces.size()),
146  makeFieldDescriptors(input_mfem_spaces), mesh->mfemParMesh()),
147  v_dot_weak_form_residual_(makeFieldDescriptors({&output_mfem_space}, input_mfem_spaces.size()),
148  makeFieldDescriptors(input_mfem_spaces), mesh->mfemParMesh()),
149  residual_vector_(output_mfem_space.GetTrueVSize())
150  {
151  // sum field operator doesn't work with sum factorization
152  v_dot_weak_form_residual_.DisableTensorProductStructure();
153  residual_vector_.UseDevice(true);
154  }
155 
177  template <typename BodyIntegralType, typename InputType, typename OutputType, typename DerivIdsType>
178  void addBodyIntegral(mfem::Array<int> domain_attributes, BodyIntegralType body_integral, InputType integral_inputs,
179  OutputType integral_outputs, const mfem::IntegrationRule& integration_rule,
180  DerivIdsType derivative_ids)
181  {
182  weak_form_.AddDomainIntegrator(body_integral, integral_inputs, integral_outputs, integration_rule,
183  domain_attributes, derivative_ids);
184  auto scalar_body_integral = makeInnerQFunction(body_integral);
185  v_dot_weak_form_residual_.AddDomainIntegrator(
186  scalar_body_integral, addToTupleType(integral_inputs, mfem::future::get<0>(integral_outputs)),
187  mfem::future::tuple<mfem::future::Sum<mfem::future::get<0>(integral_outputs).GetFieldId()>>{}, integration_rule,
188  domain_attributes, derivative_ids);
189  }
190 
192  mfem::Vector residual(TimeInfo time_info, ConstFieldPtr /*shape_disp*/, const std::vector<ConstFieldPtr>& fields,
193  const std::vector<ConstQuadratureFieldPtr>& /*quad_fields*/ = {}) const override
194  {
195  dt_ = time_info.dt();
196  cycle_ = time_info.cycle();
197 
198  weak_form_.SetParameters(getLVectors(fields));
199  weak_form_.Mult(residual_vector_, residual_vector_);
200  return residual_vector_;
201  }
202 
204  std::unique_ptr<mfem::HypreParMatrix> jacobian(
205  TimeInfo time_info, ConstFieldPtr /*shape_disp*/, const std::vector<ConstFieldPtr>& /*fields*/,
206  const std::vector<double>& /*jacobian_weights*/,
207  const std::vector<ConstQuadratureFieldPtr>& /*quad_fields*/ = {}) const override
208  {
209  SLIC_ERROR_ROOT("DfemWeakForm does not support matrix assembly");
210  dt_ = time_info.dt();
211  cycle_ = time_info.cycle();
212 
213  return std::make_unique<mfem::HypreParMatrix>();
214  }
215 
217  void jvp(TimeInfo time_info, ConstFieldPtr /*shape_disp*/, const std::vector<ConstFieldPtr>& /*fields*/,
218  const std::vector<ConstQuadratureFieldPtr>& /*quad_fields*/, ConstFieldPtr /*v_shape_disp*/,
219  const std::vector<ConstFieldPtr>& /*v_fields*/,
220  const std::vector<ConstQuadratureFieldPtr>& /*v_quad_fields*/, DualFieldPtr /*jvp_reaction*/) const override
221  {
222  SLIC_ERROR_ROOT("DfemWeakForm does not support jvp calculations");
223 
224  // SLIC_ERROR_IF(v_fields.size() != fields.size(),
225  // "Invalid number of field sensitivities relative to the number of fields");
226  // SLIC_ERROR_IF(jvp_reactions.size() != 1, "FunctionalResidual nonlinear systems only supports 1 output residual");
227  dt_ = time_info.dt();
228  cycle_ = time_info.cycle();
229 
230  // TODO (EBC): add in a future PR...
231  // std::vector<mfem::Vector*> test_par_gf({&fields[0]->gridFunction()});
232  // std::vector<mfem::Vector*> field_par_gf = getLVectors(fields);
233 
234  // *jvp_reactions[0] = 0.0;
235 
236  // for (size_t input_col = 0; input_col < fields.size(); ++input_col) {
237  // if (v_fields[input_col] != nullptr) {
238  // auto deriv_op = weak_form_.GetDerivative(input_col, test_par_gf, field_par_gf);
239  // deriv_op->AddMult(*v_fields[input_col], *jvp_reactions[0]);
240  // }
241  // }
242  }
243 
245  void vjp(TimeInfo time_info, ConstFieldPtr /*shape_disp*/, const std::vector<ConstFieldPtr>& /*fields*/,
246  const std::vector<ConstQuadratureFieldPtr>& /*quad_fields*/, ConstFieldPtr /*v_fields*/,
247  DualFieldPtr /*vjp_shape_disp_sensitivity*/, const std::vector<DualFieldPtr>& /*vjp_sensitivities*/,
248  const std::vector<QuadratureFieldPtr>& /*vjp_quad_field_sensitivities*/) const override
249  {
250  SLIC_ERROR_ROOT("DfemWeakForm does not support vjp calculations");
251 
252  // SLIC_ERROR_IF(vjp_sensitivities.size() != fields.size(),
253  // "Invalid number of field sensitivities relative to the number of fields");
254  // SLIC_ERROR_IF(v_fields.size() != 1, "FunctionalResidual nonlinear systems only supports 1 output residual");
255  dt_ = time_info.dt();
256  cycle_ = time_info.cycle();
257 
258  // TODO (EBC): add in a future PR...
259  // std::vector<mfem::Vector*> test_par_gf({&v_fields[0]->gridFunction()});
260  // std::vector<mfem::Vector*> field_par_gf = getLVectors(fields);
261  // // field_par_gf.push_back(&v_fields[0]->gridFunction());
262 
263  // for (size_t input_col = 0; input_col < fields.size(); ++input_col) {
264  // if (vjp_sensitivities[input_col] != nullptr) {
265  // auto deriv_op = v_dot_weak_form_residual_.GetDerivative(input_col, test_par_gf, field_par_gf);
266  // // do this entry by entry until assembly is supported
267  // mfem::Vector direction(vjp_sensitivities[input_col]->Size());
268  // direction = 0.0;
269  // for (int i = 0; i < vjp_sensitivities[input_col]->Size(); ++i) {
270  // direction[i] = 1.0;
271  // mfem::Vector value(1);
272  // deriv_op->Mult(direction, value);
273  // (*vjp_sensitivities[input_col])[i] += value[0];
274  // direction[i] = 0.0;
275  // }
276  // }
277  // }
278  }
279 
280  protected:
281  static std::vector<mfem::future::FieldDescriptor> makeFieldDescriptors(
282  const std::vector<const mfem::ParFiniteElementSpace*>& spaces, size_t offset = 0)
283  {
284  std::vector<mfem::future::FieldDescriptor> field_descriptors;
285  field_descriptors.reserve(spaces.size());
286  for (size_t i = 0; i < spaces.size(); ++i) {
287  field_descriptors.emplace_back(i + offset, spaces[i]);
288  }
289  return field_descriptors;
290  }
291 
292  std::vector<mfem::Vector*> getLVectors(const std::vector<ConstFieldPtr>& fields) const
293  {
294  std::vector<mfem::Vector*> fields_l;
295  fields_l.reserve(fields.size());
296  for (size_t i = 0; i < fields.size(); ++i) {
297  fields_l.push_back(&fields[i]->gridFunction());
298  }
299  return fields_l;
300  }
301 
302  template <typename Tnew, typename... Ttuple>
303  static auto addToTupleType(const mfem::future::tuple<Ttuple...>&, const Tnew&)
304  {
305  return mfem::future::tuple<Tnew, Ttuple...>{};
306  }
307 
308  // The field ID doesn't matter, since the test function is one
309  template <int Id, template <int> class FieldOp>
310  static auto makeVirtualWorkOutputs(mfem::future::tuple<FieldOp<Id>>)
311  {
312  return mfem::future::tuple<mfem::future::Sum<Id>>{};
313  }
314 
316  mutable double dt_ = std::numeric_limits<double>::max();
317 
319  mutable size_t cycle_ = 0;
320 
322  std::shared_ptr<Mesh> mesh_;
323 
325  const mfem::ParFiniteElementSpace& output_mfem_space_;
326 
328  std::vector<const mfem::ParFiniteElementSpace*> input_mfem_spaces_;
329 
331  mutable mfem::future::DifferentiableOperator weak_form_;
332 
334  mutable mfem::future::DifferentiableOperator v_dot_weak_form_residual_;
335 
337  mutable mfem::Vector residual_vector_;
338 };
339 
340 } // namespace smith
341 
342 #endif
#define SMITH_HOST_DEVICE
Macro that evaluates to __host__ __device__ when compiling with nvcc or amdclang and does nothing on ...
Definition: accelerator.hpp:37
This file contains the declaration of structure that manages the MFEM objects that make up the state ...
Smith mesh class which assists in constructing the appropriate parallel mfem meshes and registering a...
Accelerator functionality.
Definition: smith.cpp:36
tuple(T...) -> tuple< T... >
Class template argument deduction rule for tuples.
FiniteElementDual * DualFieldPtr
using
Definition: field_types.hpp:33
std::vector< const mfem::ParFiniteElementSpace * > spaces(const std::vector< FieldState > &states, const std::vector< FieldState > &params={})
Get the spaces from the primal fields of a vector of field states.
constexpr SMITH_HOST_DEVICE auto detApIm1(const tensor< T, 2, 2 > &A)
computes det(A + I) - 1, where precision is not lost when the entries A_{ij} << 1
Definition: tensor.hpp:1326
SMITH_HOST_DEVICE auto max(dual< gradient_type > a, double b)
Implementation of max for dual numbers.
Definition: dual.hpp:229
constexpr SMITH_HOST_DEVICE auto inner(const dual< S > &A, const dual< T > &B)
Definition: dual.hpp:281
FiniteElementState const * ConstFieldPtr
using
Definition: field_types.hpp:36
Specifies interface for evaluating weak form residuals and their gradients.