25 #include <type_traits>
40 template <
typename T,
typename = std::enable_if_t<std::is_
unsigned_v<T>>>
45 const auto number_of_simd{parent->GetNumberOfSimdValues()};
46 constexpr
auto bit_size =
sizeof(T) * 8;
49 assert(
parent_.size() == bit_size);
50 for ([[maybe_unused]]
const auto& wire :
parent_) {
51 assert(wire->GetBitLength() == 1);
52 assert(wire->GetNumberOfSimdValues() == number_of_simd);
64 std::vector<WirePointer> dummy_wires;
65 dummy_wires.reserve(number_of_simd);
66 for (std::size_t i = 0; i < bit_size; ++i) {
67 auto w = std::make_shared<proto::boolean_gmw::Wire>(
backend_, number_of_simd);
69 dummy_wires.emplace_back(std::move(w));
71 ts_ = std::make_shared<proto::boolean_gmw::Share>(dummy_wires);
73 ts_output_ = std::make_shared<proto::boolean_gmw::OutputGate>(ts_);
77 number_of_sbs_ = number_of_simd * bit_size;
78 sb_offset_ =
GetSbProvider().template RequestSbs<T>(number_of_sbs_);
90 auto gate_info = fmt::format(
"gate id {}, parent wires: ",
gate_id_);
91 for (
const auto& wire :
parent_) gate_info.append(fmt::format(
"{} ", wire->GetWireId()));
92 gate_info.append(fmt::format(
" output wire: {}",
output_wires_.at(0)->GetWireId()));
94 "Created a Boolean GMW to Arithmetic GMW conversion gate with following properties: {}",
111 for (
const auto& wire :
parent_) {
112 wire->GetIsReadyCondition().Wait();
117 sb_provider.WaitFinished();
119 const auto number_of_simd{
parent_.at(0)->GetNumberOfSimdValues()};
120 constexpr
auto bit_size =
sizeof(T) * 8;
124 const auto& sbs = sb_provider.template GetSbsAll<T>();
125 auto& ts_wires = ts_->GetMutableWires();
126 for (std::size_t wire_i = 0; wire_i < bit_size; ++wire_i) {
127 auto t_wire = std::dynamic_pointer_cast<proto::boolean_gmw::Wire>(ts_wires.at(wire_i));
128 auto parent_gmw_wire =
129 std::dynamic_pointer_cast<const proto::boolean_gmw::Wire>(
parent_.at(wire_i));
130 t_wire->GetMutableValues() = parent_gmw_wire->GetValues();
132 for (std::size_t j = 0; j < number_of_simd; ++j) {
133 auto b = t_wire->GetValues().Get(j);
134 bool sb = sbs.at(sb_offset_ + wire_i * number_of_simd + j) & 1;
135 t_wire->GetMutableValues().Set(b ^ sb, j);
137 t_wire->SetOnlineFinished();
141 ts_output_->WaitOnline();
142 const auto& ts_clear = ts_output_->GetOutputWires();
143 std::vector<std::shared_ptr<proto::boolean_gmw::Wire>> ts_clear_b;
144 ts_clear_b.reserve(ts_clear.size());
145 std::transform(ts_clear.cbegin(), ts_clear.cend(), std::back_inserter(ts_clear_b),
146 [](
auto& w) { return std::dynamic_pointer_cast<proto::boolean_gmw::Wire>(w); });
148 auto output = std::dynamic_pointer_cast<proto::arithmetic_gmw::Wire<T>>(
output_wires_.at(0));
149 output->GetMutableValues().resize(number_of_simd);
150 for (std::size_t j = 0; j < number_of_simd; ++j) {
152 for (std::size_t wire_i = 0; wire_i < bit_size; ++wire_i) {
154 T t(ts_clear_b.at(wire_i)->GetValues().Get(j));
155 T r(sbs.at(sb_offset_ + wire_i * number_of_simd + j));
156 output_value += T(t + r - 2 * t * r) << wire_i;
158 T t(ts_clear_b.at(wire_i)->GetValues().Get(j));
159 T r(sbs.at(sb_offset_ + wire_i * number_of_simd + j));
160 output_value += T(r - 2 * t * r) << wire_i;
163 output->GetMutableValues().at(j) = output_value;
172 auto arithmetic_wire =
173 std::dynamic_pointer_cast<proto::arithmetic_gmw::Wire<T>>(
output_wires_.at(0));
174 assert(arithmetic_wire);
175 auto result = std::make_shared<proto::arithmetic_gmw::Share<T>>(arithmetic_wire);
188 std::size_t number_of_sbs_;
189 std::size_t sb_offset_;
191 std::shared_ptr<proto::boolean_gmw::OutputGate> ts_output_;