MOTION  0.01
Framework for mixed-protocol multi-party computation
b2a_gate.h
Go to the documentation of this file.
1 // MIT License
2 //
3 // Copyright (c) 2019 Lennart Braun
4 //
5 // Permission is hereby granted, free of charge, to any person obtaining a copy
6 // of this software and associated documentation files (the "Software"), to deal
7 // in the Software without restriction, including without limitation the rights
8 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 // copies of the Software, and to permit persons to whom the Software is
10 // furnished to do so, subject to the following conditions:
11 //
12 // The above copyright notice and this permission notice shall be included in all
13 // copies or substantial portions of the Software.
14 //
15 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 // SOFTWARE.
22 
23 #pragma once
24 
25 #include <type_traits>
26 #include "base/register.h"
32 #include "protocols/gate.h"
33 #include "protocols/share.h"
34 #include "utility/constants.h"
36 #include "utility/logger.h"
37 
38 namespace encrypto::motion {
39 
40 template <typename T, typename = std::enable_if_t<std::is_unsigned_v<T>>>
41 class GmwToArithmeticGate final : public OneGate {
42  public:
43  GmwToArithmeticGate(const SharePointer& parent) : OneGate(parent->GetBackend()) {
44  parent_ = parent->GetWires();
45  const auto number_of_simd{parent->GetNumberOfSimdValues()};
46  constexpr auto bit_size = sizeof(T) * 8;
47 
48  // check that we have enough input wires to represent an element of T
49  assert(parent_.size() == bit_size);
50  for ([[maybe_unused]] const auto& wire : parent_) {
51  assert(wire->GetBitLength() == 1);
52  assert(wire->GetNumberOfSimdValues() == number_of_simd);
53  assert(wire->GetProtocol() == MpcProtocol::kBooleanGmw);
54  }
55 
58 
59  // create the output wire
60  output_wires_.emplace_back(
61  std::make_shared<proto::arithmetic_gmw::Wire<T>>(backend_, number_of_simd));
63 
64  std::vector<WirePointer> dummy_wires;
65  dummy_wires.reserve(number_of_simd);
66  for (std::size_t i = 0; i < bit_size; ++i) {
67  auto w = std::make_shared<proto::boolean_gmw::Wire>(backend_, number_of_simd);
69  dummy_wires.emplace_back(std::move(w));
70  }
71  ts_ = std::make_shared<proto::boolean_gmw::Share>(dummy_wires);
72  // also create an output gate for the ts
73  ts_output_ = std::make_shared<proto::boolean_gmw::OutputGate>(ts_);
74  GetRegister().RegisterNextGate(ts_output_);
75 
76  // register the required number of shared bits
77  number_of_sbs_ = number_of_simd * bit_size;
78  sb_offset_ = GetSbProvider().template RequestSbs<T>(number_of_sbs_);
79 
80  // register this gate
82 
83  // register this gate with the parent wires
84  for (auto& wire : parent_) {
85  RegisterWaitingFor(wire->GetWireId());
86  wire->RegisterWaitingGate(gate_id_);
87  }
88 
89  if constexpr (kDebug) {
90  auto gate_info = fmt::format("gate id {}, parent wires: ", gate_id_);
91  for (const auto& wire : parent_) gate_info.append(fmt::format("{} ", wire->GetWireId()));
92  gate_info.append(fmt::format(" output wire: {}", output_wires_.at(0)->GetWireId()));
93  GetLogger().LogDebug(fmt::format(
94  "Created a Boolean GMW to Arithmetic GMW conversion gate with following properties: {}",
95  gate_info));
96  }
97  }
98 
99  ~GmwToArithmeticGate() final = default;
100 
101  void EvaluateSetup() final {
102  SetSetupIsReady();
104  }
105 
106  void EvaluateOnline() final {
107  WaitSetup();
108  assert(setup_is_ready_);
109 
110  // wait for the parent wires to obtain their values
111  for (const auto& wire : parent_) {
112  wire->GetIsReadyCondition().Wait();
113  }
114 
115  // wait for the SbProvider to finish
116  auto& sb_provider = GetSbProvider();
117  sb_provider.WaitFinished();
118 
119  const auto number_of_simd{parent_.at(0)->GetNumberOfSimdValues()};
120  constexpr auto bit_size = sizeof(T) * 8;
121 
122  // mask the input bits with the shared bits
123  // and assign the result to t
124  const auto& sbs = sb_provider.template GetSbsAll<T>();
125  auto& ts_wires = ts_->GetMutableWires();
126  for (std::size_t wire_i = 0; wire_i < bit_size; ++wire_i) {
127  auto t_wire = std::dynamic_pointer_cast<proto::boolean_gmw::Wire>(ts_wires.at(wire_i));
128  auto parent_gmw_wire =
129  std::dynamic_pointer_cast<const proto::boolean_gmw::Wire>(parent_.at(wire_i));
130  t_wire->GetMutableValues() = parent_gmw_wire->GetValues();
131  // xor them with the shared bits
132  for (std::size_t j = 0; j < number_of_simd; ++j) {
133  auto b = t_wire->GetValues().Get(j);
134  bool sb = sbs.at(sb_offset_ + wire_i * number_of_simd + j) & 1;
135  t_wire->GetMutableValues().Set(b ^ sb, j);
136  }
137  t_wire->SetOnlineFinished();
138  }
139 
140  // reconstruct t
141  ts_output_->WaitOnline();
142  const auto& ts_clear = ts_output_->GetOutputWires();
143  std::vector<std::shared_ptr<proto::boolean_gmw::Wire>> ts_clear_b;
144  ts_clear_b.reserve(ts_clear.size());
145  std::transform(ts_clear.cbegin(), ts_clear.cend(), std::back_inserter(ts_clear_b),
146  [](auto& w) { return std::dynamic_pointer_cast<proto::boolean_gmw::Wire>(w); });
147 
148  auto output = std::dynamic_pointer_cast<proto::arithmetic_gmw::Wire<T>>(output_wires_.at(0));
149  output->GetMutableValues().resize(number_of_simd);
150  for (std::size_t j = 0; j < number_of_simd; ++j) {
151  T output_value = 0;
152  for (std::size_t wire_i = 0; wire_i < bit_size; ++wire_i) {
153  if (GetCommunicationLayer().GetMyId() == 0) {
154  T t(ts_clear_b.at(wire_i)->GetValues().Get(j)); // the masked bit
155  T r(sbs.at(sb_offset_ + wire_i * number_of_simd + j)); // the arithmetically shared bit
156  output_value += T(t + r - 2 * t * r) << wire_i;
157  } else {
158  T t(ts_clear_b.at(wire_i)->GetValues().Get(j)); // the masked bit
159  T r(sbs.at(sb_offset_ + wire_i * number_of_simd + j)); // the arithmetically shared bit
160  output_value += T(r - 2 * t * r) << wire_i;
161  }
162  }
163  output->GetMutableValues().at(j) = output_value;
164  }
165 
166  GetLogger().LogDebug(fmt::format("Evaluated B2AGate with id#{}", gate_id_));
169  }
170 
172  auto arithmetic_wire =
173  std::dynamic_pointer_cast<proto::arithmetic_gmw::Wire<T>>(output_wires_.at(0));
174  assert(arithmetic_wire);
175  auto result = std::make_shared<proto::arithmetic_gmw::Share<T>>(arithmetic_wire);
176  return result;
177  }
178 
180  return std::dynamic_pointer_cast<Share>(GetOutputAsArithmeticShare());
181  }
182 
183  GmwToArithmeticGate() = delete;
184 
185  GmwToArithmeticGate(const Gate&) = delete;
186 
187  private:
188  std::size_t number_of_sbs_;
189  std::size_t sb_offset_;
191  std::shared_ptr<proto::boolean_gmw::OutputGate> ts_output_;
192 };
193 
194 } // namespace encrypto::motion
encrypto::motion::GmwToArithmeticGate::~GmwToArithmeticGate
~GmwToArithmeticGate() final=default
encrypto::motion::Gate::GetSbProvider
SbProvider & GetSbProvider()
Definition: gate.cpp:108
encrypto::motion::Gate::gate_type_
GateType gate_type_
Definition: gate.h:105
encrypto::motion::GmwToArithmeticGate
Definition: b2a_gate.h:41
boolean_gmw_gate.h
encrypto::motion::Gate::GetCommunicationLayer
communication::CommunicationLayer & GetCommunicationLayer()
Definition: gate.cpp:92
encrypto::motion::Gate::output_wires_
std::vector< WirePointer > output_wires_
Definition: gate.h:100
fiber_condition.h
encrypto::motion::Gate::GetLogger
Logger & GetLogger()
Definition: gate.cpp:100
encrypto::motion::GmwToArithmeticGate::GmwToArithmeticGate
GmwToArithmeticGate()=delete
encrypto::motion::GmwToArithmeticGate::EvaluateOnline
void EvaluateOnline() final
Definition: b2a_gate.h:106
encrypto::motion::OneGate::parent_
std::vector< WirePointer > parent_
Definition: gate.h:155
encrypto::motion::GmwToArithmeticGate::EvaluateSetup
void EvaluateSetup() final
Definition: b2a_gate.h:101
encrypto::motion::Logger::LogDebug
void LogDebug(const std::string &message)
Definition: logger.cpp:142
encrypto::motion::Gate::requires_online_interaction_
std::atomic< bool > requires_online_interaction_
Definition: gate.h:108
encrypto::motion::Gate::GetRegister
Register & GetRegister()
Definition: gate.cpp:96
encrypto::motion::proto::boolean_gmw::SharePointer
std::shared_ptr< Share > SharePointer
Definition: backend.h:46
encrypto::motion::MpcProtocol::kBooleanGmw
@ kBooleanGmw
encrypto::motion::Register::IncrementEvaluatedGatesOnlineCounter
void IncrementEvaluatedGatesOnlineCounter()
Definition: register.cpp:125
encrypto::motion::GmwToArithmeticGate::GetOutputAsShare
const SharePointer GetOutputAsShare() const
Definition: b2a_gate.h:179
encrypto::motion::Gate::RegisterWaitingFor
void RegisterWaitingFor(std::size_t wire_id)
Definition: gate.cpp:36
boolean_gmw_wire.h
encrypto::motion::Gate::SetOnlineIsReady
void SetOnlineIsReady()
Definition: gate.cpp:54
encrypto::motion::Gate::backend_
Backend & backend_
Definition: gate.h:101
encrypto::motion::Gate::SetSetupIsReady
void SetSetupIsReady()
Definition: gate.cpp:46
encrypto::motion::Register::IncrementEvaluatedGatesSetupCounter
void IncrementEvaluatedGatesSetupCounter()
Definition: register.cpp:114
encrypto::motion::OneGate
Definition: gate.h:148
register.h
logger.h
encrypto::motion
Definition: algorithm_description.cpp:35
arithmetic_gmw_share.h
encrypto::motion::SharePointer
std::shared_ptr< Share > SharePointer
Definition: conversion_gate.h:49
encrypto::motion::Gate::gate_id_
std::int64_t gate_id_
Definition: gate.h:102
encrypto::motion::communication::CommunicationLayer::GetMyId
std::size_t GetMyId() const
Definition: communication_layer.h:66
encrypto::motion::Gate
Definition: gate.h:67
encrypto::motion::GmwToArithmeticGate::GmwToArithmeticGate
GmwToArithmeticGate(const SharePointer &parent)
Definition: b2a_gate.h:43
encrypto::motion::proto::arithmetic_gmw::SharePointer
std::shared_ptr< Share< T > > SharePointer
Definition: arithmetic_gmw_share.h:156
encrypto::motion::Register::RegisterNextGate
void RegisterNextGate(GatePointer gate)
Definition: register.cpp:77
sb_provider.h
encrypto::motion::GmwToArithmeticGate::GetOutputAsArithmeticShare
const proto::arithmetic_gmw::SharePointer< T > GetOutputAsArithmeticShare() const
Definition: b2a_gate.h:171
encrypto::motion::Gate::setup_is_ready_
std::atomic< bool > setup_is_ready_
Definition: gate.h:106
share.h
gate.h
encrypto::motion::Gate::WaitSetup
void WaitSetup() const
Definition: gate.cpp:68
encrypto::motion::Register::RegisterNextWire
void RegisterNextWire(WirePointer wire)
Definition: register.h:78
constants.h
encrypto::motion::GateType::kInteractive
@ kInteractive
encrypto::motion::kDebug
constexpr bool kDebug
Definition: config.h:36
geninput.default
default
Definition: geninput.py:149
encrypto::motion::proto::arithmetic_gmw::Wire
Definition: arithmetic_gmw_wire.h:33
boolean_gmw_share.h
encrypto::motion::Register::NextGateId
std::size_t NextGateId() noexcept
Definition: register.cpp:53