pushreceive.py 27.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
#!/usr/bin/python3

# Push OSTree commits to a remote repo, based on Dan Nicholson's ostree-push
#
# Copyright (C) 2015  Dan Nicholson <nicholson@endlessm.com>
# Copyright (C) 2017  Tristan Van Berkom <tristan.vanberkom@codethink.co.uk>
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
import logging
22
import multiprocessing
23
import os
24
import re
25 26 27 28
import subprocess
import sys
import shutil
import tarfile
29
import tempfile
30
from enum import Enum
31 32
from urllib.parse import urlparse

33 34 35
import click
import gi

36
from .. import _signals  # nopep8
37
from .._profile import Topics, profile_start, profile_end
38 39

gi.require_version('OSTree', '1.0')
40
# pylint: disable=wrong-import-position,wrong-import-order
41 42 43
from gi.repository import GLib, Gio, OSTree  # nopep8


44
PROTO_VERSION = 1
45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65
HEADER_SIZE = 5


# An error occurred
class PushException(Exception):
    pass


# Trying to commit a ref which already exists in remote
class PushExistsException(Exception):
    pass


class PushCommandType(Enum):
    info = 0
    update = 1
    putobjects = 2
    status = 3
    done = 4


66 67
def python_to_msg_byteorder(python_byteorder=sys.byteorder):
    if python_byteorder == 'little':
68
        return 'l'
69
    elif python_byteorder == 'big':
70 71
        return 'B'
    else:
72
        raise PushException('Unrecognized system byteorder {}'
73
                            .format(python_byteorder))
74 75


76 77
def msg_to_python_byteorder(msg_byteorder):
    if msg_byteorder == 'l':
78
        return 'little'
79
    elif msg_byteorder == 'B':
80 81
        return 'big'
    else:
82
        raise PushException('Unrecognized message byteorder {}'
83
                            .format(msg_byteorder))
84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114


def ostree_object_path(repo, obj):
    repodir = repo.get_path().get_path()
    return os.path.join(repodir, 'objects', obj[0:2], obj[2:])


class PushCommand(object):
    def __init__(self, cmdtype, args):
        self.cmdtype = cmdtype
        self.args = args
        self.validate(self.cmdtype, self.args)
        self.variant = GLib.Variant('a{sv}', self.args)

    @staticmethod
    def validate(command, args):
        if not isinstance(command, PushCommandType):
            raise PushException('Message command must be PushCommandType')
        if not isinstance(args, dict):
            raise PushException('Message args must be dict')
        # Ensure all values are variants for a{sv} vardict
        for val in args.values():
            if not isinstance(val, GLib.Variant):
                raise PushException('Message args values must be '
                                    'GLib.Variant')


class PushMessageWriter(object):
    def __init__(self, file, byteorder=sys.byteorder):
        self.file = file
        self.byteorder = byteorder
115
        self.msg_byteorder = python_to_msg_byteorder(self.byteorder)
116 117 118 119 120 121 122 123 124 125

    def encode_header(self, cmdtype, size):
        header = self.msg_byteorder.encode() + \
            PROTO_VERSION.to_bytes(1, self.byteorder) + \
            cmdtype.value.to_bytes(1, self.byteorder) + \
            size.to_bytes(2, self.byteorder)
        return header

    def encode_message(self, command):
        if not isinstance(command, PushCommand):
126
            raise PushException('Command must be PushCommand')
127 128 129 130 131 132 133 134 135 136 137 138 139
        data = command.variant.get_data_as_bytes()
        size = data.get_size()

        # Build the header
        header = self.encode_header(command.cmdtype, size)

        return header + data.get_data()

    def write(self, command):
        msg = self.encode_message(command)
        self.file.write(msg)
        self.file.flush()

140
    def send_hello(self):
141 142
        # The 'hello' message is used to check connectivity and discover the
        # cache's pull URL. It's actually transmitted as an empty info request.
143 144 145 146 147 148 149
        args = {
            'mode': GLib.Variant('i', 0),
            'refs': GLib.Variant('a{ss}', {})
        }
        command = PushCommand(PushCommandType.info, args)
        self.write(command)

150
    def send_info(self, repo, refs, pull_url=None):
151 152
        cmdtype = PushCommandType.info
        mode = repo.get_mode()
153 154 155 156 157

        ref_map = {}
        for ref in refs:
            _, checksum = repo.resolve_rev(ref, True)
            if checksum:
158
                _, has_object = repo.has_object(OSTree.ObjectType.COMMIT, checksum, None)
159 160 161
                if has_object:
                    ref_map[ref] = checksum

162 163
        args = {
            'mode': GLib.Variant('i', mode),
164
            'refs': GLib.Variant('a{ss}', ref_map)
165
        }
166 167 168 169 170 171

        # The server sends this so clients can discover the correct pull URL
        # for this cache without requiring end-users to specify it.
        if pull_url:
            args['pull_url'] = GLib.Variant('s', pull_url)

172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230
        command = PushCommand(cmdtype, args)
        self.write(command)

    def send_update(self, refs):
        cmdtype = PushCommandType.update
        args = {}
        for branch, revs in refs.items():
            args[branch] = GLib.Variant('(ss)', revs)
        command = PushCommand(cmdtype, args)
        self.write(command)

    def send_putobjects(self, repo, objects):

        logging.info('Sending {} objects'.format(len(objects)))

        # Send command saying we're going to send a stream of objects
        cmdtype = PushCommandType.putobjects
        command = PushCommand(cmdtype, {})
        self.write(command)

        # Open a TarFile for writing uncompressed tar to a stream
        tar = tarfile.TarFile.open(mode='w|', fileobj=self.file)
        for obj in objects:

            logging.info('Sending object {}'.format(obj))
            objpath = ostree_object_path(repo, obj)
            stat = os.stat(objpath)

            tar_info = tarfile.TarInfo(obj)
            tar_info.mtime = stat.st_mtime
            tar_info.size = stat.st_size
            with open(objpath, 'rb') as obj_fp:
                tar.addfile(tar_info, obj_fp)

        # We're done, close the tarfile
        tar.close()

    def send_status(self, result, message=''):
        cmdtype = PushCommandType.status
        args = {
            'result': GLib.Variant('b', result),
            'message': GLib.Variant('s', message)
        }
        command = PushCommand(cmdtype, args)
        self.write(command)

    def send_done(self):
        command = PushCommand(PushCommandType.done, {})
        self.write(command)


class PushMessageReader(object):
    def __init__(self, file, byteorder=sys.byteorder, tmpdir=None):
        self.file = file
        self.byteorder = byteorder
        self.tmpdir = tmpdir

    def decode_header(self, header):
        if len(header) != HEADER_SIZE:
231
            raise Exception('Header is {:d} bytes, not {:d}'.format(len(header), HEADER_SIZE))
232
        order = msg_to_python_byteorder(chr(header[0]))
233 234
        version = int(header[1])
        if version != PROTO_VERSION:
235
            raise Exception('Unsupported protocol version {:d}'.format(version))
236 237 238 239 240 241
        cmdtype = PushCommandType(int(header[2]))
        vlen = int.from_bytes(header[3:], order)
        return order, version, cmdtype, vlen

    def decode_message(self, message, size, order):
        if len(message) != size:
242
            raise Exception('Expected {:d} bytes, but got {:d}'.format(size, len(message)))
243 244 245 246 247 248 249 250 251 252
        data = GLib.Bytes.new(message)
        variant = GLib.Variant.new_from_bytes(GLib.VariantType.new('a{sv}'),
                                              data, False)
        if order != self.byteorder:
            variant = GLib.Variant.byteswap(variant)

        return variant

    def read(self):
        header = self.file.read(HEADER_SIZE)
253
        if not header:
254 255
            # Remote end quit
            return None, None
256
        order, _, cmdtype, size = self.decode_header(header)
257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272
        msg = self.file.read(size)
        if len(msg) != size:
            raise PushException('Did not receive full message')
        args = self.decode_message(msg, size, order)

        return cmdtype, args

    def receive(self, allowed):
        cmdtype, args = self.read()
        if cmdtype is None:
            raise PushException('Expected reply, got none')
        if cmdtype not in allowed:
            raise PushException('Unexpected reply type', cmdtype.name)
        return cmdtype, args.unpack()

    def receive_info(self):
273
        _, args = self.receive([PushCommandType.info])
274 275 276
        return args

    def receive_update(self):
277
        _, args = self.receive([PushCommandType.update])
278 279 280 281 282 283 284 285 286 287 288 289 290 291
        return args

    def receive_putobjects(self, repo):

        received_objects = []

        # Open a TarFile for reading uncompressed tar from a stream
        tar = tarfile.TarFile.open(mode='r|', fileobj=self.file)

        # Extract every tarinfo into the temp location
        #
        # This should block while tar.next() reads the next
        # tar object from the stream.
        while True:
292
            filepos = tar.fileobj.tell()
293 294
            tar_info = tar.next()
            if not tar_info:
295 296 297 298
                # End of stream marker consists of two 512 Byte blocks.
                # Current Python tarfile stops reading after the first block.
                # Read the second block as well to ensure the stream is at
                # the right position for following messages.
299
                if tar.fileobj.tell() - filepos < 1024:
300
                    tar.fileobj.read(512)
301 302 303 304 305 306 307 308 309 310 311
                break

            tar.extract(tar_info, self.tmpdir)
            received_objects.append(tar_info.name)

        # Finished with this stream
        tar.close()

        return received_objects

    def receive_status(self):
312
        _, args = self.receive([PushCommandType.status])
313 314 315
        return args

    def receive_done(self):
316
        _, args = self.receive([PushCommandType.done])
317 318 319
        return args


320
def parse_remote_location(remotepath):
321 322 323
    """Parse remote artifact cache URL that's been specified in our config."""
    remote_host = remote_user = remote_repo = None

324 325 326 327
    url = urlparse(remotepath, scheme='file')
    if url.scheme:
        if url.scheme not in ['file', 'ssh']:
            raise PushException('Only URL schemes file and ssh are allowed, '
328
                                'not "{}"'.format(url.scheme))
329 330 331
        remote_host = url.hostname
        remote_user = url.username
        remote_repo = url.path
332
        remote_port = url.port or 22
333 334 335 336 337 338 339 340 341 342 343
    else:
        # Scp/git style remote (user@hostname:path)
        parts = remotepath.split('@', 1)
        if len(parts) > 1:
            remote_user = parts[0]
            remainder = parts[1]
        else:
            remote_user = None
            remainder = parts[0]
        parts = remainder.split(':', 1)
        if len(parts) != 2:
344
            raise PushException('Remote repository "{}" does not '
345
                                'contain a hostname and path separated '
346
                                'by ":"'.format(remotepath))
347
        remote_host, remote_repo = parts
348 349
        # This form doesn't make it possible to specify a non-standard port.
        remote_port = 22
350 351 352 353 354

    return remote_host, remote_user, remote_repo, remote_port


def ssh_commandline(remote_host, remote_user=None, remote_port=22):
355 356 357
    if remote_host is None:
        return []

358 359 360 361 362 363 364 365 366
    ssh_cmd = ['ssh']
    if remote_user:
        ssh_cmd += ['-l', remote_user]
    if remote_port != 22:
        ssh_cmd += ['-p', str(remote_port)]
    ssh_cmd += [remote_host]
    return ssh_cmd


367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391
def foo_run(func, args, stdin_fd, stdout_fd, stderr_fd):
    sys.stdin = open(stdin_fd, 'r')
    sys.stdout = open(stdout_fd, 'w')
    sys.stderr = open(stderr_fd, 'w')
    func(args)


class ProcessWithPipes(object):
    def __init__(self, func, args, *, stderr=None):
        r0, w0 = os.pipe()
        r1, w1 = os.pipe()
        if stderr is None:
            r2, w2 = os.pipe()
        else:
            w2 = stderr.fileno()
        self.proc = multiprocessing.Process(target=foo_run, args=(func, args, r0, w1, w2))
        self.proc.start()
        self.stdin = open(w0, 'wb')
        os.close(r0)
        self.stdout = open(r1, 'rb')
        os.close(w1)
        if stderr is None:
            self.stderr = open(r2, 'rb')
            os.close(w2)

392 393 394
        # The eventual return code
        self.returncode = -1

395 396 397 398 399
    def wait(self):
        self.proc.join()
        self.returncode = self.proc.exitcode


400
class OSTreePusher(object):
401
    def __init__(self, repopath, remotepath, branches=None, verbose=False,
402 403 404 405 406 407 408
                 debug=False, output=None):
        self.repopath = repopath
        self.remotepath = remotepath
        self.verbose = verbose
        self.debug = debug
        self.output = output

409
        self.remote_host, self.remote_user, self.remote_repo, self.remote_port = \
410
            parse_remote_location(remotepath)
411 412 413 414 415 416 417 418

        if self.repopath is None:
            self.repo = OSTree.Repo.new_default()
        else:
            self.repo = OSTree.Repo.new(Gio.File.new_for_path(self.repopath))
        self.repo.open(None)

        # Enumerate branches to push
419
        if branches is None:
420 421 422 423 424 425 426 427
            _, self.refs = self.repo.list_refs(None, None)
        else:
            self.refs = {}
            for branch in branches:
                _, rev = self.repo.resolve_rev(branch, False)
                self.refs[branch] = rev

        # Start ssh
428
        ssh_cmd = ssh_commandline(self.remote_host, self.remote_user, self.remote_port)
429

430
        ssh_cmd += ['bst-artifact-receive']
431 432 433 434
        if self.verbose:
            ssh_cmd += ['--verbose']
        if self.debug:
            ssh_cmd += ['--debug']
435 436
        if not self.remote_host:
            ssh_cmd += ['--pull-url', self.remote_repo]
437
        ssh_cmd += [self.remote_repo]
438

439
        logging.info('Executing {}'.format(' '.join(ssh_cmd)))
440 441 442 443 444 445 446 447

        if self.remote_host:
            self.ssh = subprocess.Popen(ssh_cmd, stdin=subprocess.PIPE,
                                        stdout=subprocess.PIPE,
                                        stderr=self.output,
                                        start_new_session=True)
        else:
            self.ssh = ProcessWithPipes(receive_main, ssh_cmd[1:], stderr=self.output)
448 449 450 451 452 453 454 455 456 457 458 459 460 461

        self.writer = PushMessageWriter(self.ssh.stdin)
        self.reader = PushMessageReader(self.ssh.stdout)

    def needed_commits(self, remote, local, needed):
        parent = local
        if remote == '0' * 64:
            # Nonexistent remote branch, use None for convenience
            remote = None
        while parent != remote:
            needed.add(parent)
            _, commit = self.repo.load_variant_if_exists(OSTree.ObjectType.COMMIT,
                                                         parent)
            if commit is None:
462 463
                raise PushException('Shallow history from commit {} does '
                                    'not contain remote commit {}'.format(local, remote))
464 465 466 467
            parent = OSTree.commit_get_parent(commit)
            if parent is None:
                break
        if remote is not None and parent != remote:
468
            self.writer.send_done()
469 470
            raise PushExistsException('Remote commit {} not descendent of '
                                      'commit {}'.format(remote, local))
471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504

    def needed_objects(self, commits):
        objects = set()
        for rev in commits:
            _, reachable = self.repo.traverse_commit(rev, 0, None)
            for obj in reachable:
                objname = OSTree.object_to_string(obj[0], obj[1])
                if obj[1] == OSTree.ObjectType.FILE:
                    # Make this a filez since we're archive-z2
                    objname += 'z'
                elif obj[1] == OSTree.ObjectType.COMMIT:
                    # Add in detached metadata
                    metaobj = objname + 'meta'
                    metapath = ostree_object_path(self.repo, metaobj)
                    if os.path.exists(metapath):
                        objects.add(metaobj)

                    # Add in Endless compat files
                    for suffix in ['sig', 'sizes2']:
                        metaobj = obj[0] + '.' + suffix
                        metapath = ostree_object_path(self.repo, metaobj)
                        if os.path.exists(metapath):
                            objects.add(metaobj)
                objects.add(objname)
        return objects

    def close(self):
        self.ssh.stdin.close()
        return self.ssh.wait()

    def run(self):
        remote_refs = {}
        update_refs = {}

505 506 507
        # Send info immediately
        self.writer.send_info(self.repo, list(self.refs.keys()))

508 509 510 511 512 513 514 515 516 517 518
        # Receive remote info
        logging.info('Receiving repository information')
        args = self.reader.receive_info()
        remote_mode = args['mode']
        if remote_mode != OSTree.RepoMode.ARCHIVE_Z2:
            raise PushException('Can only push to archive-z2 repos')
        remote_refs = args['refs']
        for branch, rev in self.refs.items():
            remote_rev = remote_refs.get(branch, '0' * 64)
            if rev != remote_rev:
                update_refs[branch] = remote_rev, rev
519
        if not update_refs:
520 521
            logging.info('Nothing to update')
            self.writer.send_done()
522
            raise PushExistsException('Nothing to update')
523 524 525 526 527 528 529 530 531 532 533 534 535

        # Send update command
        logging.info('Sending update request')
        self.writer.send_update(update_refs)

        # Receive status for update request
        args = self.reader.receive_status()
        if not args['result']:
            self.writer.send_done()
            raise PushException(args['message'])

        # Collect commits and objects to push
        commits = set()
536
        exc_info = None
537
        ref_count = 0
538 539
        for branch, revs in update_refs.items():
            logging.info('Updating {} {} to {}'.format(branch, revs[0], revs[1]))
540 541 542 543 544 545 546 547
            try:
                self.needed_commits(revs[0], revs[1], commits)
                ref_count += 1
            except PushExistsException:
                if exc_info is None:
                    exc_info = sys.exc_info()

        # Re-raise PushExistsException if all refs exist already
548
        if ref_count == 0 and exc_info:
549 550
            raise exc_info[0].with_traceback(exc_info[1], exc_info[2])

551 552 553 554 555 556
        logging.info('Enumerating objects to send')
        objects = self.needed_objects(commits)

        # Send all the objects to receiver, checking status after each
        self.writer.send_putobjects(self.repo, objects)

557 558 559 560 561 562
        # Inform receiver that all objects have been sent
        self.writer.send_done()

        # Wait until receiver has completed
        self.reader.receive_done()

563 564 565
        return self.close()


566 567 568 569 570 571
# OSTreeReceiver is on the receiving end of an OSTree push.
#
# Args:
#     repopath (str): On-disk location of the receiving repository.
#     pull_url (str): Redirection for clients who want to pull, not push.
#
572
class OSTreeReceiver(object):
573
    def __init__(self, repopath, pull_url):
574
        self.repopath = repopath
575
        self.pull_url = pull_url
576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592

        if self.repopath is None:
            self.repo = OSTree.Repo.new_default()
        else:
            self.repo = OSTree.Repo.new(Gio.File.new_for_path(self.repopath))
        self.repo.open(None)

        repo_tmp = os.path.join(self.repopath, 'tmp')
        self.tmpdir = tempfile.mkdtemp(dir=repo_tmp, prefix='bst-push-')
        self.writer = PushMessageWriter(sys.stdout.buffer)
        self.reader = PushMessageReader(sys.stdin.buffer, tmpdir=self.tmpdir)

        # Set a sane umask before writing any objects
        os.umask(0o0022)

    def close(self):
        shutil.rmtree(self.tmpdir)
593
        sys.stdout.flush()
594 595 596 597 598 599 600 601 602 603 604 605 606 607 608
        return 0

    def run(self):
        try:
            exit_code = self.do_run()
            self.close()
            return exit_code
        except:
            # BLIND EXCEPT - Just abort if we receive any exception, this
            # can be a broken pipe, a tarfile read error when the remote
            # connection is closed, a bug; whatever happens we want to cleanup.
            self.close()
            raise

    def do_run(self):
609 610 611 612 613
        # Receive remote info
        args = self.reader.receive_info()
        remote_refs = args['refs']

        # Send info back
614 615
        self.writer.send_info(self.repo, list(remote_refs.keys()),
                              pull_url=self.pull_url)
616 617 618 619 620 621 622

        # Wait for update or done command
        cmdtype, args = self.reader.receive([PushCommandType.update,
                                             PushCommandType.done])
        if cmdtype == PushCommandType.done:
            return 0
        update_refs = args
623

624 625 626 627 628 629 630 631 632 633
        profile_names = set()
        for update_ref in update_refs:
            # Strip off the SHA256 sum on the right of the reference,
            # leaving the project and element name
            project_and_element_name = re.sub(r"/[a-z0-9]+$", '', update_ref)
            profile_names.add(project_and_element_name)

        profile_name = '_'.join(profile_names)
        profile_start(Topics.ARTIFACT_RECEIVE, profile_name)

634 635 636 637 638 639 640 641 642 643 644 645 646
        self.writer.send_status(True)

        # Wait for putobjects or done
        cmdtype, args = self.reader.receive([PushCommandType.putobjects,
                                             PushCommandType.done])

        if cmdtype == PushCommandType.done:
            logging.debug('Received done before any objects, exiting')
            return 0

        # Receive the actual objects
        received_objects = self.reader.receive_putobjects(self.repo)

647 648 649
        # Ensure that pusher has sent all objects
        self.reader.receive_done()

650
        # If we didn't get any objects, we're done
651
        if not received_objects:
652 653 654 655 656 657 658 659 660 661
            return 0

        # Got all objects, move them to the object store
        for obj in received_objects:
            tmp_path = os.path.join(self.tmpdir, obj)
            obj_path = ostree_object_path(self.repo, obj)
            os.makedirs(os.path.dirname(obj_path), exist_ok=True)
            logging.debug('Renaming {} to {}'.format(tmp_path, obj_path))
            os.rename(tmp_path, obj_path)

662 663 664 665 666 667
        # Verify that we have the specified commit objects
        for branch, revs in update_refs.items():
            _, has_object = self.repo.has_object(OSTree.ObjectType.COMMIT, revs[1], None)
            if not has_object:
                raise PushException('Missing commit {} for ref {}'.format(revs[1], branch))

668 669 670 671 672
        # Finally, update the refs
        for branch, revs in update_refs.items():
            logging.debug('Setting ref {} to {}'.format(branch, revs[1]))
            self.repo.set_ref_immediate(None, branch, revs[1], None)

673 674 675
        # Inform pusher that everything is in place
        self.writer.send_done()

676 677
        profile_end(Topics.ARTIFACT_RECEIVE, profile_name)

678 679 680
        return 0


681 682 683 684
# initialize_push_connection()
#
# Test that we can connect to the remote bst-artifact-receive program, and
# receive the pull URL for this artifact cache.
685 686
#
# We don't want to make the user wait until the first artifact has been built
687 688 689 690 691
# to discover that they actually cannot push, so this should be called early.
#
# The SSH push protocol doesn't allow pulling artifacts. We don't want to
# require users to specify two URLs for a single cache, so we have the push
# protocol return the corresponding pull URL as part of the 'hello' response.
692 693 694 695
#
# Args:
#   remote: The ssh remote url to push to
#
696 697 698
# Returns:
#   (str): The URL that should be used for pushing to this cache.
#
699 700
# Raises:
#   PushException if there was an issue connecting to the remote.
701 702
def initialize_push_connection(remote):
    remote_host, remote_user, remote_repo, remote_port = parse_remote_location(remote)
703 704
    ssh_cmd = ssh_commandline(remote_host, remote_user, remote_port)

705 706 707 708
    if remote_host:
        # We need a short timeout here because if 'remote' isn't reachable at
        # all, the process will hang until the connection times out.
        ssh_cmd += ['-oConnectTimeout=3']
709 710 711

    ssh_cmd += ['bst-artifact-receive', remote_repo]

712 713 714 715 716 717
    if remote_host:
        ssh = subprocess.Popen(ssh_cmd, stdin=subprocess.PIPE,
                               stdout=subprocess.PIPE, stderr=subprocess.PIPE)
    else:
        ssh_cmd += ['--pull-url', remote_repo]
        ssh = ProcessWithPipes(receive_main, ssh_cmd[1:])
718 719

    writer = PushMessageWriter(ssh.stdin)
720 721
    reader = PushMessageReader(ssh.stdout)

722 723 724 725
    try:
        writer.send_hello()
        args = reader.receive_info()
        writer.send_done()
726

727 728 729 730 731 732 733 734 735 736 737 738 739
        if 'pull_url' in args:
            pull_url = args['pull_url']
        else:
            raise PushException(
                "Remote cache did not tell us its pull URL. This cache probably "
                "requires updating to a newer version of `bst-artifact-receive`.")
    except PushException as protocol_error:
        # If we get a read error on the wire, let's first see if SSH reported
        # an error such as 'Permission denied'. If so this will be much more
        # useful to the user than the "Expected reply, got none" sort of
        # message that reader.receive_info() will have raised.
        ssh.wait()
        if ssh.returncode != 0:
740 741
            ssh_error = ssh.stderr.read().decode('unicode-escape').strip()
            raise PushException("{}".format(ssh_error))
742 743
        else:
            raise protocol_error
744

745 746
    return pull_url

747

748 749 750 751 752 753 754
# push()
#
# Run the pusher in process, with logging going to the output file
#
# Args:
#   repo: The local repository path
#   remote: The ssh remote url to push to
755
#   branches: The refs to push
756
#   output: The output where logging should go
757 758 759 760 761 762 763 764
#
# Returns:
#   (bool): True if the remote was updated, False if it already existed
#           and no updated was required
#
# Raises:
#   PushException if there was an error
#
765
def push(repo, remote, branches, output):
766 767 768 769

    logging.basicConfig(format='%(module)s: %(levelname)s: %(message)s',
                        level=logging.INFO, stream=output)

770
    pusher = OSTreePusher(repo, remote, branches, True, False, output=output)
771 772 773 774 775 776

    def terminate_push():
        pusher.close()

    with _signals.terminator(terminate_push):
        try:
777 778
            pusher.run()
            return True
779 780 781 782
        except ConnectionError as e:
            # Connection attempt failed or connection was terminated unexpectedly
            terminate_push()
            raise PushException("Connection failed") from e
783 784 785 786 787 788
        except PushException:
            terminate_push()
            raise
        except PushExistsException:
            # If the commit already existed, just bail out
            # on the push and dont bother re-raising the error
789
            logging.info("Ref {} was already present in remote {}".format(branches, remote))
790
            terminate_push()
791
            return False
792 793


794 795 796
@click.command(short_help="Receive pushed artifacts over ssh")
@click.option('--verbose', '-v', is_flag=True, default=False, help="Verbose mode")
@click.option('--debug', '-d', is_flag=True, default=False, help="Debug mode")
797 798
@click.option('--pull-url', type=str, required=True,
              help="Clients who try to pull over SSH will be redirected here")
799
@click.argument('repo', type=click.Path(file_okay=False, dir_okay=True, writable=True, exists=True))
800
def receive_main(verbose, debug, pull_url, repo):
801 802
    """A BuildStream sister program for receiving artifacts send to a shared artifact cache
    """
803
    loglevel = logging.WARNING
804
    if verbose:
805
        loglevel = logging.INFO
806
    if debug:
807 808 809 810
        loglevel = logging.DEBUG
    logging.basicConfig(format='%(module)s: %(levelname)s: %(message)s',
                        level=loglevel, stream=sys.stderr)

811
    receiver = OSTreeReceiver(repo, pull_url)
812
    return receiver.run()