Commit 96e2c62f authored by Luis Miguens Fernandez's avatar Luis Miguens Fernandez
Browse files

Merge branch 'feat/schedule_sampling2' into 'develop'

Add method to sample Schedule

See merge request !212
parents 2ce55aab d02730ad
Loading
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -34,6 +34,7 @@ Merged branches and closed issues
* Added a function to extract acquisition metadata from a schedule (#179, !180).
* Qblox ICCs - Compensated integration time for Qblox QRM IC component (!199).
* Visualization - Allow user defined axis for plotting circuit diagram (!206)
* Added method `sample_schedule` to sample a `Schedule` (!212)

0.4.0 InstrumentCoordinator and improvements to backends (2021-08-06)
---------------------------------------------------------------------
+165 −47
Original line number Diff line number Diff line
@@ -2,24 +2,68 @@
# Licensed according to the LICENCE file on the master branch
"""Functions for drawing pulse diagrams"""
from __future__ import annotations

import inspect
import logging
from typing import List, Dict, Optional
from typing_extensions import Literal
from typing import Dict, List, Optional, Tuple, Callable

import numpy as np

from plotly.subplots import make_subplots
import plotly.express as px
import plotly.graph_objects as go

from plotly.subplots import make_subplots
from quantify_core.utilities.general import import_func_from_string
from typing_extensions import Literal

from quantify_scheduler.types import Schedule
from quantify_scheduler.waveforms import modulate_wave

logger = logging.getLogger(__name__)


def _populate_port_mapping(schedule, portmap: Dict[str, int], ports_length) -> None:
    """
    Dynamically add up to 8 ports to the port_map dictionary.
    """
    offset_idx: int = 0

    for t_constr in schedule.timing_constraints:
        operation = schedule.operations[t_constr["operation_repr"]]
        for pulse_info in operation["pulse_info"]:
            if offset_idx == ports_length:
                return

            port = pulse_info["port"]
            if port is None:
                continue

            if port not in portmap:
                portmap[port] = offset_idx
                offset_idx += 1


def validate_pulse_info(pulse_info, port_map, t_constr, operation):
    if pulse_info["port"] not in port_map:
        # Do not draw pulses for this port
        return False

    if pulse_info["port"] is None:
        logger.warning(
            f"Unable to sample pulse for pulse_info due to missing 'port' for "
            f"operation name={operation['name']} "
            f"id={t_constr['operation_repr']} pulse_info={pulse_info}"
        )
        return False

    if pulse_info["wf_func"] is None:
        logger.warning(
            f"Unable to sample pulse for pulse_info due to missing 'wf_func' for "
            f"operation name={operation['name']} "
            f"id={t_constr['operation_repr']} pulse_info={pulse_info}"
        )
        return False
    return True


# pylint: disable=too-many-arguments
# pylint: disable=too-many-locals
# pylint: disable=too-many-statements
@@ -29,8 +73,8 @@ def pulse_diagram_plotly(
    fig_ch_height: float = 300,
    fig_width: float = 1000,
    modulation: Literal["off", "if", "clock"] = "off",
    modulation_if: float = 0,
    sampling_rate: int = 1e9,
    modulation_if: float = 0.0,
    sampling_rate: int = 1_000_000_000,
) -> go.Figure:
    """
    Produce a plotly visualization of the pulses used in the schedule.
@@ -63,31 +107,11 @@ def pulse_diagram_plotly(
    ports_length: int = 8
    auto_map: bool = port_list is None

    def _populate_port_mapping(portmap: Dict[str, int]) -> None:
        """
        Dynamically add up to 8 ports to the port_map dictionary.
        """
        offset_idx: int = 0

        for t_constr in schedule.timing_constraints:
            operation = schedule.operations[t_constr["operation_repr"]]
            for pulse_info in operation["pulse_info"]:
                if offset_idx == ports_length:
                    return

                port = pulse_info["port"]
                if port is None:
                    continue

                if port not in portmap:
                    portmap[port] = offset_idx
                    offset_idx += 1

    if auto_map is False:
        ports_length = len(port_list)
        port_map = dict(zip(port_list, range(len(port_list))))
    else:
        _populate_port_mapping(port_map)
        _populate_port_mapping(schedule, port_map, ports_length)
        ports_length = len(port_map)

    nrows = ports_length
@@ -106,31 +130,14 @@ def pulse_diagram_plotly(
        operation = schedule.operations[t_constr["operation_repr"]]

        for pulse_info in operation["pulse_info"]:
            if pulse_info["port"] not in port_map:
                # Do not draw pulses for this port
                continue

            if pulse_info["port"] is None:
                logger.warning(
                    f"Unable to draw pulse for pulse_info due to missing 'port' for "
                    f"operation name={operation['name']} "
                    f"id={t_constr['operation_repr']} pulse_info={pulse_info}"
                )
                continue

            if pulse_info["wf_func"] is None:
                logger.warning(
                    f"Unable to draw pulse for pulse_info due to missing 'wf_func' for "
                    f"operation name={operation['name']} "
                    f"id={t_constr['operation_repr']} pulse_info={pulse_info}"
                )
            if not validate_pulse_info(pulse_info, port_map, t_constr, operation):
                continue

            # port to map the waveform too
            port: str = pulse_info["port"]

            # function to generate waveform
            wf_func: str = import_func_from_string(pulse_info["wf_func"])
            wf_func: Callable = import_func_from_string(pulse_info["wf_func"])

            # iterate through the colors in the color map
            col_idx = (col_idx + 1) % len(colors)
@@ -227,3 +234,114 @@ def pulse_diagram_plotly(
    )

    return fig


def sample_schedule(
    schedule: Schedule,
    port_list: Optional[List[str]] = None,
    modulation: Literal["off", "if", "clock"] = "off",
    modulation_if: float = 0.0,
    sampling_rate: int = 1_000_000_000,
) -> Tuple[np.ndarray, Dict[str, np.ndarray]]:
    """Sample a schedule on discete points in time

    Parameters
    ------------
    schedule :
        The schedule to render.
    port_list :
        A list of ports to show. if set to `None` will use the first
        8 ports it encounters in the sequence.
    modulation :
        Determines if modulation is included in the visualization.
    modulation_if :
        Modulation frequency used when modulation is set to "if".
    sampling_rate :
        The time resolution used in the sampling.

    Returns
    -------
    :
        Tuple of sample times and a dicionary with the data samples for each port
    """

    port_map: Dict[str, int] = dict()
    ports_length: int = 8
    auto_map: bool = port_list is None

    if auto_map is False:
        ports_length = len(port_list)
        port_map = dict(zip(port_list, range(len(port_list))))
    else:
        _populate_port_mapping(schedule, port_map, ports_length)
        ports_length = len(port_map)

    time_window = None
    for pls_idx, t_constr in enumerate(schedule.timing_constraints):
        operation = schedule.operations[t_constr["operation_repr"]]

        for pulse_info in operation["pulse_info"]:
            if not validate_pulse_info(pulse_info, port_map, t_constr, operation):
                continue

            # times at which to evaluate waveform
            t0 = t_constr["abs_time"] + pulse_info["t0"]
            if time_window is None:
                time_window = [t0, t0 + pulse_info["duration"]]
            else:
                time_window = [
                    min(t0, time_window[0]),
                    max(t0 + pulse_info["duration"], time_window[1]),
                ]

    logger.info(f"sample_schedule: time_window {time_window}, port_map {port_map}")

    timestamps = np.arange(time_window[0], time_window[1], 1 / sampling_rate)
    waveforms = {key: np.zeros_like(timestamps) for key in port_map}

    for pls_idx, t_constr in enumerate(schedule.timing_constraints):
        operation = schedule.operations[t_constr["operation_repr"]]
        logger.debug(f"sample_schedule: {pls_idx}: {operation}")

        for pulse_info in operation["pulse_info"]:

            if not validate_pulse_info(pulse_info, port_map, t_constr, operation):
                continue

            # port to map the waveform too
            port: str = pulse_info["port"]

            # function to generate waveform
            wf_func: Callable = import_func_from_string(pulse_info["wf_func"])

            # times at which to evaluate waveform
            t0 = t_constr["abs_time"] + pulse_info["t0"]
            t1 = t0 + pulse_info["duration"]

            time_indices = np.where(np.logical_and(timestamps >= t0, timestamps <= t1))

            t = timestamps[time_indices]

            par_map = inspect.signature(wf_func).parameters
            wf_kwargs = {}
            for kwargs in par_map.keys():
                if kwargs in pulse_info.keys():
                    wf_kwargs[kwargs] = pulse_info[kwargs]

            # Calculate the numerical waveform using the wf_func
            waveform = wf_func(t=t, **wf_kwargs)

            # optionally adds some modulation
            if modulation == "clock":
                # apply modulation to the waveforms
                waveform = modulate_wave(
                    t, waveform, schedule.resources[pulse_info["clock"]]["freq"]
                )

            if modulation == "if":
                # apply modulation to the waveforms
                waveform = modulate_wave(t, waveform, modulation_if)

            waveforms[port][time_indices] = (waveform.real,)

    return timestamps, waveforms
+68 −0
Original line number Diff line number Diff line
# Repository: https://gitlab.com/quantify-os/quantify-scheduler
# Licensed according to the LICENCE file on the master branch
# pylint: disable=missing-function-docstring

import numpy as np
import pytest

from quantify_scheduler.compilation import determine_absolute_timing
from quantify_scheduler.pulse_library import SquarePulse
from quantify_scheduler.types import Schedule
from quantify_scheduler.visualization.pulse_diagram import sample_schedule


def test_sample_schedule() -> None:
    schedule = Schedule("test")
    r = SquarePulse(amp=0.2, duration=4e-9, port="SDP")
    schedule.add(r)
    rm = SquarePulse(amp=-0.2, duration=6e-9, port="T")
    schedule.add(rm, ref_pt="start")
    r = SquarePulse(amp=0.3, duration=6e-9, port="SDP")
    schedule.add(r)
    schedule.add(r)
    determine_absolute_timing(schedule=schedule)

    timestamps, waveforms = sample_schedule(schedule, sampling_rate=0.5e9)

    np.testing.assert_array_almost_equal(
        timestamps,
        np.array(
            [
                0.0e00,
                2.0e-09,
                4.0e-09,
                6.0e-09,
                8.0e-09,
                1.0e-08,
                1.2e-08,
                1.4e-08,
                1.6e-08,
            ]
        ),
    )

    np.testing.assert_array_almost_equal(
        waveforms["SDP"], np.array([0.2, 0.2, 0.2, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3])
    )
    np.testing.assert_array_almost_equal(
        waveforms["T"], np.array([-0.2, -0.2, -0.2, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])
    )


def test_sample_custom_port_list() -> None:
    schedule = Schedule("test")
    r = SquarePulse(amp=0.2, duration=4e-9, port="SDP")
    schedule.add(r)
    determine_absolute_timing(schedule=schedule)

    timestamps, waveforms = sample_schedule(
        schedule, sampling_rate=0.5e9, port_list=["SDP"]
    )
    assert list(waveforms.keys()) == ["SDP"]


def test_sample_empty_schedule() -> None:
    schedule = Schedule("test")

    with pytest.raises(TypeError):
        timestamps, waveforms = sample_schedule(schedule, sampling_rate=1e9)