Commit 1de1c525 authored by David Hendriks's avatar David Hendriks
Browse files

updated M&S sampling

parent f86d62a3
Loading
Loading
Loading
Loading
+0 −3
Original line number Diff line number Diff line
@@ -598,9 +598,6 @@ q extrapolation (below 0.15) method
                longname="log10(Orbital_Period)",
                probdist=1.0,
                condition='(self.population_options["multiplicity"] >= 2)',
                branchpoint=1
                if max_multiplicity > 1
                else 0,  # Signal here to put a branchpoint if we have a max multiplicity higher than 1.
                gridtype="centred",
                dphasevol="({} * dlog10per)".format(LOG_LN_CONVERTER),
                valuerange=[
+101 −105
Original line number Diff line number Diff line
@@ -14,6 +14,8 @@ _count = 0 # used for file symlinking (for testing only)
_numba = False  # activate experimental numba code?


###########################
# Some context managers
class indentation_context_manager:
    def __init__(self, population_object, indentation_delta=1):
        self.population_object = population_object
@@ -26,6 +28,18 @@ class indentation_context_manager:
        self.population_object._increment_indent_depth(-self.indentation_delta)


class boxed_context_manager:
    def __init__(self, population_object, width_box=40):
        self.population_object = population_object
        self.width_box = width_box

    def __enter__(self):
        self.population_object._add_code("#" * self.width_box + "\n")

    def __exit__(self, exc, value, exc_traceback):
        self.population_object._add_code("#" * self.width_box + "\n")


class grid_sampling:
    """
    Extension to the population grid object that contains functionality to handle the metadata that will be put in the ensemble
@@ -798,7 +812,6 @@ class grid_sampling:
        ):
            sampling_variable = sampling_variable_el[1]

            self._increment_indent_depth(+1)
            self._add_code(
                "#" * 40 + "\n",
                "# Code below is for finalising the handling of this iteration of the parameter {name}\n\n".format(
@@ -824,12 +837,13 @@ class grid_sampling:
                )
            )

            self._increment_indent_depth(-2)

            if _numba and sampling_variable["dry_parallel"]:
                self._add_code("__parallel_func(phasevol,_total_starcount)\n")
                self._increment_indent_depth(-1)

            # Decrement level
            self._increment_indent_depth(-1)

    def _grid_sampling_write_grid_generator_function_end(self, dry_run):
        """
        This function is responsible for wrapping up the grid code generator. It wraps up the function, writes statements
@@ -876,6 +890,7 @@ class grid_sampling:
            ),
        )

        # Write code
        with self.open(gridcode_filename, "w", encoding="utf-8") as file:
            file.write(self.code_string)

@@ -911,43 +926,22 @@ class grid_sampling:
        Function to handle the branch point
        """

        if loopnr > 0:
            if sampling_variable["branchpoint"]:
                with indentation_context_manager(population_object=self):

                    # self._add_code(
                    #     # Add comment
                    #     "# Condition for branchpoint at {}".format(
                    #         reverse_sorted_sampling_variables[loopnr + 1][1]["name"]
                    #     )
                    #     + "\n",
                    #     "if multiplicity=={}:".format(sampling_variable["branchpoint"])
                    #     + "\n",
                    # )

        # Check if there is a branchpoint and that this is not the deepest loop
        if (sampling_variable["branchpoint"]) and (loopnr > 0):
            ###########################
            # Handle branch point
            if sampling_variable["branchcode"]:
                        self._add_code("#" * 40 + "\n")
                with boxed_context_manager(population_object=self):
                    self._add_code("# Branch code\n")

                    self._add_code(
                        "if {branchcode}:\n".format(
                            branchcode=sampling_variable["branchcode"]
                        )
                    )

                    # if sampling_variable["branchpoint"]:
                    #     self._add_code(
                    #         "# Code below will get evaluated for every system at this level of multiplicity (last one of that being {name})\n".format(
                    #             name=sampling_variable["name"]
                    #         )
                    #     )
                    # else:
                    #     self._add_code(
                    #         "# Code below will get evaluated for every generated system\n"
                    #     )

                    ###########################
                    # Indent and write the grid system call
                    with indentation_context_manager(population_object=self):
                        ###########################
                        # Handle system call
                        self._write_gridcode_system_call(
@@ -956,8 +950,10 @@ class grid_sampling:
                        )

                        self._add_code("\n")
                    self._add_code("#" * 40 + "\n")
            else:
                raise ValueError("Handling branch point but no branchcode provided")

            #
            self._add_code("\n")

    def _write_gridcode_system_call(self, sampling_variable, dry_run):
@@ -967,8 +963,12 @@ class grid_sampling:
        Then if the run is a dry run we implement the dry_run_hook or pass depending on the settings. If it is not a dry run we yield the system dict
        """

        self._add_code("#" * 40 + "\n")
        self._add_code("# grid sampling system call section\n\n")
        with boxed_context_manager(population_object=self):
            self._add_code(
                "# grid sampling system call section ({})\n\n".format(
                    sampling_variable["name"]
                )
            )

            ##################
            # Write the code that handles the probability calculation
@@ -1039,10 +1039,6 @@ class grid_sampling:
                    # or pass
                    self._add_code("pass\n", indent=1)

        #########
        # Wrap up
        self._add_code("#" * 40 + "\n")

    def _load_grid_function(self):
        """
        Function that loads the grid code from file