Commit e39aa1f5 authored by Jack Doerner's avatar Jack Doerner

Added broader ORAM tests, discovered bugs, started to fix them.

parent 8cab76c2
......@@ -5,6 +5,7 @@
struct bitpropagator_offline {
size_t size;
size_t blockmultiple;
size_t startlevel;
size_t endlevel;
void * Z;
......@@ -17,29 +18,28 @@ struct bitpropagator_offline {
omp_lock_t * locks;
};
void bitpropagator_offline_start(bitpropagator_offline * bpo, void * blocks) {
void bitpropagator_offline_start(bitpropagator_offline * bpo, uint8_t * blocks) {
memcpy(bpo->level_data, blocks, (1ll<<bpo->startlevel) * BLOCKSIZE);
for (int ii = 0; ii < (bpo->endlevel - bpo->startlevel); ii++) {
omp_set_lock(&bpo->locks[ii]);
}
}
void bitpropagator_offline_push_Z(bitpropagator_offline * bpo, void * Z, bool advicebit_l, bool advicebit_r, size_t level) {
void bitpropagator_offline_push_Z(bitpropagator_offline * bpo, uint8_t * Z, bool advicebit_l, bool advicebit_r, size_t level) {
memcpy(&bpo->Z[(level- bpo->startlevel - 1)*BLOCKSIZE], Z, BLOCKSIZE);
bpo->advicebits_l[level- bpo->startlevel - 1] = advicebit_l;
bpo->advicebits_r[level- bpo->startlevel - 1] = advicebit_r;
omp_unset_lock(&bpo->locks[level- bpo->startlevel - 1]);
}
void bitpropagator_offline_readblockvector(void * local_output, void * local_bit_output, bitpropagator_offline * bpo) {
void bitpropagator_offline_readblockvector(uint8_t * local_output, bool * local_bit_output, bitpropagator_offline * bpo) {
#pragma omp parallel
{
size_t thislevel = bpo->startlevel;
size_t thislevelblocks = (1ll<<bpo->startlevel);
size_t nextlevelblocks = (bpo->size + (1ll<<(bpo->endlevel - thislevel -1)) - 1) / (1ll<<(bpo->endlevel - thislevel -1));
size_t expansion_stride;
uint64_t* a = (uint64_t *)bpo->level_data;
uint8_t* a2 = (uint8_t *)bpo->level_data;
......@@ -54,14 +54,21 @@ void bitpropagator_offline_readblockvector(void * local_output, void * local_bit
bool * t_bits;
#pragma omp for
for (size_t ii = 0; ii < 2*(nextlevelblocks/4); ii+=2) {
offline_prf_quad(&b2[ii*2*BLOCKSIZE], &b2[(ii*2+1)*BLOCKSIZE], &b2[(ii*2+2)*BLOCKSIZE], &b2[(ii*2+3)*BLOCKSIZE], &a2[ii*BLOCKSIZE], &a2[ii*BLOCKSIZE], &a2[(ii+1)*BLOCKSIZE], &a2[(ii+1)*BLOCKSIZE], bpo->keyL, bpo->keyR, bpo->keyL, bpo->keyR);
for (size_t ii = 0; ii < 4*(nextlevelblocks/8); ii+=4) {
offline_prf_oct(&b2[ii*2*BLOCKSIZE], &b2[(ii*2+1)*BLOCKSIZE], &b2[(ii*2+2)*BLOCKSIZE], &b2[(ii*2+3)*BLOCKSIZE],
&b2[(ii*2+4)*BLOCKSIZE], &b2[(ii*2+5)*BLOCKSIZE], &b2[(ii*2+6)*BLOCKSIZE], &b2[(ii*2+7)*BLOCKSIZE],
&a2[ii*BLOCKSIZE], &a2[ii*BLOCKSIZE], &a2[(ii+1)*BLOCKSIZE], &a2[(ii+1)*BLOCKSIZE],
&a2[(ii+2)*BLOCKSIZE], &a2[(ii+2)*BLOCKSIZE],&a2[(ii+3)*BLOCKSIZE], &a2[(ii+3)*BLOCKSIZE],
bpo->keyL, bpo->keyR, bpo->keyL, bpo->keyR,
bpo->keyL, bpo->keyR, bpo->keyL, bpo->keyR);
a_bits[ii] = a2[ii*BLOCKSIZE] & 1;
a_bits[ii+1] = a2[(ii+1)*BLOCKSIZE] & 1;
a_bits[ii+1] = a2[(ii+1)*BLOCKSIZE] & 1;
a_bits[ii+2] = a2[(ii+2)*BLOCKSIZE] & 1;
a_bits[ii+3] = a2[(ii+3)*BLOCKSIZE] & 1;
}
#pragma omp for
for (size_t ii = 2*(nextlevelblocks/4); ii < thislevelblocks; ii++) {
for (size_t ii = 4*(nextlevelblocks/8); ii < thislevelblocks; ii++) {
if ((ii+1)*2 <= nextlevelblocks) {
offline_prf(&b2[ii*2*BLOCKSIZE], &a2[ii*BLOCKSIZE], bpo->keyL);
offline_prf(&b2[(ii*2+1)*BLOCKSIZE], &a2[ii*BLOCKSIZE], bpo->keyR);
......@@ -88,6 +95,12 @@ void bitpropagator_offline_readblockvector(void * local_output, void * local_bit
b2 = a2; b = a; b_bits = a_bits;
a2 = t2; a = t; a_bits = t_bits;
if (thislevel == bpo->endlevel - 1 && b2 == local_output) {
expansion_stride = (BLOCKSIZE * bpo->blockmultiple);
} else {
expansion_stride = BLOCKSIZE;
}
#pragma omp for
for (size_t ii = 0; ii < 4*(nextlevelblocks/8); ii+=4) {
a_bits[ii] = (a2[ii*BLOCKSIZE] & 1) ^ (b_bits[ii/2] & advicebit_l);
......@@ -110,15 +123,15 @@ void bitpropagator_offline_readblockvector(void * local_output, void * local_bit
}
}
offline_prf_oct(&b2[ii*2*BLOCKSIZE], &b2[(ii*2+1)*BLOCKSIZE], &b2[(ii*2+2)*BLOCKSIZE], &b2[(ii*2+3)*BLOCKSIZE],
&b2[(ii*2+4)*BLOCKSIZE], &b2[(ii*2+5)*BLOCKSIZE], &b2[(ii*2+6)*BLOCKSIZE], &b2[(ii*2+7)*BLOCKSIZE],
offline_prf_oct(&b2[ii*2*expansion_stride], &b2[(ii*2+1)*expansion_stride], &b2[(ii*2+2)*expansion_stride], &b2[(ii*2+3)*expansion_stride],
&b2[(ii*2+4)*expansion_stride], &b2[(ii*2+5)*expansion_stride], &b2[(ii*2+6)*expansion_stride], &b2[(ii*2+7)*expansion_stride],
&a2[ii*BLOCKSIZE], &a2[ii*BLOCKSIZE], &a2[(ii+1)*BLOCKSIZE], &a2[(ii+1)*BLOCKSIZE],
&a2[(ii+2)*BLOCKSIZE], &a2[(ii+2)*BLOCKSIZE], &a2[(ii+3)*BLOCKSIZE], &a2[(ii+3)*BLOCKSIZE],
bpo->keyL, bpo->keyR, bpo->keyL, bpo->keyR,
bpo->keyL, bpo->keyR, bpo->keyL, bpo->keyR);
}
#pragma omp for
#pragma omp single
for (size_t ii = 4*(nextlevelblocks/8); ii < thislevelblocks ; ii++) {
if (ii%2 == 0) {
......@@ -135,10 +148,10 @@ void bitpropagator_offline_readblockvector(void * local_output, void * local_bit
}
if ((ii+1)*2 <= nextlevelblocks) {
offline_prf(&b2[ii*2*BLOCKSIZE], &a2[ii*BLOCKSIZE], bpo->keyL);
offline_prf(&b2[(ii*2+1)*BLOCKSIZE], &a2[ii*BLOCKSIZE], bpo->keyR);
offline_prf(&b2[ii*2*expansion_stride], &a2[ii*BLOCKSIZE], bpo->keyL);
offline_prf(&b2[(ii*2+1)*expansion_stride], &a2[ii*BLOCKSIZE], bpo->keyR);
} else if (ii*2+1 <= nextlevelblocks) {
offline_prf(&b2[ii*2*BLOCKSIZE], &a2[ii*BLOCKSIZE], bpo->keyL);
offline_prf(&b2[ii*2*expansion_stride], &a2[ii*BLOCKSIZE], bpo->keyL);
}
}
}
......@@ -169,10 +182,10 @@ void bitpropagator_offline_readblockvector(void * local_output, void * local_bit
if (b_bits[ii/2]) {
#pragma omp simd aligned(b,a,z:16)
for (uint8_t jj = 0; jj < BLOCKSIZE/sizeof(uint64_t); jj++) {
b[ii*(BLOCKSIZE/sizeof(uint64_t))+jj] = a[ii*(BLOCKSIZE/sizeof(uint64_t))+jj] ^ z[jj];
b[ii*((BLOCKSIZE * bpo->blockmultiple)/sizeof(uint64_t))+jj] = a[ii*(BLOCKSIZE/sizeof(uint64_t))+jj] ^ z[jj];
}
} else {
memcpy(&b[ii*(BLOCKSIZE/sizeof(uint64_t))], &a[ii*(BLOCKSIZE/sizeof(uint64_t))], BLOCKSIZE);
memcpy(&b[ii*((BLOCKSIZE * bpo->blockmultiple)/sizeof(uint64_t))], &a[ii*(BLOCKSIZE/sizeof(uint64_t))], BLOCKSIZE);
}
} else {
if (ii%2 == 0) {
......@@ -184,7 +197,7 @@ void bitpropagator_offline_readblockvector(void * local_output, void * local_bit
if (b_bits[ii/2]) {
#pragma omp simd aligned(b,a,z:16)
for (uint8_t jj = 0; jj < BLOCKSIZE/sizeof(uint64_t); jj++) {
a[ii*(BLOCKSIZE/sizeof(uint64_t))+jj] ^= z[jj];
a[ii*((BLOCKSIZE * bpo->blockmultiple)/sizeof(uint64_t))+jj] ^= z[jj];
}
}
}
......@@ -192,6 +205,53 @@ void bitpropagator_offline_readblockvector(void * local_output, void * local_bit
if (b == local_output) memcpy(b_bits, a_bits, thislevelblocks*sizeof(bool));
if (bpo->blockmultiple > 1) {
#pragma omp for
for (size_t ii = 0; ii < 8*(thislevelblocks/8); ii+=8) {
for (size_t jj = 1; jj < bpo->blockmultiple; jj++) {
// Note to self: this is ridiculous. Define a macro.
// Further note to self: actually, the non-encapsulation of offline_prf_oct and offline_prf
// is just as ridiculous, if not more so. There has to be a better way. TODO: find it.
offline_prf_oct(
&local_output[(ii+0) * (BLOCKSIZE*bpo->blockmultiple) + (jj * BLOCKSIZE)],
&local_output[(ii+1) * (BLOCKSIZE*bpo->blockmultiple) + (jj * BLOCKSIZE)],
&local_output[(ii+2) * (BLOCKSIZE*bpo->blockmultiple) + (jj * BLOCKSIZE)],
&local_output[(ii+3) * (BLOCKSIZE*bpo->blockmultiple) + (jj * BLOCKSIZE)],
&local_output[(ii+4) * (BLOCKSIZE*bpo->blockmultiple) + (jj * BLOCKSIZE)],
&local_output[(ii+5) * (BLOCKSIZE*bpo->blockmultiple) + (jj * BLOCKSIZE)],
&local_output[(ii+6) * (BLOCKSIZE*bpo->blockmultiple) + (jj * BLOCKSIZE)],
&local_output[(ii+7) * (BLOCKSIZE*bpo->blockmultiple) + (jj * BLOCKSIZE)],
&local_output[(ii+0) * (BLOCKSIZE*bpo->blockmultiple) + ((jj-1) * BLOCKSIZE)],
&local_output[(ii+1) * (BLOCKSIZE*bpo->blockmultiple) + ((jj-1) * BLOCKSIZE)],
&local_output[(ii+2) * (BLOCKSIZE*bpo->blockmultiple) + ((jj-1) * BLOCKSIZE)],
&local_output[(ii+3) * (BLOCKSIZE*bpo->blockmultiple) + ((jj-1) * BLOCKSIZE)],
&local_output[(ii+4) * (BLOCKSIZE*bpo->blockmultiple) + ((jj-1) * BLOCKSIZE)],
&local_output[(ii+5) * (BLOCKSIZE*bpo->blockmultiple) + ((jj-1) * BLOCKSIZE)],
&local_output[(ii+6) * (BLOCKSIZE*bpo->blockmultiple) + ((jj-1) * BLOCKSIZE)],
&local_output[(ii+7) * (BLOCKSIZE*bpo->blockmultiple) + ((jj-1) * BLOCKSIZE)],
bpo->keyL,
bpo->keyL,
bpo->keyL,
bpo->keyL,
bpo->keyL,
bpo->keyL,
bpo->keyL,
bpo->keyL
);
}
}
#pragma omp single
for (size_t ii = 8*(thislevelblocks/8); ii < thislevelblocks ; ii++) {
for (size_t jj = 1; jj < bpo->blockmultiple; jj++) {
offline_prf(
&local_output[(ii) * (BLOCKSIZE*bpo->blockmultiple) + (jj * BLOCKSIZE)],
&local_output[(ii) * (BLOCKSIZE*bpo->blockmultiple) + ((jj-1) * BLOCKSIZE)],
bpo->keyL
);
}
}
}
}
for (int ii = 0; ii < (bpo->endlevel - bpo->startlevel); ii++) {
......@@ -226,10 +286,11 @@ void bitpropagator_offline_parallelizer(void* bp, bitpropagator_offline * bpo, v
}
}
bitpropagator_offline * bitpropagator_offline_new(size_t size, size_t startlevel, uint8_t * keyL, uint8_t * keyR) {
bitpropagator_offline * bitpropagator_offline_new(size_t size, size_t blockmultiple, size_t startlevel, uint8_t * keyL, uint8_t * keyR) {
offline_expand_init();
bitpropagator_offline * bpo = malloc(sizeof(bitpropagator_offline));
bpo->size = size;
bpo->blockmultiple = blockmultiple;
bpo->startlevel = startlevel;
bpo->endlevel = LOG2LL(size) + (((1 << LOG2LL(size)) < size)? 1:0);
posix_memalign(&bpo->level_data,16,(1ll<<bpo->endlevel) * BLOCKSIZE);
......
......@@ -8,13 +8,13 @@ typedef void (* bp_traverser_fn)(void *, void *);
typedef void (* bp_pusher_fn)(void *, void *, void *);
typedef void (* facb_fn)(void *, void*);
void bitpropagator_offline_start(bitpropagator_offline * bpo, void * blocks);
void bitpropagator_offline_push_Z(bitpropagator_offline * bpo, void * Z, bool advicebit_l, bool advicebit_r, size_t level);
void bitpropagator_offline_readblockvector(void * local_output, void* local_bit_output, bitpropagator_offline * bpo);
void bitpropagator_offline_start(bitpropagator_offline * bpo, uint8_t * blocks);
void bitpropagator_offline_push_Z(bitpropagator_offline * bpo, uint8_t * Z, bool advicebit_l, bool advicebit_r, size_t level);
void bitpropagator_offline_readblockvector(uint8_t * local_output, bool* local_bit_output, bitpropagator_offline * bpo);
void bitpropagator_offline_parallelizer(void* bp, bitpropagator_offline * bpo, void* indexp, void * local_output, void* local_bit_output, void* pd, bp_traverser_fn fn, bp_pusher_fn fn2, facb_fn cbfn, void* cbpass);
bitpropagator_offline * bitpropagator_offline_new(size_t size, size_t startlevel, uint8_t * keyL, uint8_t * keyR);
bitpropagator_offline * bitpropagator_offline_new(size_t size, size_t blockmultiple, size_t startlevel, uint8_t * keyL, uint8_t * keyR);
void bitpropagator_offline_free(bitpropagator_offline * bpo);
#endif
\ No newline at end of file
......@@ -9,6 +9,7 @@ struct bitpropagator {
uint32_t startlevel;
uint32_t endlevel;
size_t size;
size_t blockmultiple;
obliv uint8_t * toplevel;
uint8_t * toplevel_local;
obliv uint8_t * blockzero;
......@@ -119,6 +120,10 @@ void bitpropagator_traverselevels(bitpropagator * bp, obliv size_t * indexp) {
control_bit_A = control_bit_A_next;
control_bit_B = control_bit_B_next;
}
for (size_t ii = 1; ii < bp->blockmultiple; ii++) {
online_prf_double(&bp->activeblock_A[BLOCKSIZE * ii], &bp->activeblock_B[BLOCKSIZE * ii], &bp->activeblock_A[BLOCKSIZE * (ii-1)], &bp->activeblock_B[BLOCKSIZE * (ii-1)], bp->keyL, bp->keyL);
}
}
void bitpropagator_getblockvector(obliv uint8_t * activeblock_delta, uint8_t * local_output, bool * local_bit_output, bitpropagator * bp, obliv size_t index) {
......@@ -159,8 +164,8 @@ void bitpropagator_getblockvector_with_callback(obliv uint8_t * activeblock_delt
cleanupProtocol(&pd2);
//write output
ocCopyN(&ocCopyChar, activeblock_delta, bp->activeblock_A, BLOCKSIZE);
for (size_t ii = 0; ii < BLOCKSIZE/sizeof(uint64_t); ii++) {
ocCopyN(&ocCopyChar, activeblock_delta, bp->activeblock_A, BLOCKSIZE*bp->blockmultiple);
for (size_t ii = 0; ii < (BLOCKSIZE*bp->blockmultiple)/sizeof(uint64_t); ii++) {
((obliv uint64_t *)activeblock_delta)[ii] ^= ((obliv uint64_t *)bp->activeblock_B)[ii];
}
}
......@@ -170,10 +175,11 @@ void bitpropagator_getadvice(obliv bool * advicebits, obliv uint8_t * blocks_A,
advicebits[1] = ((obliv bool *)blocks_A)[BLOCKSIZE*8] ^ ((obliv bool *)blocks_B)[BLOCKSIZE*8] ^ rightblock;
}
bitpropagator * bitpropagator_new(size_t size, uint32_t startlevel) {
bitpropagator * bitpropagator_new(size_t size, size_t blockmultiple, uint32_t startlevel) {
online_expand_init();
bitpropagator * bp = malloc(sizeof(bitpropagator));
bp->size = size;
bp->blockmultiple = blockmultiple;
bp->endlevel = LOG2LL(size) + (((1ll << LOG2LL(size)) < size)? 1:0);
bp->startlevel = MIN(startlevel,bp->endlevel-1);
bp->toplevel = calloc((1ll << bp->startlevel) + 1, BLOCKSIZE * sizeof(obliv uint8_t));
......@@ -181,8 +187,8 @@ bitpropagator * bitpropagator_new(size_t size, uint32_t startlevel) {
bp->blockzero = calloc(1, BLOCKSIZE * sizeof(obliv uint8_t));
bp->blocktemp_A = calloc((1ll << bp->startlevel), BLOCKSIZE * sizeof(obliv uint8_t));
bp->blocktemp_B = calloc((1ll << bp->startlevel), BLOCKSIZE * sizeof(obliv uint8_t));
bp->activeblock_A = calloc(1, BLOCKSIZE * sizeof(obliv uint8_t));
bp->activeblock_B = calloc(1, BLOCKSIZE * sizeof(obliv uint8_t));
bp->activeblock_A = calloc(1, BLOCKSIZE * blockmultiple * sizeof(obliv uint8_t));
bp->activeblock_B = calloc(1, BLOCKSIZE * blockmultiple * sizeof(obliv uint8_t));
bp->expanded_A = calloc(2, BLOCKSIZE * sizeof(obliv uint8_t));
bp->expanded_B = calloc(2, BLOCKSIZE * sizeof(obliv uint8_t));
bp->Z = calloc((bp->endlevel - bp->startlevel), BLOCKSIZE * sizeof(obliv uint8_t));
......@@ -202,7 +208,7 @@ bitpropagator * bitpropagator_new(size_t size, uint32_t startlevel) {
}
online_prf_keyschedule_double(&bp->keyL, &bp->keyR, keyL, keyR);
bp->bpo = bitpropagator_offline_new(size, bp->startlevel, keyL, keyR);
bp->bpo = bitpropagator_offline_new(size, blockmultiple, bp->startlevel, keyL, keyR);
free(keyL);
free(keyR);
......
......@@ -9,7 +9,7 @@ typedef void (* facb_fn)(void *, void*);
void bitpropagator_getblockvector(obliv uint8_t * activeblock_delta, uint8_t * local_output, bool * local_bit_output, bitpropagator * bp, obliv size_t index);
void bitpropagator_getblockvector_with_callback(obliv uint8_t * activeblock_delta, uint8_t * local_output, bool * local_bit_output, bitpropagator * bp, obliv size_t index, facb_fn cbfn, void* cbpass);
void bitpropagator_getadvice(obliv bool * advicebits, obliv uint8_t * blocks_A, obliv uint8_t * blocks_B, obliv bool rightblock);
bitpropagator * bitpropagator_new(size_t size, uint32_t truncated_levels);
bitpropagator * bitpropagator_new(size_t size, size_t blockmultiple, uint32_t truncated_levels);
void bitpropagator_free(bitpropagator * bp);
#endif
\ No newline at end of file
This diff is collapsed.
......@@ -13,7 +13,7 @@ void bitpropagator_cprg_offline_finalize(uint8_t * accumulator, uint8_t * z, boo
void bitpropagator_cprg_offline_parallelizer(void* bp, void* indexp, void* blockdelta, void * local_output, void * local_bit_output, void* pd, bp_cprg_traverser_fn fn, facb_fn cbfn, void* cbpass);
bitpropagator_cprg_offline * bitpropagator_cprg_offline_new(size_t size, uint8_t * keyL, uint8_t * keyR);
bitpropagator_cprg_offline * bitpropagator_cprg_offline_new(size_t size, size_t blockmultiple, uint8_t * keyL, uint8_t * keyR);
void bitpropagator_cprg_offline_free(bitpropagator_cprg_offline * bpo);
#endif
\ No newline at end of file
......@@ -10,6 +10,7 @@ struct bitpropagator_cprg {
uint32_t startlevel;
uint32_t endlevel;
size_t size;
size_t blockmultiple;
obliv uint8_t * diff_L;
obliv uint8_t * diff_R;
obliv uint8_t * Z;
......@@ -59,7 +60,7 @@ void bitpropagator_cprg_traverselevels(obliv uint8_t * active_block_delta, uint8
}
}
ocFromSharedCharN(ocCurrentProto(), active_block_delta, bp->L_local, BLOCKSIZE);
ocFromSharedCharN(ocCurrentProto(), active_block_delta, bp->L_local, BLOCKSIZE*bp->blockmultiple);
}
void bitpropagator_cprg_getblockvector(obliv uint8_t * active_block_delta, uint8_t * local_output, bool * local_bit_output, bitpropagator_cprg * bp, obliv size_t index) {
......@@ -85,9 +86,10 @@ void bitpropagator_cprg_getadvice(obliv bool * advicebits, obliv uint8_t * diff_
advicebits[1] = ((obliv bool *)diff_R)[0] ^ rightblock;
}
bitpropagator_cprg * bitpropagator_cprg_new(size_t size) {
bitpropagator_cprg * bitpropagator_cprg_new(size_t size, size_t blockmultiple) {
bitpropagator_cprg * bp = malloc(sizeof(bitpropagator_cprg));
bp->size = size;
bp->blockmultiple = blockmultiple;
bp->startlevel = 0;
bp->endlevel = LOG2LL(size) + (((1ll << LOG2LL(size)) < size)? 1:0);
bp->diff_L = calloc(1, BLOCKSIZE * sizeof(obliv uint8_t));
......@@ -95,7 +97,7 @@ bitpropagator_cprg * bitpropagator_cprg_new(size_t size) {
bp->Z = calloc(1, BLOCKSIZE * sizeof(obliv uint8_t));
bp->advicebits = calloc(1, 2*sizeof(obliv bool));
bp->Z_local = malloc(BLOCKSIZE);
floram_pma(&bp->L_local, 16, BLOCKSIZE);
floram_pma(&bp->L_local, 16, BLOCKSIZE*blockmultiple);
floram_pma(&bp->R_local, 16, BLOCKSIZE);
//Generator chooses keys so that we don't incur round trips. Since we're semi-honest I guess it's OK? Should probably change it anyway.
......@@ -110,7 +112,7 @@ bitpropagator_cprg * bitpropagator_cprg_new(size_t size) {
for (size_t ii=0; ii< KEYSIZE/sizeof(uint64_t);ii++) ((uint64_t *)keyL)[ii] = ocBroadcastLLong(NULL,1);
for (size_t ii=0; ii< KEYSIZE/sizeof(uint64_t);ii++) ((uint64_t *)keyR)[ii] = ocBroadcastLLong(NULL,1);
}
bp->bpo = bitpropagator_cprg_offline_new(size, keyL, keyR);
bp->bpo = bitpropagator_cprg_offline_new(size, bp->blockmultiple, keyL, keyR);
free(keyL);
free(keyR);
......
......@@ -8,7 +8,7 @@ typedef struct bitpropagator_cprg bitpropagator_cprg;
void bitpropagator_cprg_getblockvector(obliv uint8_t * active_block_delta, uint8_t * local_output, bool * local_bit_output, bitpropagator_cprg * bp, obliv size_t index);
void bitpropagator_cprg_getblockvector_with_callback(obliv uint8_t * active_block_delta, uint8_t * local_output, bool * local_bit_output, bitpropagator_cprg * bp, obliv size_t index, facb_fn cbfn, void* cbpass);
void bitpropagator_cprg_getadvice(obliv bool * advicebits, obliv uint8_t * diff_L, obliv uint8_t * diff_R, obliv bool rightblock);
bitpropagator_cprg * bitpropagator_cprg_new(size_t size);
bitpropagator_cprg * bitpropagator_cprg_new(size_t size, size_t blockmultiple);
void bitpropagator_cprg_free(bitpropagator_cprg * bp);
#endif
\ No newline at end of file
......@@ -9,7 +9,8 @@
struct floram {
OcCopy* cpy;
OcCopy blockcpy;
OcCopy memblockcpy;
size_t memblocksize;
scanwrom* wrom;
scanrom* rom;
void* bitpropagator;
......@@ -58,12 +59,12 @@ void floram_scan_callback(facb_pass * input, ProtocolDesc *pd) {
obliv size_t subblockid = input->subblockid;
obliv bool found = false;
if (ram->progress > 0) {
ocCopy(&ram->blockcpy, element(&ram->blockcpy, ram->stash, ram->progress), ram->stash);
ocCopy(&ram->memblockcpy, element(&ram->memblockcpy, ram->stash, ram->progress), ram->stash);
ram->stashi[ram->progress] = ram->stashi[0];
ram->stashi[0] = -1;
for (size_t ii = 1; ii <= ram->progress; ii ++) {
obliv if (blockid == ram->stashi[ii]) {
ocCopy(&ram->blockcpy, ram->stash, element(&ram->blockcpy, ram->stash, ii));
ocCopy(&ram->memblockcpy, ram->stash, element(&ram->memblockcpy, ram->stash, ii));
ram->stashi[0] = ram->stashi[ii];
ram->stashi[ii] = -1;
found = true;
......@@ -102,13 +103,13 @@ void floram_apply(floram* ram, void* data, floram_block_access_function fn, obli
ram->stashi[0] = blockid;
}
ocCopy(&ram->blockcpy, ram->blocktemp, ram->stash);
ocCopy(&ram->memblockcpy, ram->blocktemp, ram->stash);
for (uint32_t jj = 0; jj < ram->elementsperblock; jj ++) {
obliv if (subblockid == jj) fn(ram->cpy, element(ram->cpy, ram->stash, jj), data);
}
scanwrom_write_with_blockvector(ram->wrom, ram->activeblock_delta, ram->blockvector_local, ram->bitvector_local, element(&ram->blockcpy, ram->stash, 0), ram->blocktemp);
scanwrom_write_with_blockvector(ram->wrom, ram->activeblock_delta, ram->blockvector_local, ram->bitvector_local, element(&ram->memblockcpy, ram->stash, 0), ram->blocktemp);
~obliv() {
ram->progress++;
if (ram->progress == ram->period) floram_refresh(ram);
......@@ -142,7 +143,7 @@ void floram_apply_public(floram* ram, void* data, floram_block_access_function f
if (subblockid == jj) fn(ram->cpy, element(ram->cpy, ram->blocktemp, jj), data);
}
for (size_t jj = 0; jj < ram->progress; jj++) {
obliv if (blockid == ram->stashi[jj]) ocCopy(&ram->blockcpy, element(&ram->blockcpy, ram->stash, jj), ram->blocktemp);
obliv if (blockid == ram->stashi[jj]) ocCopy(&ram->memblockcpy, element(&ram->memblockcpy, ram->stash, jj), ram->blocktemp);
}
~obliv() scanwrom_write(ram->wrom, ram->blocktemp, blockid);
}
......@@ -159,7 +160,7 @@ floram* floram_new(OcCopy* cpy, void* data, size_t n, bool cprg, bool from_share
ram->cprg = cprg;
size_t elementsize = cpy->eltsize/sizeof(obliv uint8_t);
if (elementsize >= BLOCKSIZE/2) {
if (elementsize > BLOCKSIZE/2) {
ram->blockcount = n;
ram->blockmultiple = ((elementsize / BLOCKSIZE) + (elementsize%BLOCKSIZE?1:0));
ram->elementsperblock = 1;
......@@ -169,31 +170,30 @@ floram* floram_new(OcCopy* cpy, void* data, size_t n, bool cprg, bool from_share
ram->blockcount = (n/ram->elementsperblock) + (n%ram->elementsperblock?1:0);
}
ram->blockcpy=ocCopyCharN(ram->blockmultiple * BLOCKSIZE);
ram->memblocksize = BLOCKSIZE * ram->blockmultiple;
ram->memblockcpy=ocCopyCharN(ram->memblocksize);
if (cprg) {
ram->bgb = bitpropagator_cprg_getblockvector;
ram->bgbc = bitpropagator_cprg_getblockvector_with_callback;
ram->bf = bitpropagator_cprg_free;
ram->bitpropagator = bitpropagator_cprg_new(ram->blockcount);
ram->bitpropagator = bitpropagator_cprg_new(ram->blockcount, ram->blockmultiple);
} else {
ram->bgb = bitpropagator_getblockvector;
ram->bgbc = bitpropagator_getblockvector_with_callback;
ram->bf = bitpropagator_free;
ram->bitpropagator = bitpropagator_new(ram->blockcount, MIN(5, LOG2LL(ram->blockcount)));
ram->bitpropagator = bitpropagator_new(ram->blockcount, ram->blockmultiple, MIN(5, LOG2LL(ram->blockcount)));
}
floram_pma(&ram->blockvector_local, 16, ram->blockcount * BLOCKSIZE);
floram_pma(&ram->blockvector_local, 16, ram->blockcount * ram->memblocksize);
floram_pma(&ram->bitvector_local, 16, ram->blockcount * sizeof(bool));
ram->blocktemp_local = malloc(ram->blockmultiple * BLOCKSIZE);
ram->activeblock_delta = calloc(1, BLOCKSIZE * sizeof(obliv uint8_t));
//ram->period = (uint32_t)ceil(sqrt(ram->blockcount));
ram->blocktemp_local = malloc(ram->memblocksize);
ram->activeblock_delta = calloc(1, ram->memblocksize * sizeof(obliv uint8_t));
// Based on B = 128*b; c = B*p/2+n*b/p = 64*p+n/p; dc/dp = 64-n/p^2; dc/dp = 0 when p = sqrt(n)/8
ram->period = (uint32_t)ceil(sqrt(ram->blockcount)/(8));
ram->blocktemp = calloc(1, ram->blockcpy.eltsize);
ram->stash = calloc(ram->period, ram->blockcpy.eltsize);
ram->period = (uint32_t)ceil(sqrt(ram->blockcount)/8);
ram->blocktemp = calloc(1, ram->memblockcpy.eltsize);
ram->stash = calloc(ram->period, ram->memblockcpy.eltsize);
ram->stashi = calloc(ram->period, sizeof(obliv size_t));
for (size_t ii = 0; ii < ram->period; ii++) {
ram->stashi[ii] = -1;
......@@ -207,15 +207,24 @@ floram* floram_new(OcCopy* cpy, void* data, size_t n, bool cprg, bool from_share
//Now fill the scanrom with data, if there is data with which to fill it
if (data != NULL) {
uint8_t * loadtemp_local;
floram_pma(&loadtemp_local, 16, ram->memblocksize*ram->blockcount);
if (from_shares) {
ram->wrom = scanwrom_new(&ram->blockcpy, ram->blockcount);
scanwrom_write_xor_shares(ram->wrom, data, 0, ram->blockcount);
size_t blockid;
for (size_t ii = 0; ii < ram->size; ii++) {
blockid = ii / ram->elementsperblock;
} else {
size_t subblocksize = ram->elementsperblock*(cpy->eltsize/sizeof(obliv uint8_t));
uint8_t * loadtemp_local;
floram_pma(&loadtemp_local, 16, ram->blockcount * ram->blockmultiple * BLOCKSIZE);
memcpy(&loadtemp_local[blockid * ram->memblocksize], &data[blockid * subblocksize], subblocksize);
}
ram->wrom = scanwrom_new(ram->memblocksize, ram->blockcount);
scanwrom_write_xor_shares(ram->wrom, loadtemp_local, 0, ram->blockcount);
} else {
size_t blockid, subblockid;
for (size_t ii = 0; ii < ram->size; ii++) {
......@@ -224,24 +233,25 @@ floram* floram_new(OcCopy* cpy, void* data, size_t n, bool cprg, bool from_share
size_t elosize = cpy->eltsize/sizeof(obliv uint8_t);
ocToSharedCharN(ocCurrentProto(), &loadtemp_local[blockid * BLOCKSIZE * ram->blockmultiple + subblockid*elosize], ((obliv uint8_t *)element(cpy,data,ii)), elosize);
ocToSharedCharN(ocCurrentProto(), &loadtemp_local[blockid * ram->memblocksize + subblockid*elosize], ((obliv uint8_t *)element(cpy,data,ii)), elosize);
}
ram->wrom = scanwrom_new(&ram->blockcpy, ram->blockcount);
ram->wrom = scanwrom_new(ram->memblocksize, ram->blockcount);
scanwrom_write_xor_shares(ram->wrom, loadtemp_local, 0, ram->blockcount);
free(loadtemp_local);
}
ram->rom = scanrom_new(&ram->blockcpy, ram->blockcount, ram->rom_key_half);
free(loadtemp_local);
ram->rom = scanrom_new(ram->memblocksize, ram->blockcount, ram->rom_key_half);
scanrom_import_from_scanwrom(ram->rom, ram->wrom);
} else {
ram->wrom = scanwrom_new(&ram->blockcpy, ram->blockcount);
ram->wrom = scanwrom_new(ram->memblocksize, ram->blockcount);
scanwrom_clear(ram->wrom);
ram->rom = scanrom_new(&ram->blockcpy, ram->blockcount, ram->rom_key_half);
ram->rom = scanrom_new(ram->memblocksize, ram->blockcount, ram->rom_key_half);
scanrom_clear(ram->rom);
}
......
This diff is collapsed.
......@@ -3,25 +3,28 @@
#include <wmmintrin.h>
#include <tmmintrin.h>
void scanrom_read_with_bitvector_offline(uint8_t * data, uint8_t * local_data, bool * bitvector, size_t fullblocksize, size_t blockcount) {
memset(data, 0, fullblocksize);
void scanrom_read_with_bitvector_offline(uint8_t * data, uint8_t * local_data, bool * bitvector, size_t memblocksize, size_t blockcount) {
memset(data, 0, memblocksize);
uint64_t * d = local_data;
bool * b = bitvector;
for (size_t jj = 0; jj < fullblocksize /sizeof(uint64_t); jj++) {
for (size_t jj = 0; jj < memblocksize /sizeof(uint64_t); jj++) {
uint64_t sum = 0;
#pragma omp parallel for simd aligned(d,b:16) reduction(^:sum)
for (size_t ii = 0; ii < blockcount; ii++) {
if (b[ii]) {
sum ^= d[ii * ((fullblocksize) /sizeof(uint64_t)) + jj];
sum ^= d[ii * ((memblocksize) /sizeof(uint64_t)) + jj];
}
}
((uint64_t *)data)[jj] = sum;
}
}
void scanrom_encrypt_offline(uint8_t * out, uint8_t * in, uint8_t* key, size_t index, size_t len) {
void scanrom_encrypt_offline(uint8_t * out, uint8_t * in, uint8_t* key, size_t index, size_t blockmultiple, size_t blockcount) {
#define OCTBLOCK(II,JJ) &o[(II+JJ)*BLOCKSIZE]
uint8_t * o = out;
uint8_t * i = in;
......@@ -32,7 +35,7 @@ void scanrom_encrypt_offline(uint8_t * out, uint8_t * in, uint8_t* key, size_t i
if (in == NULL) {
size_t ii;
#pragma omp parallel for
for (ii = index/BLOCKSIZE; ii < ((index+len) / BLOCKSIZE)-((index+len) / BLOCKSIZE)%8; ii+= 8) {
for (ii = index*blockmultiple; ii < ((index+blockcount)*blockmultiple)-((index+blockcount)*blockmultiple)%8; ii+= 8) {
mr1 = _mm_shuffle_epi8 (_mm_set_epi64((__m64)ii,(__m64)0l), mask);
mr2 = _mm_shuffle_epi8 (_mm_set_epi64((__m64)ii+1,(__m64)0l), mask);
mr3 = _mm_shuffle_epi8 (_mm_set_epi64((__m64)ii+2,(__m64)0l), mask);
......@@ -41,18 +44,17 @@ void scanrom_encrypt_offline(uint8_t * out, uint8_t * in, uint8_t* key, size_t i
mr6 = _mm_shuffle_epi8 (_mm_set_epi64((__m64)ii+5,(__m64)0l), mask);
mr7 = _mm_shuffle_epi8 (_mm_set_epi64((__m64)ii+6,(__m64)0l), mask);
mr8 = _mm_shuffle_epi8 (_mm_set_epi64((__m64)ii+7,(__m64)0l), mask);
offline_prf_oct(&o[ii*BLOCKSIZE], &o[(ii+1)*BLOCKSIZE], &o[(ii+2)*BLOCKSIZE], &o[(ii+3)*BLOCKSIZE],
&o[(ii+4)*BLOCKSIZE], &o[(ii+5)*BLOCKSIZE], &o[(ii+6)*BLOCKSIZE], &o[(ii+6)*BLOCKSIZE],
offline_prf_oct(OCTBLOCK(ii,0),OCTBLOCK(ii,1),OCTBLOCK(ii,2),OCTBLOCK(ii,3),OCTBLOCK(ii,4),OCTBLOCK(ii,5),OCTBLOCK(ii,6),OCTBLOCK(ii,7),
&mr1, &mr2, &mr3, &mr4, &mr5, &mr6, &mr7, &mr8, kex, kex, kex, kex, kex, kex, kex, kex);
}
for (; ii < (index+len) / BLOCKSIZE; ii+= 1) {
for (; ii < ((index+blockcount)*blockmultiple); ii+= 1) {
__m128i mr1 = _mm_shuffle_epi8 (_mm_set_epi64((__m64)ii,(__m64)0l), mask);
offline_prf(&o[ii*BLOCKSIZE], &mr1, kex);
}
} else {
size_t ii;
#pragma omp parallel for
for (ii = index/BLOCKSIZE; ii < ((index+len) / BLOCKSIZE)-((index+len) / BLOCKSIZE)%8; ii+= 8) {
for (ii = index*blockmultiple; ii < ((index+blockcount)*blockmultiple)-((index+blockcount)*blockmultiple)%8; ii+= 8) {
mr1 = _mm_shuffle_epi8 (_mm_set_epi64((__m64)ii,(__m64)0l), mask);
mr2 = _mm_shuffle_epi8 (_mm_set_epi64((__m64)ii+1,(__m64)0l), mask);
mr3 = _mm_shuffle_epi8 (_mm_set_epi64((__m64)ii+2,(__m64)0l), mask);
......@@ -61,20 +63,19 @@ void scanrom_encrypt_offline(uint8_t * out, uint8_t * in, uint8_t* key, size_t i
mr6 = _mm_shuffle_epi8 (_mm_set_epi64((__m64)ii+5,(__m64)0l), mask);
mr7 = _mm_shuffle_epi8 (_mm_set_epi64((__m64)ii+6,(__m64)0l), mask);
mr8 = _mm_shuffle_epi8 (_mm_set_epi64((__m64)ii+7,(__m64)0l), mask);
offline_prf_oct(&o[ii*BLOCKSIZE], &o[(ii+1)*BLOCKSIZE], &o[(ii+2)*BLOCKSIZE], &o[(ii+3)*BLOCKSIZE],
&o[(ii+4)*BLOCKSIZE], &o[(ii+5)*BLOCKSIZE], &o[(ii+6)*BLOCKSIZE], &o[(ii+6)*BLOCKSIZE],
offline_prf_oct(OCTBLOCK(ii,0),OCTBLOCK(ii,1),OCTBLOCK(ii,2),OCTBLOCK(ii,3),OCTBLOCK(ii,4),OCTBLOCK(ii,5),OCTBLOCK(ii,6),OCTBLOCK(ii,7),
&mr1, &mr2, &mr3, &mr4, &mr5, &mr6, &mr7, &mr8, kex, kex, kex, kex, kex, kex, kex, kex);
#pragma omp simd aligned(o,i:16)
for (uint8_t jj=0;jj<4*BLOCKSIZE;jj++) {
o[ii*BLOCKSIZE+jj] ^= i[ii*BLOCKSIZE+jj];
#pragma omp simd aligned(o:16)
for (size_t jj = ii*BLOCKSIZE; jj < (ii+8)*BLOCKSIZE; jj++) {
o[jj] ^= i[jj];
}
}
for (; ii < (index+len) / BLOCKSIZE; ii+= 1) {
for (; ii < ((index+blockcount)*blockmultiple); ii+= 1) {
__m128i mr1 = _mm_shuffle_epi8 (_mm_set_epi64((__m64)ii,(__m64)0l), mask);
offline_prf(&o[ii*BLOCKSIZE], &mr1, kex);
#pragma omp simd aligned(o,i:16)
for (uint8_t jj=0;jj<BLOCKSIZE;jj++) {
o[ii*BLOCKSIZE+jj] ^= i[ii*BLOCKSIZE+jj];
for (size_t jj = ii*BLOCKSIZE; jj < (ii+1)*BLOCKSIZE; jj++) {
o[jj] ^= i[jj];
}
}
}
......@@ -83,7 +84,7 @@ void scanrom_encrypt_offline(uint8_t * out, uint8_t * in, uint8_t* key, size_t i
}
void scanwrom_write_with_blockvector_offline(uint8_t * local_data, uint8_t * blockvector, bool * controlbitvector, uint8_t*Zblock, bool expand, size_t fullblocksize, size_t blockcount) {
void scanwrom_write_with_blockvector_offline(uint8_t * local_data, uint8_t * blockvector, bool * controlbitvector, uint8_t*Zblock, size_t memblocksize, size_t blockcount) {
uint64_t * d = local_data;
uint64_t * b = blockvector;
uint64_t * z = Zblock;
......@@ -92,32 +93,14 @@ void scanwrom_write_with_blockvector_offline(uint8_t * local_data, uint8_t * blo
for (size_t ii = 0; ii< blockcount; ii++) {
if (controlbitvector[ii]) {
#pragma omp simd aligned(d,b,z:16)
for (size_t jj = 0; jj < fullblocksize/sizeof(uint64_t); jj++) {
d[ii * fullblocksize/sizeof(uint64_t) + jj] ^= b[ii * fullblocksize/sizeof(uint64_t) + jj] ^ z[jj];
for (size_t jj = 0; jj < memblocksize/sizeof(uint64_t); jj++) {
d[ii * memblocksize/sizeof(uint64_t) + jj] ^= b[ii * memblocksize/sizeof(uint64_t) + jj] ^ z[jj];
}
} else {
#pragma omp simd aligned(d,b:16)
for (size_t jj = 0; jj < fullblocksize/sizeof(uint64_t); jj++) {
d[ii * fullblocksize/sizeof(uint64_t) + jj] ^= b[ii * fullblocksize/sizeof(uint64_t) + jj];
for (size_t jj = 0; jj < memblocksize/sizeof(uint64_t); jj++) {
d[ii * memblocksize/sizeof(uint64_t) + jj] ^= b[ii * memblocksize/sizeof(uint64_t) + jj];
}
}
}
}
// Unfinished
void scanrom_transfer_duplexer(duplexer_fn fn1, duplexer_fn fn2, void* data, void * pd) {
#pragma omp parallel num_threads(2)
{
#pragma omp master
fn1(data, NULL);
#pragma omp single
{
#pragma omp task
fn2(data, pd);
}
}
}
\ No newline at end of file
......@@ -4,14 +4,10 @@
typedef void (* duplexer_fn)(void *, void *);
void scanrom_transfer_duplexer(duplexer_fn fn1, duplexer_fn fn2, void* data, void * pd);
void scanrom_create_local_halfpad(uint8_t * dest, uint8_t * key, size_t size);
void scanrom_read_with_bitvector_offline(uint8_t * data, uint8_t * local_data, bool * bitvector, size_t fullblocksize, size_t blockcount);
void scanrom_encrypt_offline(uint8_t * out, uint8_t * in, uint8_t* key, size_t index, size_t len);
void scanrom_encrypt_offline(uint8_t * out, uint8_t * in, uint8_t* key, size_t index, size_t blockmultiple, size_t blockcount);
void scanwrom_write_with_blockvector_offline(uint8_t * local_data, uint8_t * blockvector, bool * bitvector, uint8_t*Zblock, bool expand, size_t fullblocksize, size_t blockcount);
void scanwrom_write_with_blockvector_offline(uint8_t * local_data, uint8_t * blockvector, bool * bitvector, uint8_t*Zblock, size_t memblocksize, size_t blockcount);
#endif
\ No newline at end of file
This diff is collapsed.
......@@ -12,7 +12,7 @@ void scanrom_set_key(scanrom* rom, uint8_t* key_local);
void scanrom_import_from_scanwrom(scanrom * rom, scanwrom * wrom);
void scanrom_clear(scanrom* rom);
scanrom* scanrom_new(OcCopy* blockcpy, size_t n, void* key_local);
scanrom* scanrom_new(size_t memblocksize, size_t n, void* key_local);
void scanrom_free(scanrom* rom);
void scanwrom_write_with_blockvector(scanwrom* rom, obliv uint8_t * active_block_pair, uint8_t * blockvector, bool* bitvector, obliv uint8_t * old_data, obliv uint8_t * new_data) obliv;
......@@ -22,7 +22,7 @@ void scanwrom_read_xor_shares(uint8_t * data, scanwrom* rom, size_t index, size_
void scanwrom_write_xor_shares(scanwrom* rom, uint8_t * data, size_t index, size_t len);
void scanwrom_clear(scanwrom* rom);
scanwrom* scanwrom_new(OcCopy* blockcpy, size_t n);
scanwrom* scanwrom_new(size_t memblocksize, size_t n);