Commit 9378a042 authored by Jack Doerner's avatar Jack Doerner

Fixed circuit ORAM

Added multithreading
parent cc72dfc9
......@@ -32,6 +32,8 @@ void lso_write(LinearScanOram * oram, obliv int index, obliv bool* data) obliv
void lso_read(LinearScanOram * oram, obliv int index, obliv bool* data) obliv {
index = bit_mask(index, oram->index_size);
for(int j = 0; j < oram->data_size; ++j)
data[j] = oram->data[0][j];
for(int i = 1; i < oram->N; ++i) {
obliv if (index == i)
for(int j = 0; j < oram->data_size; ++j)
......
......@@ -25,13 +25,10 @@ void bitpropagator_offline_start(bitpropagator_offline * bpo, void * blocks) {
}
void bitpropagator_offline_push_Z(bitpropagator_offline * bpo, void * Z, bool advicebit_l, bool advicebit_r, size_t level) {
//#pragma omp task
{
memcpy(&bpo->Z[(level- bpo->startlevel - 1)*BLOCKSIZE], Z, BLOCKSIZE);
bpo->advicebits_l[level- bpo->startlevel - 1] = advicebit_l;
bpo->advicebits_r[level- bpo->startlevel - 1] = advicebit_r;
omp_unset_lock(&bpo->locks[level- bpo->startlevel - 1]);
}
memcpy(&bpo->Z[(level- bpo->startlevel - 1)*BLOCKSIZE], Z, BLOCKSIZE);
bpo->advicebits_l[level- bpo->startlevel - 1] = advicebit_l;
bpo->advicebits_r[level- bpo->startlevel - 1] = advicebit_r;
omp_unset_lock(&bpo->locks[level- bpo->startlevel - 1]);
}
void bitpropagator_offline_readblockvector(void * local_output, void * local_bit_output, bitpropagator_offline * bpo) {
......@@ -146,17 +143,27 @@ void bitpropagator_offline_readblockvector(void * local_output, void * local_bit
}
}
void bitpropagator_offline_parallelizer(void* bp, bitpropagator_offline * bpo, void* indexp, void * local_output, void * local_bit_output, bp_traverser_fn fn) {
void bitpropagator_offline_parallelizer(void* bp, bitpropagator_offline * bpo, void* indexp, void * local_output, void * local_bit_output, void* pd, bp_traverser_fn fn, bp_pusher_fn fn2, facb_fn cbfn, void* cbpass) {
omp_set_nested(true);
#pragma omp parallel num_threads(2)
#pragma omp parallel num_threads(3)
{
//OpenMP seems to get along with obliv-c just fine, so long as obliv-c only uses the master thread.
#pragma omp master
fn(bp, indexp);
{
fn(bp, indexp);
if (*cbfn!=NULL) {
cbfn(cbpass);
}
}
#pragma omp single
{
#pragma omp task
fn2(bp, bpo, pd);
#pragma omp task
bitpropagator_offline_readblockvector(local_output, local_bit_output, bpo);
}
......
......@@ -5,12 +5,14 @@
typedef struct bitpropagator_offline bitpropagator_offline;
typedef void (* bp_traverser_fn)(void *, void *);
typedef void (* bp_pusher_fn)(void *, void *, void *);
typedef void (* facb_fn)(void *);
void bitpropagator_offline_start(bitpropagator_offline * bpo, void * blocks);
void bitpropagator_offline_push_Z(bitpropagator_offline * bpo, void * Z, bool advicebit_l, bool advicebit_r, size_t level);
void bitpropagator_offline_readblockvector(void * local_output, void* local_bit_output, bitpropagator_offline * bpo);
void bitpropagator_offline_parallelizer(void* bp, bitpropagator_offline * bpo, void* indexp, void * local_output, void* local_bit_output, bp_traverser_fn fn);
void bitpropagator_offline_parallelizer(void* bp, bitpropagator_offline * bpo, void* indexp, void * local_output, void* local_bit_output, void* pd, bp_traverser_fn fn, bp_pusher_fn fn2, facb_fn cbfn, void* cbpass);
bitpropagator_offline * bitpropagator_offline_new(size_t size, size_t startlevel);
void bitpropagator_offline_free(bitpropagator_offline * bpo);
......
......@@ -2,6 +2,8 @@
#include "bitpropagate.h"
#include "flatoram_util.oh"
#include "ackutil.h"
#include <omp.h>
#include <obliv.h>
struct bitpropagator {
uint32_t startlevel;
......@@ -18,11 +20,37 @@ struct bitpropagator {
obliv uint8_t * expanded_A;
obliv uint8_t * expanded_B;
obliv uint8_t * Z;
uint8_t * Z_local;
obliv bool * advicebits;
omp_lock_t * locks;
bitpropagator_offline * bpo;
};
void bitpropagator_Z_pusher(bitpropagator * bp, bitpropagator_offline * bpo, ProtocolDesc* pd) {
ocSetCurrentProto(pd);
bool advicebit_local_l, advicebit_local_r;
uint8_t * Z_local = malloc(BLOCKSIZE);
for (size_t ii = bp->startlevel+1; ii <= bp->endlevel; ii++) {
size_t thislevel = ii- bp->startlevel -1;
omp_set_lock(&bp->locks[thislevel]);
for (size_t jj = 0; jj < BLOCKSIZE; jj++) revealOblivChar(&Z_local[jj], bp->Z[thislevel*BLOCKSIZE+jj], 1);
revealOblivBool(&advicebit_local_l, bp->advicebits[thislevel*2+0], 1);
revealOblivBool(&advicebit_local_r, bp->advicebits[thislevel*2+1], 1);
for (size_t jj = 0; jj < BLOCKSIZE; jj++) revealOblivChar(&Z_local[jj], bp->Z[thislevel*BLOCKSIZE+jj], 2);
revealOblivBool(&advicebit_local_l, bp->advicebits[thislevel*2+0], 2);
revealOblivBool(&advicebit_local_r, bp->advicebits[thislevel*2+1], 2);
bitpropagator_offline_push_Z(bpo, Z_local, advicebit_local_l, advicebit_local_r, ii);
}
free(Z_local);
}
void bitpropagator_traverselevels(bitpropagator * bp, obliv size_t * indexp) {
obliv uint32_t levelindex;
obliv size_t index = *indexp;
obliv bool control_bit_A_next, control_bit_B_next;
......@@ -32,6 +60,8 @@ void bitpropagator_traverselevels(bitpropagator * bp, obliv size_t * indexp) {
for (size_t ii = bp->startlevel+1; ii <= bp->endlevel; ii++) {
levelindex = (index >> (bp->endlevel - ii)) & 1;
obliv uint8_t * Z = &bp->Z[(ii - bp->startlevel - 1)*BLOCKSIZE];
//first expand our active blocks into two blocks (L/R)
online_expand(bp->expanded_A, bp->activeblock_A, 2);
online_expand(bp->expanded_B, bp->activeblock_B, 2);
......@@ -40,38 +70,34 @@ void bitpropagator_traverselevels(bitpropagator * bp, obliv size_t * indexp) {
ocCopyN(&ocCopyChar, bp->activeblock_A, bp->expanded_A, BLOCKSIZE);
ocCopyN(&ocCopyChar, bp->activeblock_B, bp->expanded_B, BLOCKSIZE);
//Z = block_A XOR block_B for the silenced branch
ocCopyN(&ocCopyChar, bp->Z, &bp->expanded_A[BLOCKSIZE], BLOCKSIZE);
ocCopyN(&ocCopyChar, Z, &bp->expanded_A[BLOCKSIZE], BLOCKSIZE);
for (size_t jj = 0; jj < BLOCKSIZE; jj ++) {
bp->Z[jj] ^= bp->expanded_B[BLOCKSIZE+jj];
Z[jj] ^= bp->expanded_B[BLOCKSIZE+jj];
}
} else {
//copy the branch to be kept
ocCopyN(&ocCopyChar, bp->activeblock_A, &bp->expanded_A[BLOCKSIZE], BLOCKSIZE);
ocCopyN(&ocCopyChar, bp->activeblock_B, &bp->expanded_B[BLOCKSIZE], BLOCKSIZE);
//Z = block_A XOR block_B for the silenced branch
ocCopyN(&ocCopyChar, bp->Z, bp->expanded_A, BLOCKSIZE);
ocCopyN(&ocCopyChar, Z, bp->expanded_A, BLOCKSIZE);
for (size_t jj = 0; jj < BLOCKSIZE; jj ++) {
bp->Z[jj] ^= bp->expanded_B[jj];
Z[jj] ^= bp->expanded_B[jj];
}
}
bool advicebit_local_l, advicebit_local_r;
bitpropagator_getadvice(&advicebit_local_l, &advicebit_local_r, bp->expanded_A, bp->expanded_B, levelindex);
for (size_t jj = 0; jj < BLOCKSIZE; jj++) revealOblivChar(&bp->Z_local[jj], bp->Z[jj], 1);
for (size_t jj = 0; jj < BLOCKSIZE; jj++) revealOblivChar(&bp->Z_local[jj], bp->Z[jj], 2);
bitpropagator_getadvice(&bp->advicebits[(ii - bp->startlevel - 1)*2], bp->expanded_A, bp->expanded_B, levelindex);
bitpropagator_offline_push_Z(bp->bpo, bp->Z_local, advicebit_local_l, advicebit_local_r, ii);
omp_unset_lock(&bp->locks[ii - bp->startlevel - 1]);
control_bit_A_next = control_bit_A;
control_bit_B_next = control_bit_B;
obliv if (levelindex == 0) {
control_bit_A_next &= advicebit_local_l;
control_bit_B_next &= advicebit_local_l;
control_bit_A_next &= bp->advicebits[(ii - bp->startlevel - 1)*2];
control_bit_B_next &= bp->advicebits[(ii - bp->startlevel - 1)*2];
} else {
control_bit_A_next &= advicebit_local_r;
control_bit_B_next &= advicebit_local_r;
control_bit_A_next &= bp->advicebits[(ii - bp->startlevel - 1)*2+1];
control_bit_B_next &= bp->advicebits[(ii - bp->startlevel - 1)*2+1];
}
control_bit_A_next ^= ((obliv bool *)bp->activeblock_A)[0];
control_bit_B_next ^= ((obliv bool *)bp->activeblock_B)[0];
......@@ -94,6 +120,10 @@ void bitpropagator_traverselevels(bitpropagator * bp, obliv size_t * indexp) {
}
void bitpropagator_getblockvector(obliv uint8_t * activeblock_pair, uint8_t * local_output, bool * local_bit_output, bitpropagator * bp, obliv size_t index) {
bitpropagator_getblockvector_with_callback(activeblock_pair, local_output, local_bit_output, bp, index, NULL, NULL);
}
void bitpropagator_getblockvector_with_callback(obliv uint8_t * activeblock_pair, uint8_t * local_output, bool * local_bit_output, bitpropagator * bp, obliv size_t index, facb_fn cbfn, void* cbpass) {
//Collect a set of random blocks for the top level
get_random_bytes(bp->toplevel_local, ((1ll << bp->startlevel) + 1) * BLOCKSIZE);
......@@ -123,20 +153,19 @@ void bitpropagator_getblockvector(obliv uint8_t * activeblock_pair, uint8_t * lo
bitpropagator_offline_start(bp->bpo, bp->toplevel_local);
//This is a hack to work around the fact that openmp and obliv-c are incompatible.
bitpropagator_offline_parallelizer(bp, bp->bpo, &index, local_output, local_bit_output, bitpropagator_traverselevels);
ProtocolDesc pd2;
splitProtocol(&pd2, ocCurrentProto());
bitpropagator_offline_parallelizer(bp, bp->bpo, &index, local_output, local_bit_output, &pd2, bitpropagator_traverselevels, bitpropagator_Z_pusher, cbfn, cbpass);
cleanupProtocol(&pd2);
//write output
ocCopyN(&ocCopyChar, activeblock_pair, bp->activeblock_A, BLOCKSIZE);
ocCopyN(&ocCopyChar, &activeblock_pair[BLOCKSIZE], bp->activeblock_B, BLOCKSIZE);
}
void bitpropagator_getadvice(bool * advicebit_local_l, bool * advicebit_local_r, obliv uint8_t * blocks_A, obliv uint8_t * blocks_B, obliv bool rightblock) {
obliv bool advicebit_l = ((obliv bool *)blocks_A)[0] ^ ((obliv bool *)blocks_B)[0] ^ rightblock ^ 1;
obliv bool advicebit_r = ((obliv bool *)blocks_A)[BLOCKSIZE*8] ^ ((obliv bool *)blocks_B)[BLOCKSIZE*8] ^ rightblock;
revealOblivBool(advicebit_local_l, advicebit_l, 1);
revealOblivBool(advicebit_local_r, advicebit_r, 1);
revealOblivBool(advicebit_local_l, advicebit_l, 2);
revealOblivBool(advicebit_local_r, advicebit_r, 2);
void bitpropagator_getadvice(obliv bool * advicebits, obliv uint8_t * blocks_A, obliv uint8_t * blocks_B, obliv bool rightblock) {
advicebits[0] = ((obliv bool *)blocks_A)[0] ^ ((obliv bool *)blocks_B)[0] ^ rightblock ^ 1;
advicebits[1] = ((obliv bool *)blocks_A)[BLOCKSIZE*8] ^ ((obliv bool *)blocks_B)[BLOCKSIZE*8] ^ rightblock;
}
bitpropagator * bitpropagator_new(size_t size, uint32_t startlevel) {
......@@ -155,10 +184,16 @@ bitpropagator * bitpropagator_new(size_t size, uint32_t startlevel) {
bp->activeblock_B = calloc(1, BLOCKSIZE * sizeof(obliv uint8_t));
bp->expanded_A = calloc(2, BLOCKSIZE * sizeof(obliv uint8_t));
bp->expanded_B = calloc(2, BLOCKSIZE * sizeof(obliv uint8_t));
bp->Z = calloc(1, BLOCKSIZE * sizeof(obliv uint8_t));
bp->Z_local = malloc(1 * BLOCKSIZE);
bp->Z = calloc((bp->endlevel - bp->startlevel), BLOCKSIZE * sizeof(obliv uint8_t));
bp->advicebits = calloc((bp->endlevel - bp->startlevel), 2*sizeof(obliv bool));
bp->bpo = bitpropagator_offline_new(size, bp->startlevel);
bp->locks = malloc((bp->endlevel - bp->startlevel) * sizeof(omp_lock_t));
for (size_t ii = 0; ii < bp->endlevel - bp->startlevel; ii++) {
omp_init_lock(&bp->locks[ii]);
omp_set_lock(&bp->locks[ii]);
}
return bp;
}
......@@ -174,8 +209,14 @@ void bitpropagator_free(bitpropagator * bp) {
free(bp->expanded_A);
free(bp->expanded_B);
free(bp->Z);
free(bp->Z_local);
free(bp->advicebits);
bitpropagator_offline_free(bp->bpo);
for (int ii = 0; ii < (bp->endlevel - bp->startlevel); ii++) {
omp_destroy_lock(&bp->locks[ii]);
}
free(bp->locks);
online_expand_deinit();
free(bp);
}
\ No newline at end of file
......@@ -4,8 +4,11 @@
typedef struct bitpropagator bitpropagator;
typedef void (* facb_fn)(void *);
void bitpropagator_getblockvector(obliv uint8_t * activeblock_pair, uint8_t * local_output, bool * local_bit_output, bitpropagator * bp, obliv size_t index);
void bitpropagator_getadvice(bool * tl, bool * tr, obliv uint8_t * blocks_A, obliv uint8_t * blocks_B, obliv bool rightblock);
void bitpropagator_getblockvector_with_callback(obliv uint8_t * activeblock_pair, uint8_t * local_output, bool * local_bit_output, bitpropagator * bp, obliv size_t index, facb_fn cbfn, void* cbpass);
void bitpropagator_getadvice(obliv bool * advicebits, obliv uint8_t * blocks_A, obliv uint8_t * blocks_B, obliv bool rightblock);
bitpropagator * bitpropagator_new(size_t size, uint32_t truncated_levels);
void bitpropagator_free(bitpropagator * bp);
......
......@@ -49,11 +49,18 @@ void flatoram_refresh(flatoram* ram) {
}
}
void flatoram_apply(flatoram* ram, void* data, flatoram_block_access_function fn, obliv size_t index) obliv {
obliv size_t blockid = index / ram->elementsperblock;
obliv size_t subblockid = index % ram->elementsperblock;
typedef struct facb_pass {
flatoram * ram;
obliv size_t blockid;
obliv size_t subblockid;
obliv bool found;
} facb_pass;
flatoram_scan_callback(facb_pass * input) {
flatoram * ram = input->ram;
obliv size_t blockid = input->blockid;
obliv size_t subblockid = input->subblockid;
obliv bool found = false;
if (ram->progress > 0) {
ocCopy(&ram->blockcpy, element(&ram->blockcpy, ram->stash, ram->progress), ram->stash);
ram->stashi[ram->progress] = ram->stashi[0];
......@@ -67,10 +74,18 @@ void flatoram_apply(flatoram* ram, void* data, flatoram_block_access_function fn
}
}
}
input->found=found;
}
~obliv() bitpropagator_getblockvector(ram->activeblock_pair, ram->blockvector_local, ram->bitvector_local, ram->bitpropagator, blockid);
void flatoram_apply(flatoram* ram, void* data, flatoram_block_access_function fn, obliv size_t index) obliv {
obliv size_t blockid = index / ram->elementsperblock;
obliv size_t subblockid = index % ram->elementsperblock;
facb_pass facb_data = {.ram=ram, .blockid = blockid, .subblockid = subblockid, .found = false};
~obliv() bitpropagator_getblockvector_with_callback(ram->activeblock_pair, ram->blockvector_local, ram->bitvector_local, ram->bitpropagator, blockid, flatoram_scan_callback, &facb_data);
obliv if (found == false) {
obliv if (facb_data.found == false) {
scanrom_read_with_bitvector(ram->stash, ram->rom, blockid, ram->bitvector_local);
ram->stashi[0] = blockid;
}
......
......@@ -43,4 +43,20 @@ void scanwrom_write_with_blockvector_offline(uint8_t * local_data, uint8_t * blo
}
}
}
}
void scanrom_transfer_duplexer(duplexer_fn fn1, duplexer_fn fn2, void* data, void * pd) {
#pragma omp parallel num_threads(2)
{
#pragma omp master
fn1(data, NULL);
#pragma omp single
{
#pragma omp task
fn2(data, pd);
}
}
}
\ No newline at end of file
......@@ -2,6 +2,10 @@
#define SCANROM_H
#include "flatoram.h"
typedef void (* duplexer_fn)(void *, void *);
void scanrom_transfer_duplexer(duplexer_fn fn1, duplexer_fn fn2, void* data, void * pd);
void scanrom_create_local_halfpad(void * dest, void * key, size_t size);
void scanrom_read_with_bitvector_offline(void* data, void* local_data, bool * bitvector, size_t fullblocksize, size_t blockcount);
......
......@@ -74,7 +74,7 @@ void scanrom_write_xor_shares(scanrom* rom, obliv uint8_t * data, size_t indexin
for (; index < MIN(indexinit + len, rom->blockcount); index++) {
for (size_t ii = 0; ii < rom->fullblocksize; ii++) {
rom->local_data[index * rom->fullblocksize + ii] = ocBroadcastChar(rom->local_blocktemp[index * rom->fullblocksize + ii], 1);
rom->local_data[index * rom->fullblocksize + ii] = ocBroadcastChar(rom->local_blocktemp[index * rom->fullblocksize + ii], 2);
}
}
......@@ -82,7 +82,7 @@ void scanrom_write_xor_shares(scanrom* rom, obliv uint8_t * data, size_t indexin
for (; index < MIN(indexinit + len, rom->blockcount); index++) {
for (size_t ii = 0; ii < rom->fullblocksize; ii++) {
rom->local_data[index * rom->fullblocksize + ii] ^= ocBroadcastChar(rom->local_blocktemp[index * rom->fullblocksize + ii], 2);
rom->local_data[index * rom->fullblocksize + ii] ^= ocBroadcastChar(rom->local_blocktemp[index * rom->fullblocksize + ii], 1);
}
}
}
......@@ -171,8 +171,8 @@ void scanwrom_write_with_blockvector(scanwrom* rom, obliv uint8_t * active_block
for (size_t ii = 0; ii < rom->fullblocksize; ii++) {
rom->blocktemp[ii] ^= rom->blocktemp[rom->fullblocksize + ii] ^ rom->blocktemp[2*rom->fullblocksize + ii];
}
for (size_t ii = 0; ii < rom->fullblocksize; ii++) revealOblivChar(&rom->local_blocktemp[ii], rom->blocktemp[ii], 1);
for (size_t ii = 0; ii < rom->fullblocksize; ii++) revealOblivChar(&rom->local_blocktemp[ii], rom->blocktemp[ii], 2);
for (size_t ii = 0; ii < rom->fullblocksize; ii++) revealOblivChar(&rom->local_blocktemp[ii], rom->blocktemp[ii], 1);
scanwrom_write_with_blockvector_offline(rom->local_data, blockvector, bitvector, rom->local_blocktemp, rom->blockmultiple != 1, rom->fullblocksize, rom->blockcount);
}
......
......@@ -102,8 +102,8 @@ void test_main(void*varg) {
uint64_t tally = 0;
uint64_t tallygates = 0;
uint64_t tallybytes = 0;
obliv uint32_t * input = calloc(elsz * elct, sizeof(obliv uint32_t));
for (int kk = 0; kk < (elsz * elct); kk++) input[kk] = feedOblivInt(rand(), 1);
//obliv uint32_t * input = calloc(elsz * elct, sizeof(obliv uint32_t));
//for (int kk = 0; kk < (elsz * elct); kk++) input[kk] = feedOblivInt(rand(), 1);
oram * o = oram_new(ORAM_TYPE_AUTO, &cpy, elct);
......@@ -115,7 +115,9 @@ void test_main(void*varg) {
uint64_t startTime = current_timestamp();
int64_t rungates = - yaoGateCount();
int64_t runbytes = -tcp2PBytesSent(ocCurrentProto());
oram_write(o, &input[index_raw*elsz], index);
//oram_write(o, &input[index_raw*elsz], index);
obliv uint32_t input = feedOblivInt(rand(), 1);
oram_write(o, &input, index);
uint64_t runtime = current_timestamp() - startTime;
rungates += yaoGateCount();
runbytes += tcp2PBytesSent(ocCurrentProto());
......@@ -126,7 +128,7 @@ void test_main(void*varg) {
tallybytes += runbytes;
}
free(input);
//free(input);
oram_free(o);
fprintf(stdout, "\n");
fprintf(stderr, "Write (count:%d, size: %d): %llu microseconds avg, %llu gates avg, %llu bytes avg\n", elct, elsz, tally / samples, tallygates/samples, tallybytes/samples);
......
......@@ -47,6 +47,8 @@ void test_main(void*varg) {
oram_set_default_type(ORAM_TYPE_CIRCUIT);
} else if (strcmp(optarg,"linear") == 0) {
oram_set_default_type(ORAM_TYPE_LINEAR);
} else if (strcmp(optarg,"flat") == 0) {
oram_set_default_type(ORAM_TYPE_FLAT);
} else {
fprintf (stderr, "Invalid argument for -%c.\n", arg);
return;
......
......@@ -61,6 +61,8 @@ void test_main(void*varg) {
oram_set_default_type(ORAM_TYPE_CIRCUIT);
} else if (strcmp(optarg,"linear") == 0) {
oram_set_default_type(ORAM_TYPE_LINEAR);
} else if (strcmp(optarg,"flat") == 0) {
oram_set_default_type(ORAM_TYPE_FLAT);
} else {
fprintf (stderr, "Invalid argument for -%c.\n", arg);
return;
......
......@@ -47,6 +47,8 @@ void test_main(void*varg) {
oram_set_default_type(ORAM_TYPE_CIRCUIT);
} else if (strcmp(optarg,"linear") == 0) {
oram_set_default_type(ORAM_TYPE_LINEAR);
} else if (strcmp(optarg,"flat") == 0) {
oram_set_default_type(ORAM_TYPE_FLAT);
} else {
fprintf (stderr, "Invalid argument for -%c.\n", arg);
return;
......
......@@ -53,6 +53,8 @@ void test_main(void*varg) {
oram_set_default_type(ORAM_TYPE_CIRCUIT);
} else if (strcmp(optarg,"linear") == 0) {
oram_set_default_type(ORAM_TYPE_LINEAR);
} else if (strcmp(optarg,"flat") == 0) {
oram_set_default_type(ORAM_TYPE_FLAT);
} else {
fprintf (stderr, "Invalid argument for -%c.\n", arg);
return;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment