MOTION  0.01
Framework for mixed-protocol multi-party computation
arithmetic_gmw_gate.h
Go to the documentation of this file.
1 // MIT License
2 //
3 // Copyright (c) 2019 Oleksandr Tkachenko, Lennart Braun
4 // Cryptography and Privacy Engineering Group (ENCRYPTO)
5 // TU Darmstadt, Germany
6 //
7 // Permission is hereby granted, free of charge, to any person obtaining a copy
8 // of this software and associated documentation files (the "Software"), to deal
9 // in the Software without restriction, including without limitation the rights
10 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11 // copies of the Software, and to permit persons to whom the Software is
12 // furnished to do so, subject to the following conditions:
13 //
14 // The above copyright notice and this permission notice shall be included in all
15 // copies or substantial portions of the Software.
16 //
17 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23 // SOFTWARE.
24 
25 #pragma once
26 
27 #include <fmt/format.h>
28 
29 #include "arithmetic_gmw_share.h"
31 #include "base/register.h"
39 #include "protocols/gate.h"
41 #include "utility/helpers.h"
42 #include "utility/logger.h"
44 
46 
47 //
48 // | <- one unsigned integer input
49 // --------
50 // | |
51 // | Gate |
52 // | |
53 // --------
54 // | <- one SharePointer(new arithmetic_gmw::Share) output
55 //
56 
57 template <typename T, typename = std::enable_if_t<std::is_unsigned_v<T>>>
58 class InputGate final : public motion::InputGate {
59  using Base = motion::InputGate;
60 
61  public:
62  InputGate(std::span<const T> input, std::size_t input_owner, Backend& backend)
63  : Base(backend), input_(std::vector(input.begin(), input.end())) {
64  input_owner_id_ = input_owner;
66  }
67 
68  InputGate(std::vector<T>&& input, std::size_t input_owner, Backend& backend)
69  : Base(backend), input_(std::move(input)) {
70  input_owner_id_ = input_owner;
72  }
73 
75  static_assert(!std::is_same_v<T, bool>);
76 
78  arithmetic_sharing_id_ = GetRegister().NextArithmeticSharingId(input_.size());
79  if constexpr (kVerboseDebug) {
81  fmt::format("Created an arithmetic_gmw::InputGate with global id {}", gate_id_));
82  }
83  output_wires_ = {std::static_pointer_cast<motion::Wire>(
84  std::make_shared<arithmetic_gmw::Wire<T>>(input_, backend_))};
85  for (auto& w : output_wires_) {
87  }
88 
89  auto gate_info = fmt::format("uint{}_t type, gate id {}, owner {}", sizeof(T) * 8, gate_id_,
91  GetLogger().LogDebug(fmt::format(
92  "Allocate an arithmetic_gmw::InputGate with following properties: {}", gate_info));
93  }
94 
95  ~InputGate() final = default;
96 
97  void EvaluateSetup() final override {
101  }
102 
103  // non-interactive input sharing based on distributed in advance randomness
104  // seeds
105  void EvaluateOnline() final override {
106  WaitSetup();
107  assert(setup_is_ready_);
108 
109  auto& communication_layer = GetCommunicationLayer();
110  auto my_id = communication_layer.GetMyId();
111  auto number_of_parties = communication_layer.GetNumberOfParties();
112 
113  std::vector<T> result;
114 
115  if (static_cast<std::size_t>(input_owner_id_) == my_id) {
116  result.resize(input_.size());
117  auto log_string = std::string("");
118  for (auto party_id = 0u; party_id < number_of_parties; ++party_id) {
119  if (party_id == my_id) {
120  continue;
121  }
122  auto& randomness_generator = GetBaseProvider().GetMyRandomnessGenerator(party_id);
123  auto randomness =
124  randomness_generator.template GetUnsigned<T>(arithmetic_sharing_id_, input_.size());
125  if constexpr (kVerboseDebug) {
126  log_string.append(fmt::format("id#{}:{} ", party_id, randomness.at(0)));
127  }
128  for (auto j = 0u; j < result.size(); ++j) {
129  result.at(j) += randomness.at(j);
130  }
131  }
132  for (auto j = 0u; j < result.size(); ++j) {
133  result.at(j) = input_.at(j) - result.at(j);
134  }
135 
136  if constexpr (kVerboseDebug) {
137  auto s = fmt::format(
138  "My (id#{}) arithmetic input sharing for gate#{}, my input: {}, my "
139  "share: {}, expected shares of other parties: {}",
140  input_owner_id_, gate_id_, input_.at(0), result.at(0), log_string);
141  GetLogger().LogTrace(s);
142  }
143  } else {
144  auto& randomness_generator = GetBaseProvider().GetTheirRandomnessGenerator(input_owner_id_);
145  result = randomness_generator.template GetUnsigned<T>(arithmetic_sharing_id_, input_.size());
146 
147  if constexpr (kVerboseDebug) {
148  auto s = fmt::format(
149  "Arithmetic input sharing (gate#{}) of Party's#{} input, got a share "
150  "{} from the seed",
151  gate_id_, input_owner_id_, result.at(0));
152  GetLogger().LogTrace(s);
153  }
154  }
155  auto my_wire = std::dynamic_pointer_cast<arithmetic_gmw::Wire<T>>(output_wires_.at(0));
156  assert(my_wire);
157  my_wire->GetMutableValues() = std::move(result);
158 
159  GetLogger().LogDebug(fmt::format("Evaluated arithmetic_gmw::InputGate with id#{}", gate_id_));
162  }
163 
164  // perhaps, we should return a copy of the pointer and not move it for the
165  // case we need it multiple times
167  auto arithmetic_wire = GetOutputArithmeticWire();
168  auto result = std::make_shared<arithmetic_gmw::Share<T>>(arithmetic_wire);
169  return result;
170  }
171 
172  // perhaps, we should return a copy of the pointer and not move it for the
173  // case we need it multiple times
175  auto result = std::dynamic_pointer_cast<arithmetic_gmw::Wire<T>>(output_wires_.at(0));
176  assert(result);
177  return result;
178  }
179 
180  private:
181  std::size_t arithmetic_sharing_id_;
182 
183  std::vector<T> input_;
184 };
185 
186 constexpr std::size_t kAll = std::numeric_limits<std::int64_t>::max();
187 
188 template <typename T, typename = std::enable_if_t<std::is_unsigned_v<T>>>
189 class OutputGate final : public motion::OutputGate {
190  using Base = motion::OutputGate;
191 
192  public:
193  // perhaps, we should return a copy of the pointer and not move it for the
194  // case we need it multiple times
196  auto arithmetic_wire = std::dynamic_pointer_cast<arithmetic_gmw::Wire<T>>(output_wires_.at(0));
197  assert(arithmetic_wire);
198  auto result = std::make_shared<arithmetic_gmw::Share<T>>(arithmetic_wire);
199  return result;
200  }
201 
202  OutputGate(const arithmetic_gmw::WirePointer<T>& parent, std::size_t output_owner = kAll)
203  : Base(parent->GetBackend()) {
204  assert(parent);
205 
206  if (parent->GetProtocol() != MpcProtocol::kArithmeticGmw) {
207  auto sharing_type = to_string(parent->GetProtocol());
208  throw(
209  std::runtime_error((fmt::format("Arithmetic output gate expects an arithmetic share, "
210  "got a share of type {}",
211  sharing_type))));
212  }
213 
214  parent_ = {parent};
215 
216  // values we need repeatedly
217  auto& communication_layer = GetCommunicationLayer();
218  auto my_id = communication_layer.GetMyId();
219  auto number_of_parties = communication_layer.GetNumberOfParties();
220 
221  if (static_cast<std::size_t>(output_owner) >= number_of_parties &&
222  static_cast<std::size_t>(output_owner) != kAll) {
223  throw std::runtime_error(
224  fmt::format("Invalid output owner: {} of {}", output_owner, number_of_parties));
225  }
226 
227  output_owner_ = output_owner;
231  is_my_output_ = my_id == static_cast<std::size_t>(output_owner_) ||
232  static_cast<std::size_t>(output_owner_) == kAll;
233 
234  RegisterWaitingFor(parent_.at(0)->GetWireId());
235  parent_.at(0)->RegisterWaitingGate(gate_id_);
236 
237  {
238  auto w = std::static_pointer_cast<motion::Wire>(
239  std::make_shared<arithmetic_gmw::Wire<T>>(backend_, parent->GetNumberOfSimdValues()));
241  output_wires_ = {std::move(w)};
242  }
243 
244  // Tell the DataStorages that we want to receive OutputMessages from the
245  // other parties.
246  if (is_my_output_) {
247  auto& base_provider = GetBaseProvider();
248  output_message_futures_ = base_provider.RegisterForOutputMessages(gate_id_);
249  }
250 
251  if constexpr (kDebug) {
252  auto gate_info = fmt::format("uint{}_t type, gate id {}, owner {}", sizeof(T) * 8, gate_id_,
253  output_owner_);
254  GetLogger().LogDebug(fmt::format(
255  "Allocate an arithmetic_gmw::OutputGate with following properties: {}", gate_info));
256  }
257  }
258 
259  OutputGate(const arithmetic_gmw::SharePointer<T>& parent, std::size_t output_owner)
260  : OutputGate(parent->GetArithmeticWire(), output_owner) {
261  assert(parent);
262  }
263 
264  OutputGate(const motion::SharePointer& parent, std::size_t output_owner)
265  : OutputGate(std::dynamic_pointer_cast<const arithmetic_gmw::Share<T>>(parent),
266  output_owner) {
267  assert(parent);
268  }
269 
270  ~OutputGate() final = default;
271 
272  void EvaluateSetup() final override {
273  SetSetupIsReady();
275  }
276 
277  void EvaluateOnline() final override {
278  // setup needs to be done first
279  WaitSetup();
280  assert(setup_is_ready_);
281 
282  // data we need repeatedly
283  auto& communication_layer = GetCommunicationLayer();
284  auto my_id = communication_layer.GetMyId();
285  auto number_of_parties = communication_layer.GetNumberOfParties();
286 
287  // note that arithmetic gates have only a single wire
288  auto arithmetic_wire = std::dynamic_pointer_cast<const arithmetic_gmw::Wire<T>>(parent_.at(0));
289  assert(arithmetic_wire);
290  // wait for parent wire to obtain a value
291  arithmetic_wire->GetIsReadyCondition().Wait();
292  // initialize output with local share
293  auto output = arithmetic_wire->GetValues();
294 
295  // we need to send shares to one other party:
296  if (!is_my_output_) {
297  auto payload = ToByteVector(output);
298  auto output_message = motion::communication::BuildOutputMessage(gate_id_, payload);
299  communication_layer.SendMessage(output_owner_, std::move(output_message));
300  }
301  // we need to send shares to all other parties:
302  else if (output_owner_ == kAll) {
303  auto payload = ToByteVector(output);
304  auto output_message = motion::communication::BuildOutputMessage(gate_id_, payload);
305  communication_layer.BroadcastMessage(std::move(output_message));
306  }
307 
308  // we receive shares from other parties
309  if (is_my_output_) {
310  // collect shares from all parties
311  std::vector<std::vector<T>> shared_outputs;
312  shared_outputs.reserve(number_of_parties);
313 
314  for (std::size_t i = 0; i < number_of_parties; ++i) {
315  if (i == my_id) {
316  shared_outputs.push_back(output);
317  continue;
318  }
319  const auto output_message = output_message_futures_.at(i).get();
320  auto message = communication::GetMessage(output_message.data());
321  auto output_message_pointer = communication::GetOutputMessage(message->payload()->data());
322  assert(output_message_pointer);
323  assert(output_message_pointer->wires()->size() == 1);
324 
325  shared_outputs.push_back(
326  FromByteVector<T>(*output_message_pointer->wires()->Get(0)->payload()));
327  assert(shared_outputs[i].size() == parent_[0]->GetNumberOfSimdValues());
328  }
329 
330  // reconstruct the shared value
331  if constexpr (kVerboseDebug) {
332  // we need to copy since we have to keep shared_outputs for the debug output below
333  output = AddVectors(shared_outputs);
334  } else {
335  // we can move
336  output = AddVectors(std::move(shared_outputs));
337  }
338 
339  // set the value of the output wire
340  auto arithmetic_output_wire =
341  std::dynamic_pointer_cast<arithmetic_gmw::Wire<T>>(output_wires_.at(0));
342  assert(arithmetic_output_wire);
343  arithmetic_output_wire->GetMutableValues() = output;
344 
345  if constexpr (kVerboseDebug) {
346  std::string shares{""};
347  for (auto i = 0u; i < number_of_parties; ++i) {
348  shares.append(fmt::format("id#{}:{} ", i, to_string(shared_outputs.at(i))));
349  }
350  auto result = to_string(output);
352  fmt::format("Received output shares: {} from other parties, "
353  "reconstructed result is {}",
354  shares, result));
355  }
356  }
357 
358  // we are done with this gate
359  if constexpr (kDebug) {
361  fmt::format("Evaluated arithmetic_gmw::OutputGate with id#{}", gate_id_));
362  }
365  }
366 
367  protected:
368  // indicates whether this party obtains the output
369  bool is_my_output_ = false;
370 
371  std::vector<motion::ReusableFiberFuture<std::vector<std::uint8_t>>> output_message_futures_;
372 
373  std::mutex m;
374 };
375 
376 template <typename T, typename = std::enable_if_t<std::is_unsigned_v<T>>>
377 class AdditionGate final : public motion::TwoGate {
378  public:
380  : TwoGate(a->GetBackend()) {
381  parent_a_ = {std::static_pointer_cast<motion::Wire>(a)};
382  parent_b_ = {std::static_pointer_cast<motion::Wire>(b)};
383 
384  assert(parent_a_.at(0)->GetNumberOfSimdValues() == parent_b_.at(0)->GetNumberOfSimdValues());
385 
388 
390 
391  RegisterWaitingFor(parent_a_.at(0)->GetWireId());
392  parent_a_.at(0)->RegisterWaitingGate(gate_id_);
393 
394  RegisterWaitingFor(parent_b_.at(0)->GetWireId());
395  parent_b_.at(0)->RegisterWaitingGate(gate_id_);
396 
397  {
398  auto w = std::static_pointer_cast<motion::Wire>(
399  std::make_shared<arithmetic_gmw::Wire<T>>(backend_, a->GetNumberOfSimdValues()));
401  output_wires_ = {std::move(w)};
402  }
403 
404  auto gate_info =
405  fmt::format("uint{}_t type, gate id {}, parents: {}, {}", sizeof(T) * 8, gate_id_,
406  parent_a_.at(0)->GetWireId(), parent_b_.at(0)->GetWireId());
407  GetLogger().LogDebug(fmt::format(
408  "Created an arithmetic_gmw::AdditionGate with following properties: {}", gate_info));
409  }
410 
411  ~AdditionGate() final = default;
412 
413  void EvaluateSetup() final override {
414  SetSetupIsReady();
416  }
417 
418  void EvaluateOnline() final override {
419  WaitSetup();
420  assert(setup_is_ready_);
421 
422  parent_a_.at(0)->GetIsReadyCondition().Wait();
423  parent_b_.at(0)->GetIsReadyCondition().Wait();
424 
425  auto wire_a = std::dynamic_pointer_cast<const arithmetic_gmw::Wire<T>>(parent_a_.at(0));
426  auto wire_b = std::dynamic_pointer_cast<const arithmetic_gmw::Wire<T>>(parent_b_.at(0));
427 
428  assert(wire_a);
429  assert(wire_b);
430 
431  std::vector<T> output;
432  output = RestrictAddVectors(wire_a->GetValues(), wire_b->GetValues());
433 
434  auto arithmetic_wire = std::dynamic_pointer_cast<arithmetic_gmw::Wire<T>>(output_wires_.at(0));
435  arithmetic_wire->GetMutableValues() = std::move(output);
436 
438  fmt::format("Evaluated arithmetic_gmw::AdditionGate with id#{}", gate_id_));
441  }
442 
443  // perhaps, we should return a copy of the pointer and not move it for the
444  // case we need it multiple times
446  auto arithmetic_wire = std::dynamic_pointer_cast<arithmetic_gmw::Wire<T>>(output_wires_.at(0));
447  assert(arithmetic_wire);
448  auto result = std::make_shared<arithmetic_gmw::Share<T>>(arithmetic_wire);
449  return result;
450  }
451 
452  AdditionGate() = delete;
453 
454  AdditionGate(Gate&) = delete;
455 };
456 
457 template <typename T, typename = std::enable_if_t<std::is_unsigned_v<T>>>
458 class SubtractionGate final : public motion::TwoGate {
459  public:
461  : TwoGate(a->GetBackend()) {
462  parent_a_ = {std::static_pointer_cast<motion::Wire>(a)};
463  parent_b_ = {std::static_pointer_cast<motion::Wire>(b)};
464 
465  assert(parent_a_.at(0)->GetNumberOfSimdValues() == parent_b_.at(0)->GetNumberOfSimdValues());
466 
469 
471 
472  RegisterWaitingFor(parent_a_.at(0)->GetWireId());
473  parent_a_.at(0)->RegisterWaitingGate(gate_id_);
474 
475  RegisterWaitingFor(parent_b_.at(0)->GetWireId());
476  parent_b_.at(0)->RegisterWaitingGate(gate_id_);
477 
478  {
479  auto w = std::static_pointer_cast<motion::Wire>(
480  std::make_shared<arithmetic_gmw::Wire<T>>(backend_, a->GetNumberOfSimdValues()));
482  output_wires_ = {std::move(w)};
483  }
484 
485  auto gate_info =
486  fmt::format("uint{}_t type, gate id {}, parents: {}, {}", sizeof(T) * 8, gate_id_,
487  parent_a_.at(0)->GetWireId(), parent_b_.at(0)->GetWireId());
488  GetLogger().LogDebug(fmt::format(
489  "Created an arithmetic_gmw::SubtractionGate with following properties: {}", gate_info));
490  }
491 
492  ~SubtractionGate() final = default;
493 
494  void EvaluateSetup() final override {
495  SetSetupIsReady();
497  }
498 
499  void EvaluateOnline() final override {
500  WaitSetup();
501  assert(setup_is_ready_);
502 
503  parent_a_.at(0)->GetIsReadyCondition().Wait();
504  parent_b_.at(0)->GetIsReadyCondition().Wait();
505 
506  auto wire_a = std::dynamic_pointer_cast<const arithmetic_gmw::Wire<T>>(parent_a_.at(0));
507  auto wire_b = std::dynamic_pointer_cast<const arithmetic_gmw::Wire<T>>(parent_b_.at(0));
508 
509  assert(wire_a);
510  assert(wire_b);
511 
512  std::vector<T> output = SubVectors(wire_a->GetValues(), wire_b->GetValues());
513 
514  auto arithmetic_wire = std::dynamic_pointer_cast<arithmetic_gmw::Wire<T>>(output_wires_.at(0));
515  arithmetic_wire->GetMutableValues() = std::move(output);
516 
518  fmt::format("Evaluated arithmetic_gmw::SubtractionGate with id#{}", gate_id_));
521  }
522 
523  // perhaps, we should return a copy of the pointer and not move it for the
524  // case we need it multiple times
526  auto arithmetic_wire = std::dynamic_pointer_cast<arithmetic_gmw::Wire<T>>(output_wires_.at(0));
527  assert(arithmetic_wire);
528  auto result = std::make_shared<arithmetic_gmw::Share<T>>(arithmetic_wire);
529  return result;
530  }
531 
532  SubtractionGate() = delete;
533 
534  SubtractionGate(Gate&) = delete;
535 };
536 
537 template <typename T, typename = std::enable_if_t<std::is_unsigned_v<T>>>
538 class MultiplicationGate final : public motion::TwoGate {
539  public:
542  : TwoGate(a->GetBackend()) {
543  parent_a_ = {std::static_pointer_cast<motion::Wire>(a)};
544  parent_b_ = {std::static_pointer_cast<motion::Wire>(b)};
545 
546  assert(parent_a_.at(0)->GetNumberOfSimdValues() == parent_b_.at(0)->GetNumberOfSimdValues());
547 
550 
551  d_ = std::make_shared<arithmetic_gmw::Wire<T>>(backend_, a->GetNumberOfSimdValues());
553  e_ = std::make_shared<arithmetic_gmw::Wire<T>>(backend_, a->GetNumberOfSimdValues());
555 
556  d_output_ = std::make_shared<OutputGate<T>>(d_);
557  e_output_ = std::make_shared<OutputGate<T>>(e_);
558 
559  GetRegister().RegisterNextGate(d_output_);
560  GetRegister().RegisterNextGate(e_output_);
561 
563 
564  RegisterWaitingFor(parent_a_.at(0)->GetWireId());
565  parent_a_.at(0)->RegisterWaitingGate(gate_id_);
566 
567  RegisterWaitingFor(parent_b_.at(0)->GetWireId());
568  parent_b_.at(0)->RegisterWaitingGate(gate_id_);
569 
570  {
571  auto w = std::static_pointer_cast<motion::Wire>(
572  std::make_shared<arithmetic_gmw::Wire<T>>(backend_, a->GetNumberOfSimdValues()));
574  output_wires_ = {std::move(w)};
575  }
576 
577  number_of_mts_ = parent_a_.at(0)->GetNumberOfSimdValues();
578  mt_offset_ = GetMtProvider().template RequestArithmeticMts<T>(number_of_mts_);
579 
580  auto gate_info =
581  fmt::format("uint{}_t type, gate id {}, parents: {}, {}", sizeof(T) * 8, gate_id_,
582  parent_a_.at(0)->GetWireId(), parent_b_.at(0)->GetWireId());
583  GetLogger().LogDebug(fmt::format(
584  "Created an arithmetic_gmw::MultiplicationGate with following properties: {}", gate_info));
585  }
586 
587  ~MultiplicationGate() final = default;
588 
589  void EvaluateSetup() final override {
590  SetSetupIsReady();
592  }
593 
594  void EvaluateOnline() final override {
595  WaitSetup();
596  assert(setup_is_ready_);
597  parent_a_.at(0)->GetIsReadyCondition().Wait();
598  parent_b_.at(0)->GetIsReadyCondition().Wait();
599 
600  auto& mt_provider = GetMtProvider();
601  mt_provider.WaitFinished();
602  const auto& mts = mt_provider.template GetIntegerAll<T>();
603  {
604  const auto x = std::dynamic_pointer_cast<const arithmetic_gmw::Wire<T>>(parent_a_.at(0));
605  assert(x);
606  d_->GetMutableValues() = std::vector<T>(
607  mts.a.begin() + mt_offset_, mts.a.begin() + mt_offset_ + x->GetNumberOfSimdValues());
608  T* __restrict__ d_v = d_->GetMutableValues().data();
609  const T* __restrict__ x_v = x->GetValues().data();
610  const auto number_of_simd_values{x->GetNumberOfSimdValues()};
611 
612  std::transform(x_v, x_v + number_of_simd_values, d_v, d_v,
613  [](const T& a, const T& b) { return a + b; });
614  d_->SetOnlineFinished();
615 
616  const auto y = std::dynamic_pointer_cast<const arithmetic_gmw::Wire<T>>(parent_b_.at(0));
617  assert(y);
618  e_->GetMutableValues() = std::vector<T>(
619  mts.b.begin() + mt_offset_, mts.b.begin() + mt_offset_ + x->GetNumberOfSimdValues());
620  T* __restrict__ e_v = e_->GetMutableValues().data();
621  const T* __restrict__ y_v = y->GetValues().data();
622  std::transform(y_v, y_v + number_of_simd_values, e_v, e_v,
623  [](const T& a, const T& b) { return a + b; });
624  e_->SetOnlineFinished();
625  }
626 
627  d_output_->WaitOnline();
628  e_output_->WaitOnline();
629 
630  const auto& d_clear = d_output_->GetOutputWires().at(0);
631  const auto& e_clear = e_output_->GetOutputWires().at(0);
632 
633  d_clear->GetIsReadyCondition().Wait();
634  e_clear->GetIsReadyCondition().Wait();
635 
636  const auto d_w = std::dynamic_pointer_cast<const arithmetic_gmw::Wire<T>>(d_clear);
637  const auto x_i_w = std::dynamic_pointer_cast<const arithmetic_gmw::Wire<T>>(parent_a_.at(0));
638  const auto e_w = std::dynamic_pointer_cast<const arithmetic_gmw::Wire<T>>(e_clear);
639  const auto y_i_w = std::dynamic_pointer_cast<const arithmetic_gmw::Wire<T>>(parent_b_.at(0));
640 
641  assert(d_w);
642  assert(x_i_w);
643  assert(e_w);
644  assert(y_i_w);
645 
646  auto output = std::dynamic_pointer_cast<arithmetic_gmw::Wire<T>>(output_wires_.at(0));
647  assert(output);
648  output->GetMutableValues() =
649  std::vector<T>(mts.c.begin() + mt_offset_,
650  mts.c.begin() + mt_offset_ + parent_a_.at(0)->GetNumberOfSimdValues());
651 
652  const T* __restrict__ d{d_w->GetValues().data()};
653  const T* __restrict__ s_x{x_i_w->GetValues().data()};
654  const T* __restrict__ e{e_w->GetValues().data()};
655  const T* __restrict__ s_y{y_i_w->GetValues().data()};
656  T* __restrict__ output_pointer{output->GetMutableValues().data()};
657 
658  if (GetCommunicationLayer().GetMyId() ==
659  (gate_id_ % GetCommunicationLayer().GetNumberOfParties())) {
660  for (auto i = 0ull; i < output->GetNumberOfSimdValues(); ++i) {
661  output_pointer[i] += (d[i] * s_y[i]) + (e[i] * s_x[i]) - (e[i] * d[i]);
662  }
663  } else {
664  for (auto i = 0ull; i < output->GetNumberOfSimdValues(); ++i) {
665  output_pointer[i] += (d[i] * s_y[i]) + (e[i] * s_x[i]);
666  }
667  }
668 
670  fmt::format("Evaluated arithmetic_gmw::MultiplicationGate with id#{}", gate_id_));
673  }
674 
675  // perhaps, we should return a copy of the pointer and not move it for the
676  // case we need it multiple times
678  auto arithmetic_wire = std::dynamic_pointer_cast<arithmetic_gmw::Wire<T>>(output_wires_.at(0));
679  assert(arithmetic_wire);
680  auto result = std::make_shared<arithmetic_gmw::Share<T>>(arithmetic_wire);
681  return result;
682  }
683 
684  MultiplicationGate() = delete;
685 
686  MultiplicationGate(Gate&) = delete;
687 
688  private:
690  std::shared_ptr<OutputGate<T>> d_output_, e_output_;
691 
692  std::size_t number_of_mts_, mt_offset_;
693 };
694 
695 template <typename T, typename = std::enable_if_t<std::is_unsigned_v<T>>>
696 class SquareGate final : public motion::OneGate {
697  public:
698  SquareGate(const arithmetic_gmw::WirePointer<T>& a) : OneGate(a->GetBackend()) {
699  parent_ = {std::static_pointer_cast<motion::Wire>(a)};
700 
703 
704  d_ = std::make_shared<arithmetic_gmw::Wire<T>>(backend_, a->GetNumberOfSimdValues());
706 
707  d_output_ = std::make_shared<OutputGate<T>>(d_);
708 
709  GetRegister().RegisterNextGate(d_output_);
710 
712 
713  RegisterWaitingFor(parent_.at(0)->GetWireId());
714  parent_.at(0)->RegisterWaitingGate(gate_id_);
715 
716  {
717  auto w = std::static_pointer_cast<motion::Wire>(
718  std::make_shared<arithmetic_gmw::Wire<T>>(backend_, a->GetNumberOfSimdValues()));
720  output_wires_ = {std::move(w)};
721  }
722 
723  number_of_sps_ = parent_.at(0)->GetNumberOfSimdValues();
724  sp_offset_ = GetSpProvider().template RequestSps<T>(number_of_sps_);
725 
726  auto gate_info = fmt::format("uint{}_t type, gate id {}, parent: {}", sizeof(T) * 8, gate_id_,
727  parent_.at(0)->GetWireId());
728  GetLogger().LogDebug(fmt::format(
729  "Created an arithmetic_gmw::SquareGate with following properties: {}", gate_info));
730  }
731 
732  ~SquareGate() final = default;
733 
734  void EvaluateSetup() final override {
735  SetSetupIsReady();
737  }
738 
739  void EvaluateOnline() final override {
740  WaitSetup();
741  assert(setup_is_ready_);
742  parent_.at(0)->GetIsReadyCondition().Wait();
743 
744  auto& sp_provider = GetSpProvider();
745  sp_provider.WaitFinished();
746  const auto& sps = sp_provider.template GetSpsAll<T>();
747  {
748  const auto x = std::dynamic_pointer_cast<const arithmetic_gmw::Wire<T>>(parent_.at(0));
749  assert(x);
750  d_->GetMutableValues() = std::vector<T>(
751  sps.a.begin() + sp_offset_, sps.a.begin() + sp_offset_ + x->GetNumberOfSimdValues());
752  const auto number_of_simd_values{d_->GetNumberOfSimdValues()};
753  T* __restrict__ d_v{d_->GetMutableValues().data()};
754  const T* __restrict__ x_v{x->GetValues().data()};
755  std::transform(x_v, x_v + number_of_simd_values, d_v, d_v,
756  [](const T& a, const T& b) { return a + b; });
757  d_->SetOnlineFinished();
758  }
759 
760  d_output_->WaitOnline();
761 
762  const auto& d_clear = d_output_->GetOutputWires().at(0);
763 
764  d_clear->GetIsReadyCondition().Wait();
765 
766  const auto d_w = std::dynamic_pointer_cast<const arithmetic_gmw::Wire<T>>(d_clear);
767  const auto x_i_w = std::dynamic_pointer_cast<const arithmetic_gmw::Wire<T>>(parent_.at(0));
768 
769  assert(d_w);
770  assert(x_i_w);
771 
772  auto output = std::dynamic_pointer_cast<arithmetic_gmw::Wire<T>>(output_wires_.at(0));
773  assert(output);
774  output->GetMutableValues() =
775  std::vector<T>(sps.c.begin() + sp_offset_,
776  sps.c.begin() + sp_offset_ + parent_.at(0)->GetNumberOfSimdValues());
777 
778  const T* __restrict__ d{d_w->GetValues().data()};
779  const T* __restrict__ s_x{x_i_w->GetValues().data()};
780  T* __restrict__ output_pointer{output->GetMutableValues().data()};
781  if (GetCommunicationLayer().GetMyId() ==
782  (gate_id_ % GetCommunicationLayer().GetNumberOfParties())) {
783  for (auto i = 0ull; i < output->GetNumberOfSimdValues(); ++i) {
784  output_pointer[i] += 2 * (d[i] * s_x[i]) - (d[i] * d[i]);
785  }
786  } else {
787  for (auto i = 0ull; i < output->GetNumberOfSimdValues(); ++i) {
788  output_pointer[i] += 2 * (d[i] * s_x[i]);
789  }
790  }
791 
792  GetLogger().LogDebug(fmt::format("Evaluated arithmetic_gmw::SquareGate with id#{}", gate_id_));
795  }
796 
797  // perhaps, we should return a copy of the pointer and not move it for the
798  // case we need it multiple times
800  auto arithmetic_wire = std::dynamic_pointer_cast<arithmetic_gmw::Wire<T>>(output_wires_.at(0));
801  assert(arithmetic_wire);
802  auto result = std::make_shared<arithmetic_gmw::Share<T>>(arithmetic_wire);
803  return result;
804  }
805 
806  SquareGate() = delete;
807 
808  SquareGate(Gate&) = delete;
809 
810  private:
812  std::shared_ptr<OutputGate<T>> d_output_;
813 
814  std::size_t number_of_sps_, sp_offset_;
815 };
816 
817 } // namespace encrypto::motion::proto::arithmetic_gmw
encrypto::motion::Logger::LogTrace
void LogTrace(const std::string &message)
Definition: logger.cpp:110
output_message.h
helpers.h
encrypto::motion::proto::arithmetic_gmw::AdditionGate
Definition: arithmetic_gmw_gate.h:377
encrypto::motion::Gate::gate_type_
GateType gate_type_
Definition: gate.h:105
encrypto::motion::proto::arithmetic_gmw::OutputGate::~OutputGate
~OutputGate() final=default
output_message_generated.h
encrypto::motion::proto::arithmetic_gmw::AdditionGate::~AdditionGate
~AdditionGate() final=default
motion_base_provider.h
encrypto::motion::proto::arithmetic_gmw::Share
Definition: arithmetic_gmw_share.h:37
encrypto::motion::proto::arithmetic_gmw::OutputGate::GetOutputAsArithmeticShare
arithmetic_gmw::SharePointer< T > GetOutputAsArithmeticShare()
Definition: arithmetic_gmw_gate.h:195
encrypto::motion::Gate::GetCommunicationLayer
communication::CommunicationLayer & GetCommunicationLayer()
Definition: gate.cpp:92
encrypto::motion::proto::arithmetic_gmw::MultiplicationGate::EvaluateOnline
void EvaluateOnline() final override
Definition: arithmetic_gmw_gate.h:594
encrypto::motion::Gate::output_wires_
std::vector< WirePointer > output_wires_
Definition: gate.h:100
encrypto::motion::proto::arithmetic_gmw::SubtractionGate::EvaluateOnline
void EvaluateOnline() final override
Definition: arithmetic_gmw_gate.h:499
encrypto::motion::TwoGate
Definition: gate.h:218
fiber_condition.h
encrypto::motion::proto::arithmetic_gmw::MultiplicationGate
Definition: arithmetic_gmw_gate.h:538
encrypto::motion::TwoGate::parent_b_
std::vector< WirePointer > parent_b_
Definition: gate.h:221
encrypto::motion::proto::arithmetic_gmw::SquareGate::SquareGate
SquareGate()=delete
encrypto::motion::Gate::GetLogger
Logger & GetLogger()
Definition: gate.cpp:100
encrypto::motion::proto::arithmetic_gmw::OutputGate
Definition: arithmetic_gmw_gate.h:189
encrypto::motion::proto::arithmetic_gmw::SubtractionGate::~SubtractionGate
~SubtractionGate() final=default
encrypto::motion::proto::arithmetic_gmw::SubtractionGate::GetOutputAsArithmeticShare
arithmetic_gmw::SharePointer< T > GetOutputAsArithmeticShare()
Definition: arithmetic_gmw_gate.h:525
encrypto::motion::communication::GetMessage
const encrypto::motion::communication::Message * GetMessage(const void *buf)
Definition: message_generated.h:146
reusable_future.h
encrypto::motion::AddVectors
std::vector< T > AddVectors(const std::vector< T > &a, const std::vector< T > &b)
Adds each element in a and b and returns the result.
Definition: helpers.h:100
encrypto::motion::InputGate
Definition: gate.h:170
encrypto::motion::proto::arithmetic_gmw::AdditionGate::EvaluateSetup
void EvaluateSetup() final override
Definition: arithmetic_gmw_gate.h:413
encrypto::motion::proto::arithmetic_gmw::MultiplicationGate::MultiplicationGate
MultiplicationGate()=delete
encrypto::motion::Gate::GetSpProvider
SpProvider & GetSpProvider()
Definition: gate.cpp:106
encrypto::motion::OutputGate::output_owner_
std::int64_t output_owner_
Definition: gate.h:203
encrypto::motion::proto::arithmetic_gmw::InputGate::InitializationHelper
void InitializationHelper()
Definition: arithmetic_gmw_gate.h:74
encrypto::motion::proto::arithmetic_gmw::InputGate::InputGate
InputGate(std::vector< T > &&input, std::size_t input_owner, Backend &backend)
Definition: arithmetic_gmw_gate.h:68
encrypto::motion::OneGate::parent_
std::vector< WirePointer > parent_
Definition: gate.h:155
encrypto::motion::proto::arithmetic_gmw::InputGate::GetOutputAsArithmeticShare
arithmetic_gmw::SharePointer< T > GetOutputAsArithmeticShare()
Definition: arithmetic_gmw_gate.h:166
encrypto::motion::proto::arithmetic_gmw::OutputGate::output_message_futures_
std::vector< motion::ReusableFiberFuture< std::vector< std::uint8_t > > > output_message_futures_
Definition: arithmetic_gmw_gate.h:371
encrypto::motion::proto::arithmetic_gmw::InputGate::EvaluateOnline
void EvaluateOnline() final override
Definition: arithmetic_gmw_gate.h:105
encrypto::motion::GateType::kNonInteractive
@ kNonInteractive
encrypto::motion::proto::arithmetic_gmw::kAll
constexpr std::size_t kAll
Definition: arithmetic_gmw_gate.h:186
encrypto::motion::Logger::LogDebug
void LogDebug(const std::string &message)
Definition: logger.cpp:142
encrypto::motion::Gate::requires_online_interaction_
std::atomic< bool > requires_online_interaction_
Definition: gate.h:108
encrypto::motion::Gate::GetBaseProvider
BaseProvider & GetBaseProvider()
Definition: gate.cpp:102
encrypto::motion::proto::arithmetic_gmw::InputGate::~InputGate
~InputGate() final=default
encrypto::motion::Gate::GetRegister
Register & GetRegister()
Definition: gate.cpp:96
encrypto::motion::proto::arithmetic_gmw::OutputGate::OutputGate
OutputGate(const arithmetic_gmw::SharePointer< T > &parent, std::size_t output_owner)
Definition: arithmetic_gmw_gate.h:259
encrypto::motion::proto::arithmetic_gmw::OutputGate::EvaluateOnline
void EvaluateOnline() final override
Definition: arithmetic_gmw_gate.h:277
encrypto::motion::proto::arithmetic_gmw::SubtractionGate::SubtractionGate
SubtractionGate(const arithmetic_gmw::WirePointer< T > &a, const arithmetic_gmw::WirePointer< T > &b)
Definition: arithmetic_gmw_gate.h:460
encrypto::motion::TwoGate::parent_a_
std::vector< WirePointer > parent_a_
Definition: gate.h:220
encrypto::motion::Register::IncrementEvaluatedGatesOnlineCounter
void IncrementEvaluatedGatesOnlineCounter()
Definition: register.cpp:125
encrypto::motion::Backend
Definition: backend.h:88
encrypto::motion::BaseProvider::GetMyRandomnessGenerator
primitives::SharingRandomnessGenerator & GetMyRandomnessGenerator(std::size_t party_id)
Definition: motion_base_provider.h:57
encrypto::motion::Gate::RegisterWaitingFor
void RegisterWaitingFor(std::size_t wire_id)
Definition: gate.cpp:36
encrypto::motion::proto::arithmetic_gmw::MultiplicationGate::~MultiplicationGate
~MultiplicationGate() final=default
communication_layer.h
encrypto::motion::proto::arithmetic_gmw::AdditionGate::GetOutputAsArithmeticShare
arithmetic_gmw::SharePointer< T > GetOutputAsArithmeticShare()
Definition: arithmetic_gmw_gate.h:445
encrypto::motion::proto::arithmetic_gmw::InputGate::GetOutputArithmeticWire
arithmetic_gmw::WirePointer< T > GetOutputArithmeticWire()
Definition: arithmetic_gmw_gate.h:174
encrypto::motion::Gate::SetOnlineIsReady
void SetOnlineIsReady()
Definition: gate.cpp:54
encrypto::motion::communication::BuildOutputMessage
flatbuffers::FlatBufferBuilder BuildOutputMessage(std::size_t gate_id, std::vector< std::uint8_t > wire_payload)
Definition: output_message.cpp:35
encrypto::motion::proto::arithmetic_gmw::SquareGate::SquareGate
SquareGate(const arithmetic_gmw::WirePointer< T > &a)
Definition: arithmetic_gmw_gate.h:698
encrypto::motion::Gate::backend_
Backend & backend_
Definition: gate.h:101
encrypto::motion::RestrictAddVectors
std::vector< T > RestrictAddVectors(const std::vector< T > &a, const std::vector< T > &b)
Adds each element in a and b and returns the result. It is assumed that the vectors do not overlap.
Definition: helpers.h:197
encrypto::motion::proto::arithmetic_gmw::MultiplicationGate::GetOutputAsArithmeticShare
arithmetic_gmw::SharePointer< T > GetOutputAsArithmeticShare()
Definition: arithmetic_gmw_gate.h:677
encrypto::motion::Gate::SetSetupIsReady
void SetSetupIsReady()
Definition: gate.cpp:46
encrypto::motion::proto::arithmetic_gmw::AdditionGate::EvaluateOnline
void EvaluateOnline() final override
Definition: arithmetic_gmw_gate.h:418
encrypto::motion::Register::IncrementEvaluatedGatesSetupCounter
void IncrementEvaluatedGatesSetupCounter()
Definition: register.cpp:114
encrypto::motion::proto::arithmetic_gmw::OutputGate::OutputGate
OutputGate(const arithmetic_gmw::WirePointer< T > &parent, std::size_t output_owner=kAll)
Definition: arithmetic_gmw_gate.h:202
encrypto::motion::proto::arithmetic_gmw::AdditionGate::AdditionGate
AdditionGate(const arithmetic_gmw::WirePointer< T > &a, const arithmetic_gmw::WirePointer< T > &b)
Definition: arithmetic_gmw_gate.h:379
encrypto::motion::OneGate
Definition: gate.h:148
encrypto::motion::proto::arithmetic_gmw::OutputGate::m
std::mutex m
Definition: arithmetic_gmw_gate.h:373
encrypto::motion::proto::arithmetic_gmw::SquareGate::GetOutputAsArithmeticShare
arithmetic_gmw::SharePointer< T > GetOutputAsArithmeticShare()
Definition: arithmetic_gmw_gate.h:799
encrypto::motion::proto::arithmetic_gmw::SquareGate::EvaluateSetup
void EvaluateSetup() final override
Definition: arithmetic_gmw_gate.h:734
encrypto::motion::communication::GetOutputMessage
const encrypto::motion::communication::OutputMessage * GetOutputMessage(const void *buf)
Definition: output_message_generated.h:136
encrypto::motion::proto::arithmetic_gmw::SubtractionGate::EvaluateSetup
void EvaluateSetup() final override
Definition: arithmetic_gmw_gate.h:494
register.h
mt_provider.h
encrypto::motion::proto::arithmetic_gmw::AdditionGate::AdditionGate
AdditionGate()=delete
logger.h
arithmetic_gmw_share.h
encrypto::motion::BaseProvider::GetTheirRandomnessGenerator
primitives::SharingRandomnessGenerator & GetTheirRandomnessGenerator(std::size_t party_id)
Definition: motion_base_provider.h:60
encrypto::motion::proto::arithmetic_gmw::MultiplicationGate::MultiplicationGate
MultiplicationGate(const arithmetic_gmw::WirePointer< T > &a, const arithmetic_gmw::WirePointer< T > &b)
Definition: arithmetic_gmw_gate.h:540
encrypto::motion::SharePointer
std::shared_ptr< Share > SharePointer
Definition: conversion_gate.h:49
encrypto::motion::Gate::gate_id_
std::int64_t gate_id_
Definition: gate.h:102
encrypto::motion::proto::arithmetic_gmw::OutputGate::OutputGate
OutputGate(const motion::SharePointer &parent, std::size_t output_owner)
Definition: arithmetic_gmw_gate.h:264
encrypto::motion::OutputGate
Definition: gate.h:194
encrypto::motion::ToByteVector
std::vector< std::uint8_t > ToByteVector(const std::vector< UnsignedIntegralType > &values)
Converts a vector of unsigned integral values to a vector of uint8_t.
Definition: helpers.h:57
sharing_randomness_generator.h
encrypto::motion::Gate
Definition: gate.h:67
encrypto::motion::proto::arithmetic_gmw::SharePointer
std::shared_ptr< Share< T > > SharePointer
Definition: arithmetic_gmw_share.h:156
encrypto::motion::proto::arithmetic_gmw::SquareGate
Definition: arithmetic_gmw_gate.h:696
encrypto::motion::Register::RegisterNextGate
void RegisterNextGate(GatePointer gate)
Definition: register.cpp:77
encrypto::motion::Gate::setup_is_ready_
std::atomic< bool > setup_is_ready_
Definition: gate.h:106
encrypto::motion::proto::arithmetic_gmw::SubtractionGate::SubtractionGate
SubtractionGate()=delete
encrypto::motion::kVerboseDebug
constexpr bool kVerboseDebug
Definition: constants.h:50
encrypto::motion::proto::arithmetic_gmw::SquareGate::~SquareGate
~SquareGate() final=default
encrypto::motion::proto::arithmetic_gmw::SubtractionGate
Definition: arithmetic_gmw_gate.h:458
encrypto::motion::Gate::GetMtProvider
MtProvider & GetMtProvider()
Definition: gate.cpp:104
encrypto::motion::proto::arithmetic_gmw::InputGate
Definition: arithmetic_gmw_gate.h:58
gate.h
d
static const fe d
Definition: mycurve25519_tables.h:30
encrypto::motion::SubVectors
std::vector< T > SubVectors(const std::vector< T > &a, const std::vector< T > &b)
Subtracts each element in a and b and returns the result.
Definition: helpers.h:120
encrypto::motion::MpcProtocol::kArithmeticGmw
@ kArithmeticGmw
encrypto::motion::proto::arithmetic_gmw::MultiplicationGate::EvaluateSetup
void EvaluateSetup() final override
Definition: arithmetic_gmw_gate.h:589
encrypto::motion::Gate::WaitSetup
void WaitSetup() const
Definition: gate.cpp:68
encrypto::motion::proto::arithmetic_gmw::OutputGate::EvaluateSetup
void EvaluateSetup() final override
Definition: arithmetic_gmw_gate.h:272
encrypto::motion::Register::RegisterNextWire
void RegisterNextWire(WirePointer wire)
Definition: register.h:78
encrypto::motion::proto::arithmetic_gmw::InputGate::EvaluateSetup
void EvaluateSetup() final override
Definition: arithmetic_gmw_gate.h:97
encrypto::motion::BaseProvider::WaitForSetup
void WaitForSetup() const
Definition: motion_base_provider.cpp:183
encrypto::motion::proto::arithmetic_gmw::InputGate::InputGate
InputGate(std::span< const T > input, std::size_t input_owner, Backend &backend)
Definition: arithmetic_gmw_gate.h:62
encrypto::motion::GateType::kInteractive
@ kInteractive
encrypto::motion::kDebug
constexpr bool kDebug
Definition: config.h:36
encrypto::motion::Register::NextArithmeticSharingId
std::size_t NextArithmeticSharingId(std::size_t number_of_parallel_values)
Definition: register.cpp:63
geninput.default
default
Definition: geninput.py:149
encrypto::motion::proto::arithmetic_gmw
Definition: arithmetic_gmw_gate.h:45
message_generated.h
encrypto::motion::proto::arithmetic_gmw::OutputGate::is_my_output_
bool is_my_output_
Definition: arithmetic_gmw_gate.h:369
encrypto::motion::to_string
std::string to_string(std::vector< T > values)
Returns a string representation of the std::vector values.
Definition: helpers.h:455
encrypto::motion::proto::arithmetic_gmw::SquareGate::EvaluateOnline
void EvaluateOnline() final override
Definition: arithmetic_gmw_gate.h:739
encrypto::motion::proto::arithmetic_gmw::Wire
Definition: arithmetic_gmw_wire.h:33
encrypto::motion::proto::arithmetic_gmw::WirePointer
std::shared_ptr< Wire< T > > WirePointer
Definition: arithmetic_gmw_wire.h:68
encrypto::motion::Register::NextGateId
std::size_t NextGateId() noexcept
Definition: register.cpp:53
encrypto::motion::InputGate::input_owner_id_
std::int64_t input_owner_id_
Definition: gate.h:179
sp_provider.h