Commit 6b055ffd authored by Jack Doerner's avatar Jack Doerner

Memory footprint optimizations; appreciable constant factor reduction.

parent 657d1eee
......@@ -10,17 +10,15 @@ struct bitpropagator_offline {
void * Z;
bool * advicebits_l;
bool * advicebits_r;
void * level_data_1;
void * level_data_2;
void * level_bits_1;
void * level_bits_2;
void * level_data;
void * level_bits;
void * keyL;
void * keyR;
omp_lock_t * locks;
};
void bitpropagator_offline_start(bitpropagator_offline * bpo, void * blocks) {
memcpy(bpo->level_data_1, blocks, (1ll<<bpo->startlevel) * BLOCKSIZE);
memcpy(bpo->level_data, blocks, (1ll<<bpo->startlevel) * BLOCKSIZE);
for (int ii = 0; ii < (bpo->endlevel - bpo->startlevel); ii++) {
omp_set_lock(&bpo->locks[ii]);
}
......@@ -39,16 +37,16 @@ void bitpropagator_offline_readblockvector(void * local_output, void * local_bit
size_t thislevelblocks = (1ll<<bpo->startlevel);
size_t nextlevelblocks = (bpo->size + (1ll<<(bpo->endlevel - thislevel -1)) - 1) / (1ll<<(bpo->endlevel - thislevel -1));
uint64_t* a = (uint64_t *)bpo->level_data_1;
uint8_t* a2 = (uint8_t *)bpo->level_data_1;
uint64_t* b = (uint64_t *)bpo->level_data_2;
uint8_t* b2 = (uint8_t *)bpo->level_data_2;
uint64_t* a = (uint64_t *)bpo->level_data;
uint8_t* a2 = (uint8_t *)bpo->level_data;
uint64_t* b = (uint64_t *)local_output;
uint8_t* b2 = (uint8_t *)local_output;
uint64_t* t;
uint8_t* t2;
uint64_t* z;
bool advicebit_l, advicebit_r;
bool * a_bits = bpo->level_bits_1;
bool * b_bits = bpo->level_bits_2;
bool * a_bits = bpo->level_bits;
bool * b_bits = local_bit_output;
bool * t_bits;
#pragma omp parallel for
......@@ -103,9 +101,6 @@ void bitpropagator_offline_readblockvector(void * local_output, void * local_bit
}
}
uint64_t* c = (uint64_t *)local_output;
uint8_t* c2 = (uint8_t *)local_output;
omp_set_lock(&bpo->locks[thislevel- bpo->startlevel -1 ]);
thislevelblocks = nextlevelblocks;
......@@ -119,25 +114,41 @@ void bitpropagator_offline_readblockvector(void * local_output, void * local_bit
b2 = a2; b = a; b_bits = a_bits;
a2 = t2; a = t; a_bits = t_bits;
#pragma omp parallel for
for (size_t ii = 0; ii < thislevelblocks; ii++) {
if (b == local_output) {
#pragma omp parallel for
for (size_t ii = 0; ii < thislevelblocks; ii++) {
if (ii%2 == 0) {
a_bits[ii] = (a2[ii*BLOCKSIZE] & 1) ^ (b_bits[ii/2] & advicebit_l);
} else {
a_bits[ii] = (a2[ii*BLOCKSIZE] & 1) ^ (b_bits[ii/2] & advicebit_r);
}
if (c != NULL) {
if (b_bits[ii/2]) {
#pragma omp simd aligned(c,a,z:16)
#pragma omp simd aligned(b,a,z:16)
for (uint8_t jj = 0; jj < BLOCKSIZE/sizeof(uint64_t); jj++) {
c[ii*(BLOCKSIZE/sizeof(uint64_t))+jj] = a[ii*(BLOCKSIZE/sizeof(uint64_t))+jj] ^ z[jj];
b[ii*(BLOCKSIZE/sizeof(uint64_t))+jj] = a[ii*(BLOCKSIZE/sizeof(uint64_t))+jj] ^ z[jj];
}
} else {
memcpy(&c[ii*(BLOCKSIZE/sizeof(uint64_t))], &a[ii*(BLOCKSIZE/sizeof(uint64_t))], BLOCKSIZE);
memcpy(&b[ii*(BLOCKSIZE/sizeof(uint64_t))], &a[ii*(BLOCKSIZE/sizeof(uint64_t))], BLOCKSIZE);
}
}
if (local_bit_output != NULL) {
memcpy(b_bits, a_bits, thislevelblocks*sizeof(bool));
} else {
#pragma omp parallel for
for (size_t ii = 0; ii < thislevelblocks; ii++) {
if (ii%2 == 0) {
((bool *)local_bit_output)[ii] = (a2[ii*BLOCKSIZE] & 1) ^ (b_bits[ii/2] & advicebit_l);
a_bits[ii] = (a2[ii*BLOCKSIZE] & 1) ^ (b_bits[ii/2] & advicebit_l);
} else {
((bool *)local_bit_output)[ii] = (a2[ii*BLOCKSIZE] & 1) ^ (b_bits[ii/2] & advicebit_r);
a_bits[ii] = (a2[ii*BLOCKSIZE] & 1) ^ (b_bits[ii/2] & advicebit_r);
}
if (b_bits[ii/2]) {
#pragma omp simd aligned(b,a,z:16)
for (uint8_t jj = 0; jj < BLOCKSIZE/sizeof(uint64_t); jj++) {
a[ii*(BLOCKSIZE/sizeof(uint64_t))+jj] ^= z[jj];
}
}
}
}
......@@ -180,14 +191,12 @@ bitpropagator_offline * bitpropagator_offline_new(size_t size, size_t startlevel
bpo->size = size;
bpo->startlevel = startlevel;
bpo->endlevel = LOG2LL(size) + (((1 << LOG2LL(size)) < size)? 1:0);
posix_memalign(&bpo->level_data_1,16,(1ll<<bpo->endlevel) * BLOCKSIZE);
posix_memalign(&bpo->level_data_2,16,(1ll<<bpo->endlevel) * BLOCKSIZE);
posix_memalign(&bpo->level_data,16,(1ll<<bpo->endlevel) * BLOCKSIZE);
posix_memalign(&bpo->Z,16,(bpo->endlevel - bpo->startlevel) * BLOCKSIZE);
bpo->locks = malloc((bpo->endlevel - bpo->startlevel) * sizeof(omp_lock_t));
bpo->advicebits_l = malloc((bpo->endlevel - bpo->startlevel) * sizeof(bool));
bpo->advicebits_r = malloc((bpo->endlevel - bpo->startlevel) * sizeof(bool));
bpo->level_bits_1 = malloc(size * sizeof(bool));
bpo->level_bits_2 = malloc(size * sizeof(bool));
bpo->level_bits = malloc(size * sizeof(bool));
bpo->keyL = offline_prf_keyschedule(keyL);
bpo->keyR = offline_prf_keyschedule(keyR);
......@@ -203,10 +212,8 @@ void bitpropagator_offline_free(bitpropagator_offline * bpo) {
omp_destroy_lock(&bpo->locks[ii]);
}
offline_expand_deinit();
free(bpo->level_data_1);
free(bpo->level_data_2);
free(bpo->level_bits_1);
free(bpo->level_bits_2);
free(bpo->level_data);
free(bpo->level_bits);
free(bpo->advicebits_l);
free(bpo->advicebits_r);
free(bpo->Z);
......
......@@ -161,18 +161,14 @@ flatoram* flatoram_new(OcCopy* cpy, void* data, size_t n) {
if (data != NULL) {
obliv uint8_t * loadtemp = calloc(ram->blockcount, ram->blockcpy.eltsize);
obliv uint8_t * loadtemp2 = calloc(ram->blockcount, ram->blockcpy.eltsize);
get_random_bytes(loadtemp_local, BLOCKSIZE * ram->blockcount * ram->blockmultiple);
feedOblivCharArray(loadtemp, loadtemp_local, BLOCKSIZE * ram->blockcount * ram->blockmultiple, 1);
feedOblivCharArray(loadtemp2, loadtemp_local, BLOCKSIZE * ram->blockcount * ram->blockmultiple, 2);
feedOblivCharArray(loadtemp, loadtemp_local, BLOCKSIZE * ram->blockcount * ram->blockmultiple, 2);
for (size_t ii = 0; ii<(ram->blockmultiple * BLOCKSIZE * ram->blockcount);ii++) loadtemp[ii] ^= loadtemp2[ii];
for (size_t ii = 0; ii<(ram->blockmultiple * BLOCKSIZE * ram->blockcount);ii++) loadtemp[ii] ^= feedOblivChar(loadtemp_local[ii], 1);
for (size_t ii = 0; ii<(ram->blockmultiple * BLOCKSIZE * ram->blockcount);ii++) revealOblivChar(&(loadtemp_local[ii]), loadtemp[ii], 2);
free(loadtemp2);
size_t blockid, subblockid;
for (size_t ii = 0; ii < ram->size; ii++) {
blockid = ii / ram->elementsperblock;
......
......@@ -10,7 +10,6 @@ struct scanrom {
OcCopy * blockcpy;
uint8_t * local_data;
uint8_t * local_halfkey;
uint8_t * local_halfpad;
obliv uint8_t * halfkey_a;
obliv uint8_t * halfkey_b;
uint8_t * local_blocktemp;
......@@ -70,11 +69,12 @@ void scanrom_write_xor_shares(scanrom* rom, obliv uint8_t * data, size_t indexin
//receives one share from each party, encrypts them locally, and shares them
size_t index = indexinit;
scanrom_encrypt_offline(&rom->local_blocktemp[index * rom->fullblocksize], data, rom->local_halfkey, index, rom->fullblocksize * MIN(len, rom->blockcount - index));
scanrom_encrypt_offline(&rom->local_data[index * rom->fullblocksize], data, rom->local_halfkey, index, rom->fullblocksize * MIN(len, rom->blockcount - index));
for (; index < MIN(indexinit + len, rom->blockcount); index++) {
for (size_t ii = 0; ii < rom->fullblocksize; ii++) {
rom->local_data[index * rom->fullblocksize + ii] = ocBroadcastChar(rom->local_blocktemp[index * rom->fullblocksize + ii], 2);
if (ocCurrentParty() == 1) rom->local_data[index * rom->fullblocksize + ii] ^= ocBroadcastChar(NULL, 2);
else ocBroadcastChar(rom->local_data[index * rom->fullblocksize + ii], 2);
}
}
......@@ -82,7 +82,8 @@ void scanrom_write_xor_shares(scanrom* rom, obliv uint8_t * data, size_t indexin
for (; index < MIN(indexinit + len, rom->blockcount); index++) {
for (size_t ii = 0; ii < rom->fullblocksize; ii++) {
rom->local_data[index * rom->fullblocksize + ii] ^= ocBroadcastChar(rom->local_blocktemp[index * rom->fullblocksize + ii], 1);
if (ocCurrentParty() == 2 ) rom->local_data[index * rom->fullblocksize + ii] = ocBroadcastChar(NULL, 1);
else ocBroadcastChar(rom->local_data[index * rom->fullblocksize + ii], 1);
}
}
}
......@@ -106,10 +107,8 @@ scanrom* scanrom_new(OcCopy* blockcpy, int n, void* key_local) {
rom->blockcount = n;
rom->blockcpy = blockcpy;
flatoram_pma(&rom->local_data, 16, n * fullblocksize);
memset(rom->local_data, 0, n * fullblocksize);
flatoram_pma(&rom->local_halfpad, 16, n * fullblocksize);
rom->local_halfkey = malloc(KEYSIZE);
flatoram_pma(&rom->local_blocktemp, 16, n * fullblocksize);
flatoram_pma(&rom->local_blocktemp, 16, 2 * fullblocksize);
rom->blocktemp = calloc(fullblocksize * 3, sizeof(obliv uint8_t));
rom->halfkey_a = calloc(KEYSIZE, sizeof(obliv uint8_t));
rom->halfkey_b = calloc(KEYSIZE, sizeof(obliv uint8_t));
......@@ -121,7 +120,6 @@ void scanrom_free(scanrom* rom) {
offline_expand_deinit();
free(rom->local_data);
free(rom->local_halfkey);
free(rom->local_halfpad);
free(rom->local_blocktemp);
free(rom->blocktemp);
free(rom->halfkey_a);
......
......@@ -43,8 +43,8 @@ void test_main(void*varg) {
oram_set_default_type(ORAM_OVERRIDE);
#endif
int elct = 4;
int elsz = 1;
size_t elct = 4;
size_t elsz = 1;
int samples = 1;
args_t * args_pass = varg;
......@@ -52,13 +52,13 @@ void test_main(void*varg) {
optind = 0; // this allows us to getopt a second time
while ((arg = getopt_long(args_pass->argc, args_pass->argv, options_string, long_options, NULL)) != -1) {
if (arg == 'e') {
elct = atoi(optarg);
elct = atoll(optarg);
if (elct <= 0) {
fprintf (stderr, "Argument for -%c must be positive.\n", arg);
return;
}
} else if (arg == 's') {
elsz = atoi(optarg);
elsz = atoll(optarg);
if (elsz <= 0) {
fprintf (stderr, "Argument for -%c must be positive.\n", arg);
return;
......@@ -107,7 +107,7 @@ void test_main(void*varg) {
oram * o = oram_new(ORAM_TYPE_AUTO, &cpy, elct);
fprintf(stdout, "%d,%d", elct, elsz);
fprintf(stdout, "%lld,%lld", elct, elsz);
for (int kk = 0; kk < samples; kk++) {
uint32_t index_raw = ocBroadcastInt(rand() % elct, 2);
......@@ -131,6 +131,6 @@ void test_main(void*varg) {
//free(input);
oram_free(o);
fprintf(stdout, "\n");
fprintf(stderr, "Write (count:%d, size: %d): %llu microseconds avg, %llu gates avg, %llu bytes avg\n", elct, elsz, tally / samples, tallygates/samples, tallybytes/samples);
fprintf(stderr, "Write (count:%lld, size: %lld): %llu microseconds avg, %llu gates avg, %llu bytes avg\n", elct, elsz, tally / samples, tallygates/samples, tallybytes/samples);
}
......@@ -44,8 +44,8 @@ void test_main(void*varg) {
oram_set_default_type(ORAM_OVERRIDE);
#endif
int elct = 4;
int elsz = 1;
size_t elct = 4;
size_t elsz = 1;
int samples = 1;
args_t * args_pass = varg;
......@@ -53,13 +53,13 @@ void test_main(void*varg) {
optind = 0; // this allows us to getopt a second time
while ((arg = getopt_long(args_pass->argc, args_pass->argv, options_string, long_options, NULL)) != -1) {
if (arg == 'e') {
elct = atoi(optarg);
elct = atoll(optarg);
if (elct <= 0) {
fprintf (stderr, "Argument for -%c must be positive.\n", arg);
return;
}
} else if (arg == 's') {
elsz = atoi(optarg);
elsz = atoll(optarg);
if (elsz <= 0) {
fprintf (stderr, "Argument for -%c must be positive.\n", arg);
return;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment