Commit 0875655c authored by Adriaan's avatar Adriaan
Browse files

Merge branch...

Merge branch '45-y-axis-label-is-broken-in-plotly-visualization-after-resources-refactor' into 'develop'

Resolve "y-axis label is broken in plotly visualization after resources-refactor"

Closes #45

See merge request !38
parents 1b878c41 16345d40
Loading
Loading
Loading
Loading
Loading
+118 −94
Original line number Diff line number Diff line
@@ -4,7 +4,7 @@
# Copyright (C) Qblox BV & Orange Quantum Systems Holding BV (2020-2021)
# -----------------------------------------------------------------------------
from __future__ import annotations
from typing import TYPE_CHECKING, Tuple, Union, List
from typing import TYPE_CHECKING, Tuple, Union, List, Dict, Optional, Any
import logging
import inspect
import numpy as np
@@ -25,7 +25,7 @@ if TYPE_CHECKING:
    from quantify.scheduler.types import Schedule


def new_pulse_fig(figsize=None) -> Tuple[Figure, Union[Axes, List[Axes]]]:
def new_pulse_fig(figsize: Optional[Tuple[int, int]] = None) -> Tuple[Figure, Union[Axes, List[Axes]]]:
    """
    Open a new figure and configure it to plot pulse schemes.
    """
@@ -42,7 +42,7 @@ def new_pulse_fig(figsize=None) -> Tuple[Figure, Union[Axes, List[Axes]]]:
    return fig, ax


def new_pulse_subplot(fig: 'Figure', *args, **kwargs) -> 'Axes':
def new_pulse_subplot(fig: Figure, *args, **kwargs) -> Axes:
    """
    Add a new subplot configured for plotting pulse schemes to a figure.

@@ -56,8 +56,8 @@ def new_pulse_subplot(fig: 'Figure', *args, **kwargs) -> 'Axes':
    return ax


def mwPulse(ax, pos, y_offs=0, width=1.5, amp=1, label=None, phase=0, label_height=1.3, color='C0',
            modulation='normal', **plot_kws) -> float:
def mwPulse(ax: Axes, pos: float, y_offs: float = .0,  width: float = 1.5, amp: float = 1, label: Optional[str] = None,
            phase=0, label_height: float = 1.3, color: str = 'C0', modulation: str = 'normal', **plot_kws) -> float:
    """
    Draw a microwave pulse: Gaussian envelope with modulation.
    """
@@ -82,7 +82,8 @@ def mwPulse(ax, pos, y_offs=0, width=1.5, amp=1, label=None, phase=0, label_heig
    return pos + width


def fluxPulse(ax, pos, y_offs=0, width=2.5, s=.1, amp=1.5, label=None, label_height=1.7, color='C1', **plot_kws) -> float:
def fluxPulse(ax: Axes, pos: float, y_offs: float = .0, width: float = 2.5, s: float = .1, amp: float = 1.5,
              label: Optional[str] = None, label_height: float = 1.7, color: str = 'C1', **plot_kws) -> float:
    """
    Draw a smooth flux pulse, where the rising and falling edges are given by
    Fermi-Dirac functions.
@@ -100,7 +101,8 @@ def fluxPulse(ax, pos, y_offs=0, width=2.5, s=.1, amp=1.5, label=None, label_hei
    return pos + width


def ramZPulse(ax, pos, y_offs=0, width=2.5, s=0.1, amp=1.5, sep=1.5, color='C1') -> float:
def ramZPulse(ax: Axes, pos: float, y_offs: float = .0, width: float = 2.5, s: float = .1, amp: float = 1.5,
              sep: float = 1.5, color: str = 'C1') -> float:
    """
    Draw a Ram-Z flux pulse, i.e. only part of the pulse is shaded, to indicate
    cutting off the pulse at some time.
@@ -117,8 +119,9 @@ def ramZPulse(ax, pos, y_offs=0, width=2.5, s=0.1, amp=1.5, sep=1.5, color='C1')
    return pos + width


def interval(ax, start, stop, y_offs=0, height=1.5, label=None, label_height=None, vlines=True, color='k',
             arrowstyle='<|-|>', **plot_kws) -> None:
def interval(ax: Axes, start: float, stop: float, y_offs: float = .0, height: float = 1.5, label: Optional[str] = None,
             label_height: Optional[str] = None, vlines: bool = True, color: str = 'k', arrowstyle: str = '<|-|>',
             **plot_kws) -> None:
    """
    Draw an arrow to indicate an interval.
    """
@@ -138,7 +141,8 @@ def interval(ax, start, stop, y_offs=0, height=1.5, label=None, label_height=Non
        ax.text((start + stop) / 2, label_height+y_offs, label, color=color, ha='center').set_clip_on(True)


def meter(ax, x0, y0, y_offs=0, w=1.1, h=.8, color='black', fillcolor=None) -> None:
def meter(ax: Axes, x0: float, y0: float, y_offs: float = .0, w: float = 1.1, h: float = .8, color: str = 'black',
          fillcolor: Optional[str] = None) -> None:
    """
    Draws a measurement meter on the specified position.
    """
@@ -157,7 +161,8 @@ def meter(ax, x0, y0, y_offs=0, w=1.1, h=.8, color='black', fillcolor=None) -> N
             zorder=5)


def box_text(ax, x0, y0, text='', w=1.1, h=.8, color='black', fillcolor=None, textcolor='black', fontsize=None) -> None:
def box_text(ax: Axes, x0: float, y0: float, text: str = '', w: float = 1.1, h: float = .8, color: str = 'black',
             fillcolor: Optional[str] = None, textcolor: str = 'black', fontsize: Optional[int] = None) -> None:
    """
    Draws a box filled with text at the specified position.
    """
@@ -172,12 +177,12 @@ def box_text(ax, x0, y0, text='', w=1.1, h=.8, color='black', fillcolor=None, te


def pulse_diagram_plotly(schedule: Schedule,
                         port_list: list = None,
                         port_list: Optional[List[str]] = None,
                         fig_ch_height: float = 150,
                         fig_width: float = 1000,
                         modulation_if: float = 0,
                         modulation: bool = True,
                         sampling_rate: float = 1e9
                         sampling_rate: int = 1e9
                         ) -> Figure:
    """
    Produce a plotly visualization of the pulses used in the schedule.
@@ -205,97 +210,116 @@ def pulse_diagram_plotly(schedule: Schedule,
        the plot
    """

    if port_list is None:  # determine the channel list automatically.
        auto_map = True
        offset_idx = 0
        nr_rows = 8
        port_map = {}
    else:
        auto_map = False
        nr_rows = len(port_list)
        port_map = dict(zip(port_list, range(len(port_list))))
        print(port_map)
    port_map: Dict[str, int] = dict()
    ports_length: int = 8
    auto_map: bool = True if port_list is None else False

    def _populate_port_mapping(map: Dict[str, int]) -> None:
        """
        Dynammically add up to 8 ports to the port_map dictionary.
        """
        offset_idx: int = 0

        for t_constr in schedule.timing_constraints:
            operation = schedule.operations[t_constr['operation_hash']]
            for pulse_info in operation['pulse_info']:
                if offset_idx == ports_length:
                    return

    fig = make_subplots(rows=nr_rows, cols=1, shared_xaxes=True, vertical_spacing=0.02)
    fig.update_layout(height=fig_ch_height*nr_rows, width=fig_width, title=schedule.data['name'], showlegend=False)
                port = pulse_info['port']
                if port is None:
                    continue

                if port not in port_map:
                    port_map[port] = offset_idx
                    offset_idx += 1

    if auto_map is False:
        ports_length = len(port_list)
        port_map = dict(zip(port_list, range(len(port_list))))
    else:
        _populate_port_mapping(port_map)
        ports_length = len(port_map)

    nrows = ports_length
    fig = make_subplots(rows=nrows, cols=1, shared_xaxes=True, vertical_spacing=0.02)
    fig.update_layout(height=fig_ch_height*nrows, width=fig_width,
                      title=schedule.data['name'], showlegend=False)
    colors = px.colors.qualitative.Plotly
    col_idx = 0
    col_idx: int = 0

    for pls_idx, t_constr in enumerate(schedule.timing_constraints):
        op = schedule.operations[t_constr['operation_hash']]
        operation = schedule.operations[t_constr['operation_hash']]

        for pulse_info in operation['pulse_info']:
            if pulse_info['port'] not in port_map:
                # Do not draw pulses for this port
                continue

            if pulse_info['port'] is None:
                logger.warning(
                    f"Unable to draw pulse for pulse_info due to missing 'port' for \
                        operation name={operation['name']} \
                        id={t_constr['operation_hash']} pulse_info={pulse_info}")
                continue

            if pulse_info['wf_func'] is None:
                logger.warning(
                    f"Unable to draw pulse for pulse_info due to missing 'wf_func' for \
                        operation name={operation['name']} \
                        id={t_constr['operation_hash']} pulse_info={pulse_info}")
                continue

            # port to map the waveform too
            port: Optional[str] = pulse_info['port']

        for p in op['pulse_info']:
            # function to generate waveform
            wf_func: Optional[str] = import_func_from_string(pulse_info['wf_func'])

            # iterate through the colors in the color map
            col_idx = (col_idx+1) % len(colors)

            # times at which to evaluate waveform
            t0 = t_constr['abs_time']+p['t0']
            t = np.arange(t0, t0+p['duration'], 1/sampling_rate)

            # function to generate waveform
            if p['wf_func'] is not None:
                wf_func = import_func_from_string(p['wf_func'])
            t0 = t_constr['abs_time'] + pulse_info['t0']
            t = np.arange(t0, t0+pulse_info['duration'], 1/sampling_rate)

            # select the arguments for the waveform function that are present in pulse info
            par_map = inspect.signature(wf_func).parameters
            wf_kwargs = {}
            for kw in par_map.keys():
                    if kw in p.keys():
                        wf_kwargs[kw] = p[kw]
                if kw in pulse_info.keys():
                    wf_kwargs[kw] = pulse_info[kw]

            # Calculate the numerical waveform using the wf_func
            wf = wf_func(t=t, **wf_kwargs)

            # optionally adds some modulation
                if modulation and modulation_if == 0.0 and 'clock' in p.keys():
            if modulation and modulation_if == 0.0 and 'clock' in pulse_info:
                # apply modulation to the waveforms
                    wf = modulate_wave(t, wf, schedule.resources[p['clock']]['freq'])
                wf = modulate_wave(t, wf, schedule.resources[pulse_info['clock']]['freq'])

                if modulation and modulation_if > 0 and 'clock' in p.keys():
            if modulation and modulation_if > 0 and 'clock' in pulse_info:
                # apply modulation to the waveforms
                wf = modulate_wave(t, wf, modulation_if)

                port = p['port']
                # If port_list does not exist yet and using auto map, add it.
                if port not in port_map.keys() and auto_map:
                    port_map[port] = offset_idx
                    offset_idx += 1

                    # once all ports are used, don't add new ports anymore.
                    if offset_idx > nr_rows:
                        auto_map = False

                if port in port_map.keys():
            row: int = port_map[port] + 1
            # FIXME properly deal with complex waveforms.
            for i in range(2):
                showlegend = (i == 0)
                        label = op['name']
                label = operation['name']
                fig.add_trace(go.Scatter(x=t, y=wf.imag, mode='lines', name=label, legendgroup=pls_idx,
                                                 showlegend=showlegend,
                                                 line_color='lightgrey'),
                                      row=port_map[port]+1, col=1)
                                         showlegend=showlegend, line_color='lightgrey'),
                              row=row, col=1)
                fig.add_trace(go.Scatter(x=t, y=wf.real, mode='lines', name=label, legendgroup=pls_idx,
                                                 showlegend=showlegend,
                                                 line_color=colors[col_idx]),
                                      row=port_map[port]+1, col=1)

    for r in range(nr_rows):
        title = ''
        if r+1 == nr_rows:
            title = 'Time'
            fig.update_xaxes(row=r+1, col=1, tickformat=".2s",
                             hoverformat='.3s', ticksuffix='s', title=title,
                             rangeslider=dict(visible=True, thickness=0.05))

        # FIXME: units are hardcoded
        else:
            fig.update_xaxes(row=r+1, col=1, tickformat=".2s",
                             hoverformat='.3s', ticksuffix='s', title=title)
        try:
            fig.update_yaxes(row=r+1, col=1, tickformat=".2s", hoverformat='.3s',
                             ticksuffix='V', title=list(ch_map.keys())[r], range=[-1.1, 1.1])
        except Exception:
            logger.warning("{} not enough channels".format(r))
                                         showlegend=showlegend, line_color=colors[col_idx]),
                              row=row, col=1)

            fig.update_xaxes(row=row, col=1, tickformat=".2s", hoverformat='.3s', ticksuffix='s', showgrid=True)
            fig.update_yaxes(row=row, col=1, tickformat=".2s", hoverformat='.3s', ticksuffix='V', title=port,
                             range=[-1.1, 1.1])

    fig.update_xaxes(row=ports_length, col=1, title='Time',
                     tickformat=".4s",
                     rangeslider_visible=True)

    return fig