Smith  0.1
Smith is an implicit thermal structural mechanics simulation code.
nonlinear_solve.cpp
8 
9 namespace smith {
10 
12 void applyBoundaryConditions(double time, const smith::BoundaryConditionManager* bc_manager,
13  smith::FEFieldPtr& primal_field, const smith::FEFieldPtr& bc_field_ptr)
14 {
15  if (bc_field_ptr) {
16  auto constrained_dofs = bc_manager->allEssentialTrueDofs();
17  for (int i = 0; i < constrained_dofs.Size(); i++) {
18  int j = constrained_dofs[i];
19  (*primal_field)[j] = (*bc_field_ptr)(j);
20  }
21  } else {
22  for (auto& bc : bc_manager->essentials()) {
23  bc.setDofs(*primal_field, time);
24  }
25  }
26 }
27 
48 FieldState nonlinearSolve(const WeakForm* residual_eval, const FieldState& shape_disp,
49  const std::vector<FieldState>& states, const std::vector<FieldState>& params,
50  const std::vector<double>& state_update_weights, size_t primal_solve_state_index,
51  size_t dirichlet_state_index, const TimeInfo& time_info, const DifferentiableSolver* solver,
52  const BoundaryConditionManager* bc_manager, const FieldState* bc_field = nullptr)
53 {
55  SLIC_ERROR_IF(states.size() != state_update_weights.size(), "State and state weight fields are inconsistent");
56  SLIC_ERROR_IF(state_update_weights[primal_solve_state_index] != 1.0, "Primal state must have a weight of 1.0");
57 
58  std::vector<gretl::StateBase> allFields;
59  for (auto& s : states) allFields.push_back(s);
60  for (auto& p : params) allFields.push_back(p);
61  allFields.push_back(shape_disp);
62 
63  bool have_bc_field = bc_field;
64  if (have_bc_field) {
65  allFields.push_back(*bc_field);
66  }
67 
68  FieldState sol = states[primal_solve_state_index].clone(allFields);
69 
70  sol.set_eval([=](const gretl::UpstreamStates& inputs, gretl::DownstreamState& output) {
71  SMITH_MARK_BEGIN("solve forward");
72 
73  const size_t num_states = state_update_weights.size();
74 
75  std::vector<size_t> non_primal_to_state_index;
76  for (size_t i = 0; i < num_states; ++i) {
77  if (i != primal_solve_state_index) {
78  non_primal_to_state_index.push_back(i);
79  }
80  }
81 
82  const size_t num_extra_args = have_bc_field ? 2 : 1;
83  const size_t num_fields = inputs.size() - num_extra_args;
84 
85  std::vector<FEFieldPtr> corrected_fields(num_fields);
86  for (size_t field_index = 0; field_index < num_fields; ++field_index) {
87  if (field_index < state_update_weights.size() && state_update_weights[field_index] != 0.0) {
88  corrected_fields[field_index] = std::make_shared<FiniteElementState>(*inputs[field_index].get<FEFieldPtr>());
89  } else {
90  corrected_fields[field_index] = inputs[field_index].get<FEFieldPtr>();
91  }
92  }
93 
94  const FEFieldPtr shape_disp_ptr = inputs[num_fields].get<FEFieldPtr>();
95 
96  FEFieldPtr bc_field_ptr;
97  if (have_bc_field) {
98  bc_field_ptr = inputs[num_fields + num_extra_args - 1].get<FEFieldPtr>();
99  }
100 
101  FEFieldPtr s0 = corrected_fields[primal_solve_state_index];
102  FEFieldPtr s = std::make_shared<FiniteElementState>(s0->space(), "s");
103 
104  if (bc_manager && (dirichlet_state_index == primal_solve_state_index)) {
105  applyBoundaryConditions(time_info.time(), bc_manager, s0, bc_field_ptr);
106  }
107 
108  s = solver->solve(
109  *s0, // initial guess when solving for the primal index field
110  [=](const FiniteElementState& s_) {
111  FEFieldPtr primal_field = corrected_fields[primal_solve_state_index];
112  *primal_field = s_;
113 
114  if (bc_manager && (dirichlet_state_index == primal_solve_state_index)) {
115  applyBoundaryConditions(time_info.time(), bc_manager, primal_field, bc_field_ptr);
116  }
117 
118  for (size_t corrected_field_index : non_primal_to_state_index) {
119  if (state_update_weights[corrected_field_index] != 0.0) {
120  *corrected_fields[corrected_field_index] = *inputs[corrected_field_index].get<FEFieldPtr>();
121  corrected_fields[corrected_field_index]->Add(state_update_weights[corrected_field_index], *primal_field);
122  corrected_fields[corrected_field_index]->Add(-state_update_weights[corrected_field_index], *s0);
123  }
124  }
125 
126  auto r = residual_eval->residual(time_info, shape_disp_ptr.get(), getConstFieldPointers(corrected_fields));
127 
128  if (bc_manager) {
129  if (dirichlet_state_index == primal_solve_state_index) {
130  auto constrained_dofs = bc_manager->allEssentialTrueDofs();
131  r.SetSubVector(constrained_dofs, 0.0);
132  }
133  }
134 
135  return r;
136  },
137  [=](const FiniteElementState& s_) {
138  FEFieldPtr primal_field = corrected_fields[primal_solve_state_index];
139  *primal_field = s_;
140 
141  if (bc_manager && (dirichlet_state_index == primal_solve_state_index)) {
142  applyBoundaryConditions(time_info.time(), bc_manager, primal_field, bc_field_ptr);
143  }
144 
145  for (size_t corrected_field_index : non_primal_to_state_index) {
146  if (state_update_weights[corrected_field_index] != 0.0) {
147  *corrected_fields[corrected_field_index] = *inputs[corrected_field_index].get<FEFieldPtr>();
148  corrected_fields[corrected_field_index]->Add(state_update_weights[corrected_field_index], *primal_field);
149  corrected_fields[corrected_field_index]->Add(-state_update_weights[corrected_field_index], *s0);
150  }
151  }
152 
153  auto J = residual_eval->jacobian(time_info, shape_disp_ptr.get(), getConstFieldPointers(corrected_fields),
154  state_update_weights);
155 
156  if (bc_manager) {
157  if (dirichlet_state_index == primal_solve_state_index) {
158  J->EliminateBC(bc_manager->allEssentialTrueDofs(), mfem::Operator::DiagonalPolicy::DIAG_ONE);
159  }
160  }
161  return J;
162  });
163 
164  output.set<FEFieldPtr, FEDualPtr>(s);
165 
166  SMITH_MARK_END("solve forward");
167  });
168 
169  sol.set_vjp([=](gretl::UpstreamStates& inputs, const gretl::DownstreamState& output) {
170  SMITH_MARK_BEGIN("solve reverse");
171  const FEFieldPtr s = output.get<FEFieldPtr>(); // get the final solution
172  const FEDualPtr s_dual = output.get_dual<FEDualPtr, FEFieldPtr>(); // get the dual load
173 
174  const size_t num_states = state_update_weights.size();
175 
176  std::vector<size_t> non_primal_to_state_index;
177  for (size_t i = 0; i < num_states; ++i) {
178  if (i != primal_solve_state_index) {
179  non_primal_to_state_index.push_back(i);
180  }
181  }
182 
183  const size_t num_extra_args = have_bc_field ? 2 : 1;
184  const size_t num_fields = inputs.size() - num_extra_args;
185 
186  std::vector<FEFieldPtr> corrected_fields(num_fields);
187  for (size_t field_index = 0; field_index < num_fields; ++field_index) {
188  if (field_index < state_update_weights.size() && state_update_weights[field_index] != 0.0) {
189  corrected_fields[field_index] = std::make_shared<FiniteElementState>(*inputs[field_index].get<FEFieldPtr>());
190  } else {
191  corrected_fields[field_index] = inputs[field_index].get<FEFieldPtr>();
192  }
193  }
194 
195  const FEFieldPtr shape_disp_ptr = inputs[num_fields].get<FEFieldPtr>();
196 
197  const FEFieldPtr s0 = inputs[primal_solve_state_index].get<FEFieldPtr>();
198 
199  *corrected_fields[primal_solve_state_index] = *s;
200  for (size_t corrected_field_index : non_primal_to_state_index) {
201  if (state_update_weights[corrected_field_index] != 0.0) {
202  corrected_fields[corrected_field_index]->Add(state_update_weights[corrected_field_index], *s);
203  corrected_fields[corrected_field_index]->Add(-state_update_weights[corrected_field_index], *s0);
204  }
205  }
206 
207  solver->clearMemory();
208  auto J = residual_eval->jacobian(time_info, shape_disp_ptr.get(), getConstFieldPointers(corrected_fields),
209  state_update_weights, {});
210 
211  if (bc_manager) {
212  if (dirichlet_state_index == primal_solve_state_index) {
213  J->EliminateBC(bc_manager->allEssentialTrueDofs(), mfem::Operator::DiagonalPolicy::DIAG_ONE);
214  }
215  }
216 
217  auto J_T = std::unique_ptr<mfem::HypreParMatrix>(J->Transpose());
218  J.reset();
219 
220  auto s_adjoint_ptr = solver->solveAdjoint(*s_dual, std::move(J_T));
221 
222  if (bc_manager) {
223  if (dirichlet_state_index == primal_solve_state_index) {
224  s_adjoint_ptr->SetSubVector(bc_manager->allEssentialTrueDofs(), 0.0);
225  }
226  }
227 
228  *s_adjoint_ptr *= -1.0;
229 
230  std::vector<DualFieldPtr> field_sensitivities(num_fields, nullptr);
231  FEDualPtr shape_disp_sensitivity = inputs[num_fields].get_dual<FEDualPtr, FEFieldPtr>();
232  for (size_t state_index = 0; state_index < num_states; ++state_index) {
233  field_sensitivities[state_index] = inputs[state_index].get_dual<FEDualPtr, FEFieldPtr>().get();
234  }
235  for (size_t param_index = num_states; param_index < num_fields; ++param_index) {
236  field_sensitivities[param_index] = inputs[param_index].get_dual<FEDualPtr, FEFieldPtr>().get();
237  }
238 
239  auto primal_sensitivity = std::make_shared<FiniteElementDual>(*field_sensitivities[primal_solve_state_index]);
240  field_sensitivities[primal_solve_state_index] = primal_sensitivity.get();
241  *field_sensitivities[primal_solve_state_index] = *s_dual;
242 
243  residual_eval->vjp(time_info, shape_disp_ptr.get(), getConstFieldPointers(corrected_fields), {},
244  s_adjoint_ptr.get(), shape_disp_sensitivity.get(), field_sensitivities, {});
245 
246  if (bc_manager && have_bc_field && dirichlet_state_index == primal_solve_state_index) {
247  auto bc_dual_ptr = inputs[num_fields + num_extra_args - 1].get_dual<FEDualPtr, FEFieldPtr>();
248  field_sensitivities[primal_solve_state_index]->SetSubVectorComplement(bc_manager->allEssentialTrueDofs(), 0.0);
249  *bc_dual_ptr += *field_sensitivities[primal_solve_state_index];
250  }
251 
252  SMITH_MARK_END("solve reverse");
253  });
254 
255  sol.finalize();
256 
257  return sol;
258 }
259 
260 FieldState solve(const WeakForm& weak_form, const FieldState& shape_disp, const std::vector<FieldState>& states,
261  const std::vector<FieldState>& params, const TimeInfo& time_info, const DifferentiableSolver& solver,
262  const DirichletBoundaryConditions& bcs, size_t unknown_state_index)
263 {
264  std::vector<double> state_update_weights(states.size(), 0.0);
265  state_update_weights[unknown_state_index] = 1.0;
266  return nonlinearSolve(&weak_form, shape_disp, states, params, state_update_weights, unknown_state_index,
267  unknown_state_index, time_info, &solver, &bcs.getBoundaryConditionManager());
268 }
269 
270 std::vector<FieldState> block_solve(const std::vector<WeakForm*>& residual_evals,
271  const std::vector<std::vector<size_t>> block_indices, const FieldState& shape_disp,
272  const std::vector<std::vector<FieldState>>& states,
273  const std::vector<std::vector<FieldState>>& params, const TimeInfo& time_info,
274  const DifferentiableBlockSolver* solver,
275  const std::vector<const BoundaryConditionManager*>& bc_managers)
276 {
278  size_t num_rows_ = residual_evals.size();
279 
280  SLIC_ERROR_IF(num_rows_ != block_indices.size(), "Block indices size not consistent with number of residual rows");
281  SLIC_ERROR_IF(num_rows_ != states.size(),
282  "Number of state input vectors not consistent with number of residual rows");
283  SLIC_ERROR_IF(num_rows_ != params.size(),
284  "Number of parameter input vectors not consistent with number of residual rows");
285  SLIC_ERROR_IF(num_rows_ != bc_managers.size(),
286  "Number of boundary condition manager not consistent with number of residual rows");
287 
288  for (size_t r = 0; r < num_rows_; ++r) {
289  SLIC_ERROR_IF(num_rows_ != block_indices[r].size(), "All block index rows must have the same number of columns");
290  }
291 
292  std::vector<size_t> num_state_inputs;
293  std::vector<gretl::StateBase> allFields;
294  for (auto& ss : states) {
295  num_state_inputs.push_back(ss.size());
296  for (auto& s : ss) {
297  allFields.push_back(s);
298  }
299  }
300  std::vector<size_t> num_param_inputs;
301  for (auto& ps : params) {
302  num_param_inputs.push_back(ps.size());
303  for (auto& p : ps) {
304  allFields.push_back(p);
305  }
306  }
307  allFields.push_back(shape_disp);
308  struct ZeroDualVectors {
309  std::vector<FEDualPtr> operator()(const std::vector<FEFieldPtr>& fs)
310  {
311  std::vector<FEDualPtr> ds(fs.size());
312  for (size_t i = 0; i < fs.size(); ++i) {
313  ds[i] = std::make_shared<FiniteElementDual>(fs[i]->space(), fs[i]->name() + "_dual");
314  }
315  return ds;
316  }
317  };
318 
319  FieldVecState sol =
320  shape_disp.create_state<std::vector<FEFieldPtr>, std::vector<FEDualPtr>>(allFields, ZeroDualVectors());
321  sol.set_eval([=](const gretl::UpstreamStates& upstreams, gretl::DownstreamState& downstream) {
322  SMITH_MARK_BEGIN("solve forward");
323  const size_t num_rows = num_state_inputs.size();
324  std::vector<std::vector<FEFieldPtr>> input_fields(num_rows);
325  SLIC_ERROR_IF(num_rows != num_param_inputs.size(), "row count for params and states are inconsistent");
326 
327  // The order of inputs in upstreams is:
328  // states of residual 0, states of residual 1, ... , params of residual 0, params of residual 1, ...
329  size_t field_count = 0;
330  for (size_t row_i = 0; row_i < num_rows; ++row_i) {
331  for (size_t state_i = 0; state_i < num_state_inputs[row_i]; ++state_i) {
332  input_fields[row_i].push_back(upstreams[field_count++].get<FEFieldPtr>());
333  }
334  }
335  for (size_t row_i = 0; row_i < num_rows; ++row_i) {
336  for (size_t param_i = 0; param_i < num_param_inputs[row_i]; ++param_i) {
337  input_fields[row_i].push_back(upstreams[field_count++].get<FEFieldPtr>());
338  }
339  }
340 
341  std::vector<FEFieldPtr> diagonal_fields(num_rows);
342  for (size_t row_i = 0; row_i < num_rows; ++row_i) {
343  size_t prime_unknown_row_i = block_indices[row_i][row_i];
344  SLIC_ERROR_IF(prime_unknown_row_i == invalid_block_index,
345  "The primary unknown field (field index for block_index[n][n], must not be invalid)");
346  diagonal_fields[row_i] = std::make_shared<FiniteElementState>(*input_fields[row_i][prime_unknown_row_i]);
347  }
348 
349  for (size_t row_i = 0; row_i < num_rows; ++row_i) {
350  FEFieldPtr primal_field_row_i = diagonal_fields[row_i];
351  applyBoundaryConditions(time_info.time(), bc_managers[row_i], primal_field_row_i, nullptr);
352  }
353 
354  for (size_t row_i = 0; row_i < num_rows; ++row_i) {
355  for (size_t col_j = 0; col_j < num_rows; ++col_j) {
356  size_t prime_unknown_ij = block_indices[row_i][col_j];
357  if (prime_unknown_ij != invalid_block_index) {
358  input_fields[row_i][block_indices[row_i][col_j]] = diagonal_fields[col_j];
359  }
360  }
361  }
362 
363  const FEFieldPtr shape_disp_ptr = upstreams[field_count].get<FEFieldPtr>();
364 
365  auto eval_residuals = [=](const std::vector<FEFieldPtr>& unknowns) {
366  SLIC_ERROR_IF(unknowns.size() != num_rows,
367  "block solver unknowns size must match the number or residuals in block_solve");
368  std::vector<mfem::Vector> residuals(num_rows);
369 
370  for (size_t row_i = 0; row_i < num_rows; ++row_i) {
371  FEFieldPtr primal_field_row_i = diagonal_fields[row_i];
372  *primal_field_row_i = *unknowns[row_i];
373  applyBoundaryConditions(time_info.time(), bc_managers[row_i], primal_field_row_i, nullptr);
374  }
375  for (size_t row_i = 0; row_i < num_rows; ++row_i) {
376  residuals[row_i] = residual_evals[row_i]->residual(time_info, shape_disp_ptr.get(),
377  getConstFieldPointers(input_fields[row_i]));
378  residuals[row_i].SetSubVector(bc_managers[row_i]->allEssentialTrueDofs(), 0.0);
379  }
380  return residuals;
381  };
382 
383  auto eval_jacobians = [=](const std::vector<FEFieldPtr>& unknowns) {
384  SLIC_ERROR_IF(unknowns.size() != num_rows,
385  "block solver unknown size must match the number or residuals in block_solve");
386  std::vector<std::vector<std::unique_ptr<mfem::HypreParMatrix>>> jacobians(num_rows);
387 
388  for (size_t row_i = 0; row_i < num_rows; ++row_i) {
389  FEFieldPtr primal_field_row_i = diagonal_fields[row_i];
390  *primal_field_row_i = *unknowns[row_i];
391  applyBoundaryConditions(time_info.time(), bc_managers[row_i], primal_field_row_i, nullptr);
392  }
393 
394  for (size_t row_i = 0; row_i < num_rows; ++row_i) {
395  std::vector<FEFieldPtr> row_field_inputs = input_fields[row_i];
396  std::vector<double> tangent_weights(row_field_inputs.size(), 0.0);
397  for (size_t col_j = 0; col_j < num_rows; ++col_j) {
398  size_t field_index_to_diff = block_indices[row_i][col_j];
399  if (field_index_to_diff != invalid_block_index) {
400  tangent_weights[field_index_to_diff] = 1.0;
401  auto jac_ij = residual_evals[row_i]->jacobian(time_info, shape_disp_ptr.get(),
402  getConstFieldPointers(row_field_inputs), tangent_weights);
403  jacobians[row_i].emplace_back(std::move(jac_ij));
404  tangent_weights[field_index_to_diff] = 0.0;
405  } else {
406  jacobians[row_i].emplace_back(nullptr);
407  }
408  }
409  }
410 
411  // Apply BCs to the block system
412  for (size_t row_i = 0; row_i < num_rows; ++row_i) {
413  if (jacobians[row_i][row_i]) {
414  jacobians[row_i][row_i]->EliminateBC(bc_managers[row_i]->allEssentialTrueDofs(),
415  mfem::Operator::DiagonalPolicy::DIAG_ONE);
416  }
417  for (size_t col_j = 0; col_j < num_rows; ++col_j) {
418  if (col_j != row_i) {
419  if (jacobians[row_i][col_j]) {
420  jacobians[row_i][col_j]->EliminateRows(bc_managers[row_i]->allEssentialTrueDofs());
421  }
422  if (jacobians[col_j][row_i]) {
423  mfem::HypreParMatrix* Jji =
424  jacobians[col_j][row_i]->EliminateCols(bc_managers[row_i]->allEssentialTrueDofs());
425  delete Jji;
426  }
427  }
428  }
429  }
430  return jacobians;
431  };
432 
433  diagonal_fields = solver->solve(diagonal_fields, eval_residuals, eval_jacobians);
434 
435  downstream.set<std::vector<FEFieldPtr>, std::vector<FEDualPtr>>(diagonal_fields);
436 
437  SMITH_MARK_END("solve forward");
438  });
439 
440  sol.set_vjp([=](gretl::UpstreamStates& upstreams, const gretl::DownstreamState& downstream) {
441  SMITH_MARK_BEGIN("solve reverse");
442  const std::vector<FEFieldPtr> s = downstream.get<std::vector<FEFieldPtr>>(); // get the final solution
443  const std::vector<FEDualPtr> s_dual =
444  downstream.get_dual<std::vector<FEDualPtr>, std::vector<FEFieldPtr>>(); // get the dual load
445 
446  const size_t num_rows = num_state_inputs.size();
447  SLIC_ERROR_IF(s_dual.size() != num_rows,
448  "block solver vjp downstream size must match the number or residuals in block_solve");
449 
450  std::vector<std::vector<FEFieldPtr>> input_fields(num_rows);
451  size_t field_count = 0;
452  for (size_t row_i = 0; row_i < num_rows; ++row_i) {
453  for (size_t state_i = 0; state_i < num_state_inputs[row_i]; ++state_i) {
454  input_fields[row_i].push_back(upstreams[field_count++].get<FEFieldPtr>());
455  }
456  }
457  for (size_t row_i = 0; row_i < num_rows; ++row_i) {
458  for (size_t param_i = 0; param_i < num_param_inputs[row_i]; ++param_i) {
459  input_fields[row_i].push_back(upstreams[field_count++].get<FEFieldPtr>());
460  }
461  }
462 
463  // if the field is a primal variable we solved before,
464  // make a copy so we don't accidentally override the original copy
465  std::vector<FEFieldPtr> diagonal_fields(num_rows);
466  for (size_t row_i = 0; row_i < num_rows; ++row_i) {
467  diagonal_fields[row_i] = std::make_shared<FiniteElementState>(*input_fields[row_i][block_indices[row_i][row_i]]);
468  *diagonal_fields[row_i] = *s[row_i];
469  }
470 
471  for (size_t row_i = 0; row_i < num_rows; ++row_i) {
472  for (size_t col_j = 0; col_j < num_rows; ++col_j) {
473  input_fields[row_i][block_indices[row_i][col_j]] = diagonal_fields[col_j];
474  }
475  }
476 
477  const FEFieldPtr shape_disp_ptr = upstreams[field_count].get<FEFieldPtr>();
478 
479  // I'm not sure this will be the right timestamp to apply boundary condition during backward propagation
480  // Need to double check for time-dependent boundary conditions
481  for (size_t row_i = 0; row_i < num_rows; ++row_i) {
482  FEFieldPtr primal_field_row_i = diagonal_fields[row_i];
483  applyBoundaryConditions(time_info.time(), bc_managers[row_i], primal_field_row_i, nullptr);
484  }
485 
486  solver->clearMemory();
487 
488  std::vector<std::vector<std::unique_ptr<mfem::HypreParMatrix>>> jacobians(num_rows);
489  for (size_t row_i = 0; row_i < num_rows; ++row_i) {
490  std::vector<FEFieldPtr> row_field_inputs = input_fields[row_i];
491  std::vector<double> tangent_weights(row_field_inputs.size(), 0.0);
492  for (size_t col_j = 0; col_j < num_rows; ++col_j) {
493  size_t field_index_to_diff = block_indices[row_i][col_j];
494  tangent_weights[field_index_to_diff] = 1.0;
495  auto jac_ij = residual_evals[row_i]->jacobian(time_info, shape_disp_ptr.get(),
496  getConstFieldPointers(row_field_inputs), tangent_weights);
497  jacobians[row_i].emplace_back(std::move(jac_ij));
498  tangent_weights[field_index_to_diff] = 0.0;
499  }
500  }
501 
502  // Apply BCs to the block system
503  for (size_t row_i = 0; row_i < num_rows; ++row_i) {
504  s_dual[row_i]->SetSubVector(bc_managers[row_i]->allEssentialTrueDofs(), 0.0);
505 
506  mfem::HypreParMatrix* Jii =
507  jacobians[row_i][row_i]->EliminateRowsCols(bc_managers[row_i]->allEssentialTrueDofs());
508  delete Jii;
509  for (size_t col_j = 0; col_j < num_rows; ++col_j) {
510  if (col_j != row_i) {
511  jacobians[row_i][col_j]->EliminateRows(bc_managers[row_i]->allEssentialTrueDofs());
512  mfem::HypreParMatrix* Jji =
513  jacobians[col_j][row_i]->EliminateCols(bc_managers[row_i]->allEssentialTrueDofs());
514  delete Jji;
515  }
516  }
517  }
518 
519  // Take the transpose of the block system
520  std::vector<std::vector<std::unique_ptr<mfem::HypreParMatrix>>> jacobians_T(num_rows);
521  for (size_t col_j = 0; col_j < num_rows; ++col_j) {
522  for (size_t row_i = 0; row_i < num_rows; ++row_i) {
523  jacobians_T[col_j].emplace_back(std::unique_ptr<mfem::HypreParMatrix>(jacobians[row_i][col_j]->Transpose()));
524  }
525  }
526  for (size_t row_i = 0; row_i < num_rows; ++row_i) {
527  for (size_t col_j = 0; col_j < num_rows; ++col_j) {
528  jacobians[row_i][col_j].reset();
529  }
530  }
531 
532  std::vector<FEFieldPtr> adjoint_fields(num_rows);
533  adjoint_fields = solver->solveAdjoint(s_dual, jacobians_T);
534  for (size_t row_i = 0; row_i < num_rows; ++row_i) {
535  *adjoint_fields[row_i] *= -1.0;
536  }
537 
538  // Update sensitivities
539  std::vector<std::vector<FEDualPtr>> field_sensitivities(num_rows);
540  FEDualPtr shape_disp_sensitivity = upstreams[field_count].get_dual<FEDualPtr, FEFieldPtr>();
541  size_t dual_index = 0;
542  for (size_t row_i = 0; row_i < num_rows; ++row_i) {
543  for (size_t state_i = 0; state_i < num_state_inputs[row_i]; ++state_i) {
544  field_sensitivities[row_i].push_back(upstreams[dual_index++].get_dual<FEDualPtr, FEFieldPtr>());
545  }
546  }
547  for (size_t row_i = 0; row_i < num_rows; ++row_i) {
548  for (size_t param_i = 0; param_i < num_param_inputs[row_i]; ++param_i) {
549  field_sensitivities[row_i].push_back(upstreams[dual_index++].get_dual<FEDualPtr, FEFieldPtr>());
550  }
551  }
552  SLIC_ERROR_IF(field_count != dual_index, "Number of sensitivities must equal to number of upstreams");
553 
554  // No sensitivity needed for primal fields
555  for (size_t row_i = 0; row_i < num_rows; ++row_i) {
556  for (size_t col_j = 0; col_j < num_rows; ++col_j) {
557  field_sensitivities[row_i][block_indices[row_i][col_j]] = nullptr;
558  }
559  }
560 
561  for (size_t row_i = 0; row_i < num_rows; ++row_i) {
562  residual_evals[row_i]->vjp(time_info, shape_disp_ptr.get(), getConstFieldPointers(input_fields[row_i]), {},
563  adjoint_fields[row_i].get(), shape_disp_sensitivity.get(),
564  getFieldPointers(field_sensitivities[row_i]), {});
565  }
566 
567  SMITH_MARK_END("solve reverse");
568  });
569 
570  sol.finalize();
571 
572  std::vector<FieldState> results;
573  for (size_t i = 0; i < num_rows_; ++i) {
574  FieldState s = gretl::create_state<FEFieldPtr, FEDualPtr>(
576  [i](const std::vector<FEFieldPtr>& sols) {
577  auto state_copy = std::make_shared<FiniteElementState>(sols[i]->space(), sols[i]->name());
578  *state_copy = *sols[i];
579  return state_copy;
580  },
581  [i](const std::vector<FEFieldPtr>&, const FEFieldPtr&, std::vector<FEDualPtr>& sols_,
582  const FEDualPtr& output_) { *sols_[i] += *output_; },
583  sol);
584 
585  results.emplace_back(s);
586  }
587 
588  return results;
589 }
590 
591 } // namespace smith
This file contains the declaration of the boundary condition manager class.
A container for the boundary condition information relating to a specific physics module.
const mfem::Array< int > & allEssentialTrueDofs() const
Returns all the true degrees of freedom associated with all the essential BCs.
std::vector< BoundaryCondition > & essentials()
Accessor for the essential BC objects.
Abstract interface to DifferentiableBlockSolver interface. Each differentiable block solve should pro...
virtual std::vector< FieldPtr > solveAdjoint(const std::vector< DualPtr > &u_bars, std::vector< std::vector< MatrixPtr >> &jacobian_transposed) const =0
Solve the (linear) adjoint set of equations with a vector of FiniteElementState as unknown.
virtual std::vector< FieldPtr > solve(const std::vector< FieldPtr > &u_guesses, std::function< std::vector< mfem::Vector >(const std::vector< FieldPtr > &)> residuals, std::function< std::vector< std::vector< MatrixPtr >>(const std::vector< FieldPtr > &)> jacobians) const =0
Solve a set of equations with a vector of FiniteElementState as unknown.
virtual void clearMemory() const
Interface option to clear memory between solves to avoid high-water mark memory usage.
Abstract interface to DifferentiableSolver interface. Each differentiable solve should provide both i...
virtual std::shared_ptr< smith::FiniteElementState > solve(const smith::FiniteElementState &u_guess, std::function< mfem::Vector(const smith::FiniteElementState &)> equation, std::function< std::unique_ptr< mfem::HypreParMatrix >(const smith::FiniteElementState &)> jacobian) const =0
Solve a set of equations with a FiniteElementState as unknown.
virtual std::shared_ptr< smith::FiniteElementState > solveAdjoint(const smith::FiniteElementDual &u_bar, std::unique_ptr< mfem::HypreParMatrix > jacobian_transposed) const =0
Solve the (linear) adjoint set of equations with a FiniteElementState as unknown.
virtual void clearMemory() const
Interface option to clear memory between solves to avoid high-water mark memory usage.
A generic class for setting Dirichlet boundary conditions on arbitrary physics.
const smith::BoundaryConditionManager & getBoundaryConditionManager() const
Return the smith BoundaryConditionManager.
Class for encapsulating the critical MFEM components of a primal finite element field.
Abstract WeakForm class.
Definition: weak_form.hpp:36
virtual std::unique_ptr< mfem::HypreParMatrix > jacobian(TimeInfo time_info, ConstFieldPtr shape_disp, const std::vector< ConstFieldPtr > &fields, const std::vector< double > &field_argument_tangents, const std::vector< ConstQuadratureFieldPtr > &quad_fields={}) const =0
Derivative of the residual with respect to specified field arguments: sum_j d{r}/d{fields}_j * argume...
virtual mfem::Vector residual(TimeInfo time_info, ConstFieldPtr shape_disp, const std::vector< ConstFieldPtr > &fields, const std::vector< ConstQuadratureFieldPtr > &quad_fields={}) const =0
Virtual interface for computing the residual vector of a weak form.
virtual void vjp(TimeInfo time_info, ConstFieldPtr shape_disp, const std::vector< ConstFieldPtr > &fields, const std::vector< ConstQuadratureFieldPtr > &quad_fields, ConstFieldPtr v_field, DualFieldPtr vjp_shape_disp_sensitivity, const std::vector< DualFieldPtr > &vjp_sensitivities, const std::vector< QuadratureFieldPtr > &vjp_quadrature_sensivities) const =0
Vector-Jacobian product, will += into existing values in vjpFields.
This file contains the declaration of the DifferentiableSolver interface.
Contains DirichletBoundaryConditions class for interaction with the differentiable solve interfaces.
Defines common types and helper functions for using the residual and scalar_objective classes.
Accelerator functionality.
Definition: smith.cpp:36
constexpr T & get(variant< T0, T1 > &v)
Returns the variant member of specified type.
Definition: variant.hpp:338
std::shared_ptr< FiniteElementState > FEFieldPtr
typedef
Definition: field_state.hpp:20
gretl::State< std::vector< FEFieldPtr >, std::vector< FEDualPtr > > FieldVecState
typedef
Definition: field_state.hpp:24
void applyBoundaryConditions(double time, const smith::BoundaryConditionManager *bc_manager, smith::FEFieldPtr &primal_field, const smith::FEFieldPtr &bc_field_ptr)
apply boundary conditions
std::vector< FiniteElementState * > getFieldPointers(std::vector< FieldState > &states, std::vector< FieldState > params={})
Get a vector of FieldPtr or DualFieldPtr from a vector of FieldState.
gretl::State< FEFieldPtr, FEDualPtr > FieldState
typedef
Definition: field_state.hpp:22
constexpr SMITH_HOST_DEVICE int size(const tensor< T, n... > &)
returns the total number of stored values in a tensor
Definition: tensor.hpp:1939
std::shared_ptr< FiniteElementDual > FEDualPtr
typedef
Definition: field_state.hpp:21
std::vector< const FiniteElementState * > getConstFieldPointers(const std::vector< FieldState > &states, const std::vector< FieldState > &params={})
Get a vector of ConstFieldPtr or ConstDualFieldPtr from a vector of FieldState.
std::vector< FieldState > block_solve(const std::vector< WeakForm * > &residual_evals, const std::vector< std::vector< size_t >> block_indices, const FieldState &shape_disp, const std::vector< std::vector< FieldState >> &states, const std::vector< std::vector< FieldState >> &params, const TimeInfo &time_info, const DifferentiableBlockSolver *solver, const std::vector< const BoundaryConditionManager * > &bc_managers)
Solve a block nonlinear system of equations as defined by the vector of weak form.
mfem::ParFiniteElementSpace & space(FieldState field)
Get the space from the primal field of a field states.
FieldState nonlinearSolve(const WeakForm *residual_eval, const FieldState &shape_disp, const std::vector< FieldState > &states, const std::vector< FieldState > &params, const std::vector< double > &state_update_weights, size_t primal_solve_state_index, size_t dirichlet_state_index, const TimeInfo &time_info, const DifferentiableSolver *solver, const BoundaryConditionManager *bc_manager, const FieldState *bc_field=nullptr)
Solve a nonlinear system of equations as defined by the weak form.
FieldState solve(const WeakForm &weak_form, const FieldState &shape_disp, const std::vector< FieldState > &states, const std::vector< FieldState > &params, const TimeInfo &time_info, const DifferentiableSolver &solver, const DirichletBoundaryConditions &bcs, size_t unknown_state_index)
Solve a nonlinear system of equations as defined by the weak form, assuming that the field indexed by...
Methods for solving systems of equations as given by WeakForms. Tracks these operations on the gretl ...
#define SMITH_MARK_FUNCTION
Definition: profiling.hpp:90
#define SMITH_MARK_BEGIN(name)
Definition: profiling.hpp:94
#define SMITH_MARK_END(name)
Definition: profiling.hpp:95
struct storing time and timestep information
Definition: common.hpp:18
double time() const
accessor for the current time
Definition: common.hpp:26
functor which takes a std::shared_ptr<FiniteElementState>, and returns a zero-valued std::shared_ptr<...
Definition: field_state.hpp:29
Specifies interface for evaluating weak form residuals and their gradients.