emp-toolkit
mextension_alsz.h
Go to the documentation of this file.
1 #ifndef OT_M_EXTENSION_ALSZ_H__
2 #define OT_M_EXTENSION_ALSZ_H__
3 #include "ot.h"
4 #include "co.h"
5 
9 class MOTExtension_ALSZ: public OT<MOTExtension_ALSZ> { public:
13  int l, ssp;
14 
15  block *k0, *k1, * data_open = nullptr;
16  bool *s;
17 
18  uint8_t * qT, *tT, *q = nullptr, **t, *block_s;
19  int u = 0;
20  bool setup = false;
21  bool committing = false;
23  MOTExtension_ALSZ(NetIO * io, bool committing = false, int ssp = 40): OT(io) , ssp(ssp){
24  this->l = 192;
25  u = 2;
26  this->base_ot = new OTCO(io);
27  this->s = new bool[l];
28  this->k0 = new block[l];
29  this->k1 = new block[l];
30  block_s = new uint8_t[l/8];
31  this->committing = committing;
32  }
33 
35  delete base_ot;
36  delete[] s;
37  delete[] k0;
38  delete[] k1;
39  delete[] block_s;
40  if(data_open != nullptr) {
41  delete[] data_open;
42  }
43  }
44 
45  void xor_arr (uint8_t * a, uint8_t * b, uint8_t * c, int n) {
46  if(n%16 == 0)
47  xorBlocks_arr((block*)a, (block *)b, (block*)c, n/16);
48  else {
49  uint8_t* end_a = a + n;
50  for(;a!= end_a;)
51  *(a++) = *(b++) ^ *(c++);
52  }
53  }
54 
55  block H(uint8_t* in, long id, int len) {
56  block res = zero_block();
57  for(int i = 0; i < len/16; ++i) {
58  res = xorBlocks(res, pi.H(_mm_loadl_epi64((block *)(in)), id));
59  in+=16;
60  }
61  return res;
62  }
63 
64  void bool_to_uint8(uint8_t * out, const bool*in, int len) {
65  for(int i = 0; i < len/8; ++i)
66  out[i] = 0;
67  for(int i = 0; i < len; ++i)
68  if(in[i])
69  out[i/8]|=(1<<(i%8));
70  }
71  void setup_send(block * in_k0 = nullptr, bool * in_s = nullptr){
72  setup = true;
73  if(in_s != nullptr) {
74  memcpy(k0, in_k0, l*sizeof(block));
75  memcpy(s, in_s, l);
76  bool_to_uint8(block_s, s, l);
77  return;
78  }
79  prg.random_bool(s, l);
80  base_ot->recv(k0, s, l);
81  bool_to_uint8(block_s, s, l);
82  }
83  void setup_recv(block * in_k0 = nullptr, block * in_k1 =nullptr) {
84  setup = true;
85  if(in_k0 !=nullptr) {
86  memcpy(k0, in_k0, l*sizeof(block));
87  memcpy(k1, in_k1, l*sizeof(block));
88  return;
89  }
90  prg.random_block(k0, l);
91  prg.random_block(k1, l);
92  base_ot->send(k0, k1, l);
93  setup = true;
94  }
95 
96  void ot_extension_send_pre(int length) {
97  assert(length%8==0);
98  if (length%128 !=0) length = (length/128 + 1)*128;
99 
100  q = new uint8_t[length/8*l];
101  if(!setup)setup_send();
102  setup = false;
103  if(committing) {
104  Hash::hash_once(com, s, l);
106  }
107  //get u, compute q
108  qT = new uint8_t[length/8*l];
109  uint8_t * q2 = new uint8_t[length/8*l];
110  uint8_t*tmp = new uint8_t[length/8];
111  PRG G;
112  for(int i = 0; i < l; ++i) {
113  io->recv_data(tmp, length/8);
114  G.reseed(&k0[i]);
115  G.random_data(q+(i*length/8), length/8);
116  if (s[i])
117  xor_arr(q2+(i*length/8), q+(i*length/8), tmp, length/8);
118  else
119  memcpy(q2+(i*length/8), q+(i*length/8), length/8);
120  }
121  sse_trans(qT, q2, l, length);
122  delete[] tmp;
123  delete[] q2;
124  }
125 
126  void ot_extension_recv_pre(block * data, const bool* r, int length) {
127  int old_length = length;
128  if (length%128 !=0) length = (length/128 + 1)*128;
129  if(!setup)setup_recv();
130  setup = false;
131  if(committing) {
133  }
134  uint8_t *block_r = new uint8_t[length/8];
135  bool_to_uint8(block_r, r, old_length);
136  // send u
137  t = new uint8_t*[2];
138  t[0] = new uint8_t[length/8*l];
139  t[1] = new uint8_t[length/8*l];
140  tT = new uint8_t[length/8*l];
141  uint8_t* tmp = new uint8_t[length/8];
142  PRG G;
143  for(int i = 0; i < l; ++i) {
144  G.reseed(&k0[i]);
145  G.random_data(&(t[0][i*length/8]), length/8);
146  G.reseed(&k1[i]);
147  G.random_data(t[1]+(i*length/8), length/8);
148  xor_arr(tmp, t[0]+(i*length/8), t[1]+(i*length/8), length/8);
149  xor_arr(tmp, block_r, tmp, length/8);
150  io->send_data(tmp, length/8);
151  }
152 
153  sse_trans(tT, t[0], l, length);
154 
155  delete[] tmp;
156  delete[] block_r;
157  }
158 
159  void ot_extension_send_post(const block* data0, const block* data1, int length) {
160  int old_length = length;
161  if (length%128 !=0) length = (length/128 + 1)*128;
162  // uint8_t *pad0 = new uint8_t[l/8];
163  uint8_t *pad1 = new uint8_t[l/8];
164  block pad[2];
165  for(int i = 0; i < old_length; ++i) {
166  xor_arr(pad1, qT+i*l/8, block_s, l/8);
167  pad[0] = xorBlocks( H(qT+i*l/8, i, l/8), data0[i]);
168  pad[1] = xorBlocks( H(pad1, i, l/8), data1[i]);
169  io->send_data(pad, 2*sizeof(block));
170  }
171  delete[] pad1;
172  delete[] qT;
173  }
174 
175  void ot_extension_recv_check(int length) {
176  if (length%128 !=0) length = (length/128 + 1)*128;
177  block seed; PRG prg;int beta;
178  uint8_t * tmp = new uint8_t[length/8];
179  char dgst[20];
180  for(int i = 0; i < u; ++i) {
181  io->recv_block(&seed, 1);
182  prg.reseed(&seed);
183  for(int j = 0; j < l; ++j) {
184  prg.random_data(&beta, 4);
185  beta = beta>0?beta:-1*beta;
186  beta %= l;
187  for(int k = 0; k < 2; ++k)
188  for(int l = 0; l < 2; ++l) {
189  xor_arr(tmp, t[k]+(j*length/8), t[l]+(beta*length/8), length/8);
190  Hash::hash_once(dgst, tmp, length/8);
191  io->send_data(dgst, 20);
192  }
193  }
194  }
195  delete []tmp;
196  }
197 
198  void ot_extension_recv_post(block* data, const bool* r, int length) {
199  int old_length = length;
200  data_open = new block[length];
201  if (length%128 !=0) length = (length/128 + 1)*128;
202  block res[2];
203  for(int i = 0; i < old_length; ++i) {
204  io->recv_data(res, 2*sizeof(block));
205  block tmp = H(tT+i*l/8, i, l/8);
206  if(r[i]) {
207  data[i] = xorBlocks(res[1], tmp);
208  data_open[i] = res[0];
209  } else {
210  data[i] = xorBlocks(res[0], tmp);
211  data_open[i] = res[1];
212  }
213  }
214  if(!committing) {
215  delete[] tT;
216  tT=nullptr;
217  }
218  }
219  bool ot_extension_send_check(int length) {
220  if (length%128 !=0) length = (length/128 + 1)*128;
221  bool cheat = false;
222  PRG prg, sprg; block seed;int beta;
223  char dgst[2][2][20]; char dgstchk[20];
224  uint8_t * tmp = new uint8_t[length/8];
225  for(int i = 0; i < u; ++i) {
226  prg.random_block(&seed, 1);
227  io->send_block(&seed, 1);
228  sprg.reseed(&seed);
229  for(int j = 0; j < l; ++j) {
230  sprg.random_data(&beta, 4);
231  beta = beta>0?beta:-1*beta;
232  beta %= l;
233  io->recv_data(dgst[0][0], 20);
234  io->recv_data(dgst[0][1], 20);
235  io->recv_data(dgst[1][0], 20);
236  io->recv_data(dgst[1][1], 20);
237 
238  int ind1 = s[j]? 1:0;
239  int ind2 = s[beta]? 1:0;
240  xor_arr(tmp, q+(j*length/8), q+(beta*length/8), length/8);
241  Hash::hash_once(dgstchk, tmp, length/8);
242  if (strncmp(dgstchk, dgst[ind1][ind2], 20)!=0)
243  cheat = true;
244  }
245  }
246  delete[] tmp;
247  return cheat;
248  }
249 
250  void send_impl(const block* data0, const block* data1, int length) {
251  ot_extension_send_pre(length);
252  assert(!ot_extension_send_check(length)?"T":"F");
253  delete[] q; q = nullptr;
254  ot_extension_send_post(data0, data1, length);
255  }
256 
257  void recv_impl(block* data, const bool* b, int length) {
258  ot_extension_recv_pre(data, b, length);
259  ot_extension_recv_check(length);
260  delete[] t[0];
261  delete[] t[1];
262  delete[] t;
263  ot_extension_recv_post(data, b, length);
264  }
265 
266  void open() {
267  io->send_data(s, l);
268  }
269  //return data[1-b]
270  void open(block * data, const bool * r, int length) {
271  io->recv_data(s, l);
272  char com_recv[10];
273  Hash::hash_once(com_recv, s, l);
274  if (strncmp(com_recv, com, 10)!= 0)
275  assert(false);
276  bool_to_uint8(block_s, s, l);
277  for(int i = 0; i < length; ++i) {
278  xor_arr(tT+i*l/8, tT+i*l/8, block_s, l/8);
279  block tmp = H(tT+i*l/8, i, l/8);
280  data[i] = xorBlocks(data_open[i], tmp);
281  }
282  delete[] tT;
283  delete[] data_open;
284  tT=nullptr;
285  data_open = nullptr;
286  }
287 };
289 #endif// OT_M_EXTENSION_ALSZ_H__
void bool_to_uint8(uint8_t *out, const bool *in, int len)
Definition: mextension_alsz.h:64
void send_data(const void *data, int nbyte)
Definition: io_channel.h:14
void recv_data(void *data, int nbyte)
Definition: io_channel.h:17
void open(block *data, const bool *r, int length)
Definition: mextension_alsz.h:270
void random_bool(bool *data, int length)
Definition: prg.h:57
void setup_recv(block *in_k0=nullptr, block *in_k1=nullptr)
Definition: mextension_alsz.h:83
__m128i block
Definition: block.h:8
void ot_extension_send_post(const block *data0, const block *data1, int length)
Definition: mextension_alsz.h:159
uint8_t * tT
Definition: mextension_alsz.h:18
void send(const block *data0, const block *data1, int length)
Definition: ot.h:10
void ot_extension_recv_post(block *data, const bool *r, int length)
Definition: mextension_alsz.h:198
void open()
Definition: mextension_alsz.h:266
OTCO * base_ot
Definition: mextension_alsz.h:10
bool committing
Definition: mextension_alsz.h:21
bool * s
Definition: mextension_alsz.h:16
block xorBlocks(block x, block y)
Definition: block.h:35
Definition: net_io_channel.h:22
void sse_trans(uint8_t *out, uint8_t const *inp, int nrows, int ncols)
Definition: block.h:85
PRP pi
Definition: mextension_alsz.h:12
bool setup
Definition: mextension_alsz.h:20
#define zero_block()
Definition: block.h:66
int ssp
Definition: mextension_alsz.h:13
void xor_arr(uint8_t *a, uint8_t *b, uint8_t *c, int n)
Definition: mextension_alsz.h:45
MOTExtension_ALSZ(NetIO *io, bool committing=false, int ssp=40)
Definition: mextension_alsz.h:23
block * data_open
Definition: mextension_alsz.h:15
char com[Hash::DIGEST_SIZE]
Definition: mextension_alsz.h:22
void random_block(block *data, int nblocks=1)
Definition: prg.h:75
block H(block in, uint64_t id)
Definition: prp.h:48
void reseed(const void *key, uint64_t id=0)
Definition: prg.h:41
Definition: prp.h:11
Definition: co.h:8
void ot_extension_recv_pre(block *data, const bool *r, int length)
Definition: mextension_alsz.h:126
NetIO * io
Definition: ot.h:9
int l
Definition: mextension_alsz.h:13
uint8_t * q
Definition: mextension_alsz.h:18
uint8_t * qT
Definition: mextension_alsz.h:18
int u
Definition: mextension_alsz.h:19
void recv(block *data, const bool *b, int length)
Definition: ot.h:13
Definition: prg.h:16
PRG prg
Definition: mextension_alsz.h:11
uint8_t * block_s
Definition: mextension_alsz.h:18
void recv_impl(block *data, const bool *b, int length)
Definition: mextension_alsz.h:257
void random_data(void *data, int nbytes)
Definition: prg.h:49
void send_impl(const block *data0, const block *data1, int length)
Definition: mextension_alsz.h:250
void ot_extension_recv_check(int length)
Definition: mextension_alsz.h:175
block * k0
Definition: mextension_alsz.h:15
block * k1
Definition: mextension_alsz.h:15
void send_block(const block *data, int nblock)
Definition: io_channel.h:132
~MOTExtension_ALSZ()
Definition: mextension_alsz.h:34
void ot_extension_send_pre(int length)
Definition: mextension_alsz.h:96
void recv_block(block *data, int nblock)
Definition: io_channel.h:136
Definition: mextension_alsz.h:9
Definition: ot.h:6
void setup_send(block *in_k0=nullptr, bool *in_s=nullptr)
Definition: mextension_alsz.h:71
block H(uint8_t *in, long id, int len)
Definition: mextension_alsz.h:55
static const int DIGEST_SIZE
Definition: hash.h:17
void xorBlocks_arr(block *res, const block *x, const block *y, int nblocks)
Definition: block.h:37
uint8_t ** t
Definition: mextension_alsz.h:18
bool ot_extension_send_check(int length)
Definition: mextension_alsz.h:219
static void hash_once(void *digest, const void *data, int nbyte)
Definition: hash.h:49