13 return gretl::clone_state(
15 auto next = std::make_shared<FiniteElementState>(s->space(), s->name() +
"_squared");
22 for (
int i = 0; i < sz; ++i) {
23 (*s_)[i] += 2.0 * (*s)[i] * (*next_)[i];
31 return gretl::create_state<double, double>(
32 gretl::defaultInitializeZeroDual<double, double>(),
35 A_->Add(product_, *B);
36 B_->Add(product_, *A);
43 return gretl::create_state<double, double>(
44 gretl::defaultInitializeZeroDual<double, double>(),
47 A_->Add(product_, *B);
48 B_->Add(product_, *A);
55 auto z = x.clone({x, y});
57 z.set_eval([a, b](
const gretl::UpstreamStates& upstreams, gretl::DownstreamState& downstream) {
60 FEFieldPtr Z = std::make_shared<FiniteElementState>(X->space(),
"axpby");
61 add(a, *X, b, *Y, *Z);
65 z.set_vjp([a, b](gretl::UpstreamStates& upstreams,
const gretl::DownstreamState& downstream) {
69 add(*X_dual, a, *Z_dual, *X_dual);
70 add(*Y_dual, b, *Z_dual, *Y_dual);
78 return gretl::clone_state(
79 [](
const FEFieldPtr& X) {
return std::make_shared<FiniteElementState>(X->space(),
"zero"); },
85 return gretl::clone_state(
87 FEFieldPtr Z = std::make_shared<FiniteElementState>(X->space(),
"axpby");
88 add(A, *X, B, *Y, *Z);
93 add(*X_, A, *Z_, *X_);
94 add(*Y_, B, *Z_, *Y_);
105 const std::vector<gretl::State<double>>& differentiable_weights,
106 const std::vector<FieldState>& differentiably_weighted_fields,
107 const std::vector<double>& differentiable_scale_factors)
109 SLIC_ERROR_IF(weights.size() != weighted_fields.size(),
110 "weights and the fields they are weighting do not match in size");
111 SLIC_ERROR_IF(differentiable_weights.size() != differentiably_weighted_fields.size(),
112 "differentiable weights and the fields they are weighting do not match in size");
113 SLIC_ERROR_IF(differentiable_weights.size() != differentiable_scale_factors.size(),
114 "differentiable weights and the vector of fixed scale factors do not match in size");
115 SLIC_ERROR_IF((weights.size() == 0) && (differentiable_weights.size() == 0),
116 "At least 1 weight must be passed to a weighted sum");
118 std::vector<gretl::StateBase> inputs;
119 inputs.insert(inputs.end(), weighted_fields.begin(), weighted_fields.end());
120 inputs.insert(inputs.end(), differentiable_weights.begin(), differentiable_weights.end());
121 inputs.insert(inputs.end(), differentiably_weighted_fields.begin(), differentiably_weighted_fields.end());
123 auto x = weights.size() ? weighted_fields[0] : differentiably_weighted_fields[0];
124 auto z = x.clone(inputs);
126 z.set_eval([weights, differentiable_scale_factors](
const gretl::UpstreamStates& upstreams,
127 gretl::DownstreamState& downstream) {
128 size_t num_weights = weights.size();
129 size_t num_diffable_weights = (upstreams.size() - num_weights) / 2;
131 auto X = weights.size() ? upstreams[0].get<
FEFieldPtr>()
132 : upstreams[num_weights + num_diffable_weights].get<FEFieldPtr>();
134 FEFieldPtr Z = std::make_shared<FiniteElementState>(X->space(),
"weighted_sum");
136 if (num_weights > 0) {
137 double weightOld = weights[0];
139 if (num_weights == 1) {
140 Z->Set(weightOld, *vecOld);
142 for (
size_t i = 1; i < num_weights; ++i) {
143 double weightNew = weights[i];
144 add(weightOld, *vecOld, weightNew, *upstreams[i].get<FEFieldPtr>(), *Z);
150 if (num_diffable_weights > 0) {
151 size_t start_index = 0;
152 double weightOld = 1.0;
155 if (weights.size() == 0) {
157 double scale = differentiable_scale_factors[0];
158 weightOld = scale * upstreams[num_weights].get<
double>();
159 vecOld = upstreams[num_weights + num_diffable_weights].get<
FEFieldPtr>();
160 if (num_diffable_weights == 1) {
161 Z->Set(weightOld, *vecOld);
165 for (
size_t i = start_index; i < num_diffable_weights; ++i) {
166 double scale = differentiable_scale_factors[i];
167 double weightNew = scale * upstreams[num_weights + i].get<
double>();
168 add(weightOld, *vecOld, weightNew, *upstreams[num_weights + num_diffable_weights + i].get<FEFieldPtr>(), *Z);
177 z.set_vjp([weights, differentiable_scale_factors](gretl::UpstreamStates& upstreams,
178 const gretl::DownstreamState& downstream) {
179 size_t num_weights = weights.size();
180 size_t num_diffable_weights = (upstreams.size() - num_weights) / 2;
184 for (
size_t i = 0; i < num_weights; ++i) {
186 double weight = weights[i];
187 add(*V_dual, weight, *Z_dual, *V_dual);
190 for (
size_t i = 0; i < num_diffable_weights; ++i) {
191 double& weight_dual = upstreams[num_weights + i].get_dual<double,
double>();
193 double scale = differentiable_scale_factors[i];
194 double weight = scale * upstreams[num_weights + i].get<
double>();
196 add(*V_dual, weight, *Z_dual, *V_dual);
220 const size_t num_initial_weights =
weights_.size();
223 for (
size_t n = num_initial_weights; n <
weights_.size(); ++n) {
249 return zero -= *
this;
254 return weighted_sum(weights_, weighted_fields_, differentiable_weights_, differentiably_weighted_fields_,
255 differentiable_scale_factors_);
Accelerator functionality.
gretl::State< FEDualPtr, FEFieldPtr > ReactionState
typedef
std::shared_ptr< FiniteElementState > FEFieldPtr
typedef
FieldStateWeightedSum operator+(const FieldState &x, const FieldState &y)
add two FieldState
gretl::State< FEFieldPtr, FEDualPtr > FieldState
typedef
FieldState axpby(double a, const FieldState &x, double b, const FieldState &y)
gretl-function to compute a*x + b*y
std::shared_ptr< FiniteElementDual > FEDualPtr
typedef
FieldState square(const FieldState &state)
gretl-function to square (x^2) every component of the Field
gretl::State< double > innerProduct(const FieldState &a, const FieldState &b)
gretl-function to compute the inner product (vector l2-norm) of a and b
FieldStateWeightedSum operator*(double a, const FieldState &b)
multiply scalar by a FieldState to get a temporary FieldStateWeightedSum which can cast back to a Fie...
FieldState zeroCopy(const FieldState &x)
gretl-function to make a deep-copy of a FieldState and initialize it to 0.
FieldStateWeightedSum operator-(const FieldState &x, const FieldState &y)
subtract two FieldState
FieldState weighted_sum(const std::vector< double > &weights, const std::vector< FieldState > &weighted_fields, const std::vector< gretl::State< double >> &differentiable_weights, const std::vector< FieldState > &differentiably_weighted_fields, const std::vector< double > &differentiable_scale_factors)
compute the differentiable weighted sum of fields, weighted by both double weights,...
temporary object to register the multiplication of a gretl::State<double> with a FieldState....
FieldStateWeightedSum operator-() const
negate
FieldStateWeightedSum & operator*=(double weight)
mulitply by a fixed scalar
std::vector< double > weights_
non-differentiable weights
FieldStateWeightedSum & operator+=(const FieldStateWeightedSum &b)
add another weighted sum in place
std::vector< FieldState > weighted_fields_
fields to weight by non-differentiable weights
std::vector< double > differentiable_scale_factors_
flag differentiable weights to be negated
std::vector< FieldState > differentiably_weighted_fields_
fields to weight by differentiable weights
FieldStateWeightedSum & operator-=(const FieldStateWeightedSum &b)
subtract another weighted sum in place
std::vector< gretl::State< double > > differentiable_weights_
differentiable weights
A sentinel struct for eliding no-op tensor operations.