27 #include <fmt/format.h>
57 template <
typename T,
typename = std::enable_if_t<std::is_
unsigned_v<T>>>
63 :
Base(backend), input_(std::vector(input.begin(), input.end())) {
69 :
Base(backend), input_(std::move(input)) {
75 static_assert(!std::is_same_v<T, bool>);
81 fmt::format(
"Created an arithmetic_gmw::InputGate with global id {}",
gate_id_));
89 auto gate_info = fmt::format(
"uint{}_t type, gate id {}, owner {}",
sizeof(T) * 8,
gate_id_,
92 "Allocate an arithmetic_gmw::InputGate with following properties: {}", gate_info));
110 auto my_id = communication_layer.GetMyId();
111 auto number_of_parties = communication_layer.GetNumberOfParties();
113 std::vector<T> result;
116 result.resize(input_.size());
117 auto log_string = std::string(
"");
118 for (
auto party_id = 0u; party_id < number_of_parties; ++party_id) {
119 if (party_id == my_id) {
124 randomness_generator.template GetUnsigned<T>(arithmetic_sharing_id_, input_.size());
126 log_string.append(fmt::format(
"id#{}:{} ", party_id, randomness.at(0)));
128 for (
auto j = 0u; j < result.size(); ++j) {
129 result.at(j) += randomness.at(j);
132 for (
auto j = 0u; j < result.size(); ++j) {
133 result.at(j) = input_.at(j) - result.at(j);
137 auto s = fmt::format(
138 "My (id#{}) arithmetic input sharing for gate#{}, my input: {}, my "
139 "share: {}, expected shares of other parties: {}",
145 result = randomness_generator.template GetUnsigned<T>(arithmetic_sharing_id_, input_.size());
148 auto s = fmt::format(
149 "Arithmetic input sharing (gate#{}) of Party's#{} input, got a share "
155 auto my_wire = std::dynamic_pointer_cast<arithmetic_gmw::Wire<T>>(
output_wires_.at(0));
157 my_wire->GetMutableValues() = std::move(result);
168 auto result = std::make_shared<arithmetic_gmw::Share<T>>(arithmetic_wire);
175 auto result = std::dynamic_pointer_cast<arithmetic_gmw::Wire<T>>(
output_wires_.at(0));
181 std::size_t arithmetic_sharing_id_;
183 std::vector<T> input_;
186 constexpr std::size_t
kAll = std::numeric_limits<std::int64_t>::max();
188 template <
typename T,
typename = std::enable_if_t<std::is_
unsigned_v<T>>>
196 auto arithmetic_wire = std::dynamic_pointer_cast<arithmetic_gmw::Wire<T>>(
output_wires_.at(0));
197 assert(arithmetic_wire);
198 auto result = std::make_shared<arithmetic_gmw::Share<T>>(arithmetic_wire);
203 :
Base(parent->GetBackend()) {
207 auto sharing_type =
to_string(parent->GetProtocol());
209 std::runtime_error((fmt::format(
"Arithmetic output gate expects an arithmetic share, "
210 "got a share of type {}",
218 auto my_id = communication_layer.GetMyId();
219 auto number_of_parties = communication_layer.GetNumberOfParties();
221 if (
static_cast<std::size_t
>(output_owner) >= number_of_parties &&
222 static_cast<std::size_t
>(output_owner) !=
kAll) {
223 throw std::runtime_error(
224 fmt::format(
"Invalid output owner: {} of {}", output_owner, number_of_parties));
238 auto w = std::static_pointer_cast<motion::Wire>(
252 auto gate_info = fmt::format(
"uint{}_t type, gate id {}, owner {}",
sizeof(T) * 8,
gate_id_,
255 "Allocate an arithmetic_gmw::OutputGate with following properties: {}", gate_info));
260 :
OutputGate(parent->GetArithmeticWire(), output_owner) {
265 :
OutputGate(std::dynamic_pointer_cast<const arithmetic_gmw::
Share<T>>(parent),
284 auto my_id = communication_layer.GetMyId();
285 auto number_of_parties = communication_layer.GetNumberOfParties();
288 auto arithmetic_wire = std::dynamic_pointer_cast<const arithmetic_gmw::Wire<T>>(
parent_.at(0));
289 assert(arithmetic_wire);
291 arithmetic_wire->GetIsReadyCondition().Wait();
293 auto output = arithmetic_wire->GetValues();
299 communication_layer.SendMessage(
output_owner_, std::move(output_message));
305 communication_layer.BroadcastMessage(std::move(output_message));
311 std::vector<std::vector<T>> shared_outputs;
312 shared_outputs.reserve(number_of_parties);
314 for (std::size_t i = 0; i < number_of_parties; ++i) {
316 shared_outputs.push_back(output);
322 assert(output_message_pointer);
323 assert(output_message_pointer->wires()->size() == 1);
325 shared_outputs.push_back(
326 FromByteVector<T>(*output_message_pointer->wires()->Get(0)->payload()));
327 assert(shared_outputs[i].size() ==
parent_[0]->GetNumberOfSimdValues());
336 output =
AddVectors(std::move(shared_outputs));
340 auto arithmetic_output_wire =
341 std::dynamic_pointer_cast<arithmetic_gmw::Wire<T>>(
output_wires_.at(0));
342 assert(arithmetic_output_wire);
343 arithmetic_output_wire->GetMutableValues() = output;
346 std::string shares{
""};
347 for (
auto i = 0u; i < number_of_parties; ++i) {
348 shares.append(fmt::format(
"id#{}:{} ", i,
to_string(shared_outputs.at(i))));
352 fmt::format(
"Received output shares: {} from other parties, "
353 "reconstructed result is {}",
361 fmt::format(
"Evaluated arithmetic_gmw::OutputGate with id#{}",
gate_id_));
376 template <
typename T,
typename = std::enable_if_t<std::is_
unsigned_v<T>>>
381 parent_a_ = {std::static_pointer_cast<motion::Wire>(a)};
382 parent_b_ = {std::static_pointer_cast<motion::Wire>(b)};
384 assert(
parent_a_.at(0)->GetNumberOfSimdValues() ==
parent_b_.at(0)->GetNumberOfSimdValues());
398 auto w = std::static_pointer_cast<motion::Wire>(
405 fmt::format(
"uint{}_t type, gate id {}, parents: {}, {}",
sizeof(T) * 8,
gate_id_,
408 "Created an arithmetic_gmw::AdditionGate with following properties: {}", gate_info));
422 parent_a_.at(0)->GetIsReadyCondition().Wait();
423 parent_b_.at(0)->GetIsReadyCondition().Wait();
425 auto wire_a = std::dynamic_pointer_cast<const arithmetic_gmw::Wire<T>>(
parent_a_.at(0));
426 auto wire_b = std::dynamic_pointer_cast<const arithmetic_gmw::Wire<T>>(
parent_b_.at(0));
431 std::vector<T> output;
434 auto arithmetic_wire = std::dynamic_pointer_cast<arithmetic_gmw::Wire<T>>(
output_wires_.at(0));
435 arithmetic_wire->GetMutableValues() = std::move(output);
438 fmt::format(
"Evaluated arithmetic_gmw::AdditionGate with id#{}",
gate_id_));
446 auto arithmetic_wire = std::dynamic_pointer_cast<arithmetic_gmw::Wire<T>>(
output_wires_.at(0));
447 assert(arithmetic_wire);
448 auto result = std::make_shared<arithmetic_gmw::Share<T>>(arithmetic_wire);
457 template <
typename T,
typename = std::enable_if_t<std::is_
unsigned_v<T>>>
462 parent_a_ = {std::static_pointer_cast<motion::Wire>(a)};
463 parent_b_ = {std::static_pointer_cast<motion::Wire>(b)};
465 assert(
parent_a_.at(0)->GetNumberOfSimdValues() ==
parent_b_.at(0)->GetNumberOfSimdValues());
479 auto w = std::static_pointer_cast<motion::Wire>(
486 fmt::format(
"uint{}_t type, gate id {}, parents: {}, {}",
sizeof(T) * 8,
gate_id_,
489 "Created an arithmetic_gmw::SubtractionGate with following properties: {}", gate_info));
503 parent_a_.at(0)->GetIsReadyCondition().Wait();
504 parent_b_.at(0)->GetIsReadyCondition().Wait();
506 auto wire_a = std::dynamic_pointer_cast<const arithmetic_gmw::Wire<T>>(
parent_a_.at(0));
507 auto wire_b = std::dynamic_pointer_cast<const arithmetic_gmw::Wire<T>>(
parent_b_.at(0));
512 std::vector<T> output =
SubVectors(wire_a->GetValues(), wire_b->GetValues());
514 auto arithmetic_wire = std::dynamic_pointer_cast<arithmetic_gmw::Wire<T>>(
output_wires_.at(0));
515 arithmetic_wire->GetMutableValues() = std::move(output);
518 fmt::format(
"Evaluated arithmetic_gmw::SubtractionGate with id#{}",
gate_id_));
526 auto arithmetic_wire = std::dynamic_pointer_cast<arithmetic_gmw::Wire<T>>(
output_wires_.at(0));
527 assert(arithmetic_wire);
528 auto result = std::make_shared<arithmetic_gmw::Share<T>>(arithmetic_wire);
537 template <
typename T,
typename = std::enable_if_t<std::is_
unsigned_v<T>>>
543 parent_a_ = {std::static_pointer_cast<motion::Wire>(a)};
544 parent_b_ = {std::static_pointer_cast<motion::Wire>(b)};
546 assert(
parent_a_.at(0)->GetNumberOfSimdValues() ==
parent_b_.at(0)->GetNumberOfSimdValues());
551 d_ = std::make_shared<arithmetic_gmw::Wire<T>>(
backend_, a->GetNumberOfSimdValues());
553 e_ = std::make_shared<arithmetic_gmw::Wire<T>>(
backend_, a->GetNumberOfSimdValues());
556 d_output_ = std::make_shared<OutputGate<T>>(d_);
557 e_output_ = std::make_shared<OutputGate<T>>(e_);
571 auto w = std::static_pointer_cast<motion::Wire>(
577 number_of_mts_ =
parent_a_.at(0)->GetNumberOfSimdValues();
578 mt_offset_ =
GetMtProvider().template RequestArithmeticMts<T>(number_of_mts_);
581 fmt::format(
"uint{}_t type, gate id {}, parents: {}, {}",
sizeof(T) * 8,
gate_id_,
584 "Created an arithmetic_gmw::MultiplicationGate with following properties: {}", gate_info));
597 parent_a_.at(0)->GetIsReadyCondition().Wait();
598 parent_b_.at(0)->GetIsReadyCondition().Wait();
601 mt_provider.WaitFinished();
602 const auto& mts = mt_provider.template GetIntegerAll<T>();
604 const auto x = std::dynamic_pointer_cast<const arithmetic_gmw::Wire<T>>(
parent_a_.at(0));
606 d_->GetMutableValues() = std::vector<T>(
607 mts.a.begin() + mt_offset_, mts.a.begin() + mt_offset_ + x->GetNumberOfSimdValues());
608 T* __restrict__ d_v = d_->GetMutableValues().data();
609 const T* __restrict__ x_v = x->GetValues().data();
610 const auto number_of_simd_values{x->GetNumberOfSimdValues()};
612 std::transform(x_v, x_v + number_of_simd_values, d_v, d_v,
613 [](
const T& a,
const T& b) {
return a + b; });
614 d_->SetOnlineFinished();
616 const auto y = std::dynamic_pointer_cast<const arithmetic_gmw::Wire<T>>(
parent_b_.at(0));
618 e_->GetMutableValues() = std::vector<T>(
619 mts.b.begin() + mt_offset_, mts.b.begin() + mt_offset_ + x->GetNumberOfSimdValues());
620 T* __restrict__ e_v = e_->GetMutableValues().data();
621 const T* __restrict__ y_v = y->GetValues().data();
622 std::transform(y_v, y_v + number_of_simd_values, e_v, e_v,
623 [](
const T& a,
const T& b) {
return a + b; });
624 e_->SetOnlineFinished();
627 d_output_->WaitOnline();
628 e_output_->WaitOnline();
630 const auto& d_clear = d_output_->GetOutputWires().at(0);
631 const auto& e_clear = e_output_->GetOutputWires().at(0);
633 d_clear->GetIsReadyCondition().Wait();
634 e_clear->GetIsReadyCondition().Wait();
636 const auto d_w = std::dynamic_pointer_cast<const arithmetic_gmw::Wire<T>>(d_clear);
637 const auto x_i_w = std::dynamic_pointer_cast<const arithmetic_gmw::Wire<T>>(
parent_a_.at(0));
638 const auto e_w = std::dynamic_pointer_cast<const arithmetic_gmw::Wire<T>>(e_clear);
639 const auto y_i_w = std::dynamic_pointer_cast<const arithmetic_gmw::Wire<T>>(
parent_b_.at(0));
646 auto output = std::dynamic_pointer_cast<arithmetic_gmw::Wire<T>>(
output_wires_.at(0));
648 output->GetMutableValues() =
649 std::vector<T>(mts.c.begin() + mt_offset_,
650 mts.c.begin() + mt_offset_ +
parent_a_.at(0)->GetNumberOfSimdValues());
652 const T* __restrict__
d{d_w->GetValues().data()};
653 const T* __restrict__ s_x{x_i_w->GetValues().data()};
654 const T* __restrict__ e{e_w->GetValues().data()};
655 const T* __restrict__ s_y{y_i_w->GetValues().data()};
656 T* __restrict__ output_pointer{output->GetMutableValues().data()};
660 for (
auto i = 0ull; i < output->GetNumberOfSimdValues(); ++i) {
661 output_pointer[i] += (
d[i] * s_y[i]) + (e[i] * s_x[i]) - (e[i] *
d[i]);
664 for (
auto i = 0ull; i < output->GetNumberOfSimdValues(); ++i) {
665 output_pointer[i] += (
d[i] * s_y[i]) + (e[i] * s_x[i]);
670 fmt::format(
"Evaluated arithmetic_gmw::MultiplicationGate with id#{}",
gate_id_));
678 auto arithmetic_wire = std::dynamic_pointer_cast<arithmetic_gmw::Wire<T>>(
output_wires_.at(0));
679 assert(arithmetic_wire);
680 auto result = std::make_shared<arithmetic_gmw::Share<T>>(arithmetic_wire);
690 std::shared_ptr<OutputGate<T>> d_output_, e_output_;
692 std::size_t number_of_mts_, mt_offset_;
695 template <
typename T,
typename = std::enable_if_t<std::is_
unsigned_v<T>>>
699 parent_ = {std::static_pointer_cast<motion::Wire>(a)};
704 d_ = std::make_shared<arithmetic_gmw::Wire<T>>(
backend_, a->GetNumberOfSimdValues());
707 d_output_ = std::make_shared<OutputGate<T>>(d_);
717 auto w = std::static_pointer_cast<motion::Wire>(
723 number_of_sps_ =
parent_.at(0)->GetNumberOfSimdValues();
724 sp_offset_ =
GetSpProvider().template RequestSps<T>(number_of_sps_);
726 auto gate_info = fmt::format(
"uint{}_t type, gate id {}, parent: {}",
sizeof(T) * 8,
gate_id_,
729 "Created an arithmetic_gmw::SquareGate with following properties: {}", gate_info));
742 parent_.at(0)->GetIsReadyCondition().Wait();
745 sp_provider.WaitFinished();
746 const auto& sps = sp_provider.template GetSpsAll<T>();
748 const auto x = std::dynamic_pointer_cast<const arithmetic_gmw::Wire<T>>(
parent_.at(0));
750 d_->GetMutableValues() = std::vector<T>(
751 sps.a.begin() + sp_offset_, sps.a.begin() + sp_offset_ + x->GetNumberOfSimdValues());
752 const auto number_of_simd_values{d_->GetNumberOfSimdValues()};
753 T* __restrict__ d_v{d_->GetMutableValues().data()};
754 const T* __restrict__ x_v{x->GetValues().data()};
755 std::transform(x_v, x_v + number_of_simd_values, d_v, d_v,
756 [](
const T& a,
const T& b) {
return a + b; });
757 d_->SetOnlineFinished();
760 d_output_->WaitOnline();
762 const auto& d_clear = d_output_->GetOutputWires().at(0);
764 d_clear->GetIsReadyCondition().Wait();
766 const auto d_w = std::dynamic_pointer_cast<const arithmetic_gmw::Wire<T>>(d_clear);
767 const auto x_i_w = std::dynamic_pointer_cast<const arithmetic_gmw::Wire<T>>(
parent_.at(0));
772 auto output = std::dynamic_pointer_cast<arithmetic_gmw::Wire<T>>(
output_wires_.at(0));
774 output->GetMutableValues() =
775 std::vector<T>(sps.c.begin() + sp_offset_,
776 sps.c.begin() + sp_offset_ +
parent_.at(0)->GetNumberOfSimdValues());
778 const T* __restrict__
d{d_w->GetValues().data()};
779 const T* __restrict__ s_x{x_i_w->GetValues().data()};
780 T* __restrict__ output_pointer{output->GetMutableValues().data()};
783 for (
auto i = 0ull; i < output->GetNumberOfSimdValues(); ++i) {
784 output_pointer[i] += 2 * (
d[i] * s_x[i]) - (
d[i] *
d[i]);
787 for (
auto i = 0ull; i < output->GetNumberOfSimdValues(); ++i) {
788 output_pointer[i] += 2 * (
d[i] * s_x[i]);
800 auto arithmetic_wire = std::dynamic_pointer_cast<arithmetic_gmw::Wire<T>>(
output_wires_.at(0));
801 assert(arithmetic_wire);
802 auto result = std::make_shared<arithmetic_gmw::Share<T>>(arithmetic_wire);
812 std::shared_ptr<OutputGate<T>> d_output_;
814 std::size_t number_of_sps_, sp_offset_;