Commit feec5403 authored by Sergey Smirnov's avatar Sergey Smirnov

Download input file only once in case of multiple concurrent task

submits with the same file hash
parent a5631355
......@@ -874,6 +874,7 @@ class Agent(tcpserver.TCPServer):
download_start = time.time()
requests = []
files = []
wait_files = []
for inp in inputData:
if inp.get('delete', False):
result['deleteURIs'].append((inp['uri'], inp.get('auth', '')))
......@@ -900,7 +901,13 @@ class Agent(tcpserver.TCPServer):
inp['path'], inp['sha1']))
shutil.copyfile(path, dest_path)
continue
f = self.cache.putFile()
future = self.cache.waitFile(inp['sha1'])
if not future is None:
logger.debug('Waiting for %s with hash %s to be downloaded by another task' % (
inp['path'], inp['sha1']))
wait_files.append((dest_path, inp['sha1'], future))
continue
f = self.cache.putFile(inp.get('sha1', ''))
files.append((dest_path, inp.get('sha1', ''), f))
headers = {}
if 'auth' in inp:
......@@ -922,6 +929,10 @@ class Agent(tcpserver.TCPServer):
finally:
files = [(n, s, f.close()) for n, s, f in files]
after_wait = yield [w[2] for w in wait_files]
for i, r in enumerate(after_wait):
files.append((wait_files[i][0], wait_files[i][1], r))
for dest_path, shaOrig, (shaReal, path) in files:
if shaOrig and shaReal != shaOrig:
raise TaskException("Input file %s hash differs from downloaded %s != %s"
......
......@@ -17,6 +17,7 @@ import tempfile
import hashlib
import time
import operator
import tornado
class Cache(object):
def __init__(self, cacheDir, cacheSize):
......@@ -69,14 +70,25 @@ class Cache(object):
def _isCached(self, sha1):
return sha1 in self.paths
def putFile(self):
def putFile(self, sha1=''):
"Return file-like object to be used by the client for writing file contents"
if not os.path.exists(self.tmpDir):
os.makedirs(self.tmpDir)
f = tempfile.NamedTemporaryFile(delete=False, dir=self.tmpDir)
self.puts[f] = hashlib.sha1()
self.puts[f] = {'hash_calc' : hashlib.sha1(),
'expected' : sha1,
'futures' : []}
return Putter(self, f)
def waitFile(self, sha1):
if not sha1:
return None
for f, p in self.puts.iteritems():
if p['expected'] == sha1:
p['futures'].append(tornado.concurrent.Future())
return p['futures'][-1]
return None
def dropFile(self, sha1):
path = self.getFile(sha1)
if path:
......@@ -90,17 +102,20 @@ class Cache(object):
def _write(self, f, data):
assert(f in self.puts)
f.write(data)
self.puts[f].update(data)
self.puts[f]['hash_calc'].update(data)
def _close(self, f):
assert(f in self.puts)
tmpName = f.name
f.close()
sha1 = self.puts[f].hexdigest()
sha1 = self.puts[f]['hash_calc'].hexdigest()
futures = self.puts[f]['futures']
del self.puts[f]
if sha1 in self.paths:
os.remove(tmpName)
self.paths[sha1]['atime'] = time.time()
for future in futures:
future.set_result((sha1, self.paths[sha1]['path']))
return sha1, self.paths[sha1]['path']
size = os.path.getsize(tmpName)
if size + self.size > self.maxSize:
......@@ -113,6 +128,8 @@ class Cache(object):
'atime' : time.time()
}
self.size += size
for future in futures:
future.set_result((sha1, path))
return (sha1, path)
def _cleanup(self, f):
......
......@@ -828,6 +828,34 @@ class AgentTests(AgentTestServer):
sendMessage(self.app.ws, 'TASK_CANCEL', str(self.taskId))
self.checkTaskStates(self.taskId, ['CANCELED'])
def testCacheConcurrentStageIn(self):
data = '1 2 3 4\n'
sha1 = hashlib.sha1(data).hexdigest()
with open(os.path.join(self.file_root, 'nums.txt'), 'wb') as f:
f.write(data)
inputDatum = {
'path' : 'in0.txt',
'uri' : self.get_url('/slow/nums.txt'),
'sha1' : sha1
}
self.doSubmit(
'python %s -1 0 0.1 0 0 in0.txt' % DUMMY_PATH,
inputData=[inputDatum])
self.doSubmit(
'python %s -1 0 0.1 0 0 in0.txt' % DUMMY_PATH,
inputData=[inputDatum])
self.checkTaskStatesSoft(self.taskId-1, ['ACCEPTED', 'STAGED_IN', 'RUNNING',
'COMPLETED', 'STAGED_OUT', 'DONE'])
self.checkTaskStatesSoft(self.taskId, ['ACCEPTED', 'STAGED_IN', 'RUNNING',
'COMPLETED', 'STAGED_OUT', 'DONE'])
lines = []
with open('log/agent.txt', 'r') as f:
for l in f.readlines():
lines.append(l)
if len(lines) > 100:
lines = lines[1:]
self.assertTrue(sha1 + ' to be downloaded by' in '\n'.join(lines))
def testAuth(self):
inp = self.makeInput('input')
inputDatum = {
......
......@@ -93,5 +93,21 @@ class CacheTest(unittest.TestCase):
self.isCached(hash2)
self.notCached(hash1)
def testPutWithHash(self):
H = SHA['tree.tar.gz']
self.notCached(H)
path = 'tree.tar.gz'
p = self.cache.putFile(H)
self.assertEqual('', self.cache.getFile(H))
future = self.cache.waitFile(H)
self.assertTrue(future.running())
with open(path, 'r') as f:
p.write(f.read())
h, path1 = p.close()
self.assertEqual(h, H)
self.isCached(H)
self.assertTrue(future.done())
self.assertEqual((h, path1), future.result())
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment