Commit 8fb69911 authored by David Hendriks's avatar David Hendriks
Browse files

updating the types of the functions

parent af2bd4e8
Loading
Loading
Loading
Loading
+13 −13
Original line number Diff line number Diff line
@@ -4,21 +4,21 @@ Module containing the predefined distribution functions
The user can use any of these distribution functions to
generate probability distributions for sampling populations

There are distributions for the following properties:
There are distributions for the following parameters:
    - mass
    - period
    - mass ratio
    - binary fraction

TODO: make some things globally present? rob does this in his module..i guess it saves calculations but not sure if im gonna do that now
TODO: make global constants stuff
TODO: add eccentricity distribution: thermal
TODO: Add SFH distributions depending on redshift
TODO: Add metallicity distributions depending on redshift
TODO: Add initial rotational velocity distributions
Tasks:
    - TODO: make some things globally present? rob does this in his module..i guess it saves calculations but not sure if im gonna do that now
    - TODO: make global constants stuff
    - TODO: add eccentricity distribution: thermal
    - TODO: Add SFH distributions depending on redshift
    - TODO: Add metallicity distributions depending on redshift
    - TODO: Add initial rotational velocity distributions
"""


import math
import numpy as np
from typing import Optional, Union
+188 −50
Original line number Diff line number Diff line
@@ -3,6 +3,9 @@ Module containing most of the utility functions for the binarycpython package

Functions here are mostly functions used in other classes/functions, or
useful functions for the user

Tasks:
    - TODO: change all prints to verbose_prints
"""

import json
@@ -11,7 +14,7 @@ import tempfile
import copy
import inspect
import ast

from typing import Union, Any
from collections import defaultdict

import h5py
@@ -25,12 +28,17 @@ from binarycpython import _binary_c_bindings
########################################################


def verbose_print(message, verbosity, minimal_verbosity):
def verbose_print(message: str, verbosity: int, minimal_verbosity: int) -> None:
    """
    Function that decides whether to print a message based on the current verbosity
    and its minimum verbosity

    if verbosity is equal or higher than the minimum, then we print
    
    Args:
        message: message to print
        verbosity: current verbosity level
        minimal_verbosity: threshold verbosity above which to print
    """

    if verbosity >= minimal_verbosity:
@@ -43,6 +51,7 @@ def remove_file(file: str, verbosity: int=0) -> None:

    Args:
        file: full filepath to the file that will be removed.
        verbosity: current verbosity level (Optional)

    Returns:
        the path of a subdirectory called binary_c_python in the TMP of the filesystem 
@@ -81,13 +90,18 @@ def temp_dir()-> str:
    return path


def create_hdf5(data_dir, name):
def create_hdf5(data_dir: str, name: str) -> None:
    """
    Function to create an hdf5 file from the contents of a directory:
     - settings file is selected by checking on files ending on settings
     - data files are selected by checking on files ending with .dat

    TODO: fix missing settingsfiles
    
    Args:
        data_dir: directory containing the data files and settings file
        name: name of hdf5file.

    """

    # Make HDF5:
@@ -150,9 +164,15 @@ def create_hdf5(data_dir, name):
########################################################


def return_binary_c_version_info(parsed=False):
def return_binary_c_version_info(parsed: bool=False) -> Union[str, dict]:
    """
    Function that returns the version information of binary_c
    Function that returns the version information of binary_c. This function calls the function _binary_c_bindings.return_version_info()

    Args:
        parsed: Boolean flag whether to parse the version_info output of binary_c. default = False

    Returns:
        Either the raw string of binary_c or a parsed version of this in the form of a nested dictionary
    """

    version_info = _binary_c_bindings.return_version_info().strip()
@@ -163,9 +183,15 @@ def return_binary_c_version_info(parsed=False):
    return version_info


def parse_binary_c_version_info(version_info_string):
def parse_binary_c_version_info(version_info_string: str) -> dict:
    """
    Function that parses the binary_c version info. Length function with a lot of branches
    Function that parses the binary_c version info. Long function with a lot of branches

    Args:
        version_info_string: raw output of version_info call to binary_c

    Returns:
        Parsed version of the version info, which is a dictionary containing the keys: 'isotopes' for isotope info, 'argpairs' for argument pair info (TODO: explain), 'ensembles' for ensemble settings/info, 'macros' for macros, 'elements' for atomic element info, 'DTlimit' for (TODO: explain), 'nucleosynthesis_sources' for nucleosynthesis sources, and 'miscellaneous' for all those that were not caught by the previous groups. 'git_branch', 'git_build', 'revision' and 'email' are also keys, but its clear what those contain.    
    """

    version_info_dict = {}
@@ -375,27 +401,44 @@ def parse_binary_c_version_info(version_info_string):
########################################################


def output_lines(output):
def output_lines(output: str) -> str:
    """
    Function that outputs the lines that were recieved from the binary_c run.
    Function that outputs the lines that were recieved from the binary_c run, but now as an iterator. 

    Args:
        output: raw binary_c output

    Returns:
        Iterator over the lines of the binary_c output        
    """


    return output.splitlines()


def parse_output(output, selected_header):
def example_parse_output(output: str, selected_header: str) -> dict:
    """
    Function that parses output of binary_c:
    Function that parses output of binary_c. This version serves as an example and is quite detailed. Custom functions can be easier:

    This function works in two cases:
    if the caught line contains output like 'example_header time=12.32 mass=0.94 ..'
    or if the line contains output like 'example_header 12.32 0.94'
    Please dont the two cases.

    You can give a 'selected_header' to catch any line that starts with that.
    Then the values will be put into a dictionary.

    TODO: Think about exporting to numpy array or pandas instead of a defaultdict
    Tasks:
        - TODO: Think about exporting to numpy array or pandas instead of a defaultdict
        - TODO: rethink whether this function is necessary at all
        - TODO: check this function again

    TODO: rethink whether this function is necessary at all
    Args:
        output: binary_c output string
        selected_header: string header of the output (the start of the line that you want to process)

    Returns:
        dictionary containing parameters as keys and lists for the values 
    """

    value_dicts = []
@@ -453,13 +496,17 @@ def parse_output(output, selected_header):
########################################################


def get_defaults(filter_values=False):
def get_defaults(filter_values: bool=False) -> dict:
    """
    Function that calls the binaryc get args function and cast it into a dictionary.

    All the values are strings

    Args:
        filter_values: whether to filter out NULL and Function defaults.

    Returns:
        dictionary containing the parameter name as key and the parameter default as value
    """

    default_output = _binary_c_bindings.return_arglines()
@@ -476,17 +523,28 @@ def get_defaults(filter_values=False):
    return default_dict


def get_arg_keys():
def get_arg_keys() -> list:
    """
    Function that return the list of possible keys to give in the arg string
    Function that return the list of possible keys to give in the arg string. This function calls get_defaults()

    Returns:
        list of all the parameters that binary_c accepts (and has default values for, since we call get_defaults())
    """

    return get_defaults().keys()


def filter_arg_dict(arg_dict):
def filter_arg_dict(arg_dict: dict) -> dict:
    """
    Function to filter out keys that contain values included in ['NULL', 'Function', '']
    
    This function is called by get_defaults()

    Args:
        arg_dict: dictionary containing the argument + default keypairs of binary_c

    Returns:
        filtered dictionary (pairs with NULL and Function values are removed)
    """

    old_dict = arg_dict.copy()
@@ -500,23 +558,33 @@ def filter_arg_dict(arg_dict):
    return new_dict


def create_arg_string(arg_dict, sort=False, filter_values=False):
def create_arg_string(arg_dict: dict, sort: bool=False, filter_values: bool=False) -> str:
    """
    Function that creates the arg string for binary_c.
    Function that creates the arg string for binary_c. Takes a dictionary containing the arguments and writes them to a string
    This string is missing the 'binary_c ' at the start. 

    Args:
        arg_dict: dictionary 
        sort: (optional, default = False) Boolean whether to sort the order of the keys.
        filter_values: (optional, default = False) filters the input dict on keys that have NULL or `function` as value.
        
    Options:
        sort: sort the order of the keys.
        filter_values: filters the input dict on keys that have NULL or `function` as value.
    Returns:
        The string built up by combining all the key + value's. 
    """

    arg_string = ""

    # Whether to filter the arguments
    if filter_values:
        arg_dict = filter_values(arg_dict)

    # 
    keys = sorted(arg_dict.keys()) if sort else arg_dict.keys()

    # 
    for key in keys:
        arg_string += "{key} {value} ".format(key=key, value=arg_dict[key])

    arg_string = arg_string.strip()
    return arg_string

@@ -526,9 +594,11 @@ def create_arg_string(arg_dict, sort=False, filter_values=False):
########################################################


def get_help(param_name="", print_help=True, fail_silently=False):
def get_help(param_name: str="", print_help: bool=True, fail_silently: bool=False) -> Union[dict, None]:
    """
    Function that returns the help info for a given parameter.
    Function that returns the help info for a given parameter, by interfacing with binary_c

    Will check whether it is a valid parameter.

    Binary_c will output things in the following order;
    - Did you mean?
@@ -538,10 +608,16 @@ def get_help(param_name="", print_help=True, fail_silently=False):

    This function reads out that structure and catches the different components of this output

    Will print a dict
    Tasks:
        - TODO: consider not returning None, but return empty dict

    return_dict: wether to return the help info dictionary
    Args: 
        param_name: name of the parameter that you want info from. Will get checked whether its a valid parameter name
        print_help: (optional, default = True) whether to print out the help information
        fail_silently: (optional, default = False) Whether to print the errors raised if the parameter isn't valid

    Returns:
        Dictionary containing the help info. This dictionary contains 'parameter_name', 'parameter_value_input_type', 'description', optionally 'macros' 
    """

    available_arg_keys = get_arg_keys()
@@ -625,13 +701,15 @@ def get_help(param_name="", print_help=True, fail_silently=False):
        return None


def get_help_all(print_help=True):
def get_help_all(print_help: bool=True) -> dict:
    """
    Function that reads out the output of the help_all api call to binary_c
    Function that reads out the output of the return_help_all api call to binary_c. This return_help_all binary_c returns all the information for the parameters, their descriptions and other properties. The output is categorized in sections.

    print_help: bool, prints all the parameters and their descriptions.
    Args:
        print_help: (optional, default = Tru) prints all the parameters and their descriptions.

    return_dict:  returns a dictionary
    Returns:
        returns a dictionary containing dictionaries per section. These dictionaries contain the parameters and descriptions etc for all the parameters in that section
    """

    # Call function
@@ -736,10 +814,17 @@ def get_help_all(print_help=True):
    return help_all_dict


def get_help_super(print_help=False, fail_silently=True):
def get_help_super(print_help: bool=False, fail_silently: bool=True) -> dict:
    """
    Function that first runs get_help_all, and then per argument also run
    the help function to get as much information as possible.
    
    Args: 
        print_help: (optional, default = False) Whether to print the information
        fail_silently: (optional, default = True) Whether to fail silently or to print the errors

    Returns:
        dictionary containing all dictionaries per section, which then contain as much info as possible per parameter.
    """

    # Get help_all information
@@ -796,12 +881,16 @@ def get_help_super(print_help=False, fail_silently=True):
    return help_all_super_dict


def write_binary_c_parameter_descriptions_to_rst_file(output_file):
def write_binary_c_parameter_descriptions_to_rst_file(output_file: str) -> None:
    """
    Function that calls the binary_c api to get the help text/descriptions for all the paramateres available in that build.
    Function that calls the get_help_super() to get the help text/descriptions for all the parameters available in that build.
    Writes the results to a .rst file that can be included in the docs. 

    TODO: add the specific version to this document
    Tasks:
        - TODO: add the specific version git branch, git build, git commit, and binary_c version to this document
    
    Args:
        output_file: name of the output .rst faile containing the ReStructuredText formatted output of all the binary_c parameters.
    """

    # Get the whole arguments dictionary
@@ -842,9 +931,21 @@ def write_binary_c_parameter_descriptions_to_rst_file(output_file):
# logfile functions
########################################################

def load_logfile(logfile):

def load_logfile(logfile: str) -> None:
    """
    Experimental function that parses the generated logfile of binary_c.

    This function is not finished and shouldn't be used yet.

    Tasks:
        - TODO: 
    
    Args:
        - logfile: filename of the logfile you want to parse

    Returns:

    """

    with open(logfile, "r") as file:
@@ -887,42 +988,61 @@ def load_logfile(logfile):
########################################################


def inspect_dict(dict_1, indent=0, print_structure=True):
def inspect_dict(input_dict: dict, indent: int=0, print_structure: bool=True) -> dict:
    """
    Function to inspect a dict.
    Function to (recursively) inspect a (nested) dictionary.
    The object that is returned is a dictionary containing the key of the input_dict, but as value it will return the type of what the value would be in the input_dict

    Works recursively if there is a nested dict.
    In this way we inspect the structure of these dictionaries, rather than the exact contents.

    Args:
        input_dict: dictionary you want to inspect
        print_structure: (optional, default = True) 
        indent: (optional, default = 0) indent of the first output
    
    Prints out keys and their value types
    Returns:
        Dictionary that has the same structure as the input_dict, but as values it has the type(input_dict[key]) (except if the value is a dict)   
    """

    structure_dict = {}

    for key, value in dict_1.items():
    # 
    for key, value in input_dict.items():
        structure_dict[key] = type(value)

        if print_structure:
            print("\t" * indent, key, type(value))

        if isinstance(value, dict):
            structure_dict[key] = inspect_dict(
                value, indent=indent + 1, print_structure=print_structure
            )

    return structure_dict


def merge_dicts(dict_1, dict_2):
def merge_dicts(dict_1: dict, dict_2: dict) -> dict:
    """
    Function to merge two dictionaries.
    Function to merge two dictionaries in a custom way. 

    Behaviour:

    When dict keys are only present in one of either: 
        - we just add the content to the new dict

    When dict keys are present in both, we decide based on the value types how to combine them:
        - dictionaries will be merged by calling recursively calling this function again
        - numbers will be added
        - (opt) lists will be appended
        - In the case that the instances do not match: for now I will raise an error

    - In the case that the instances do now match: for now I will raise an error
    Args:
        dict_1: first dictionary
        dict_2: second dictionary

    Returns:
        Merged dictionary

    When dict keys are only present in one of either, we just add the content to the new dict
    """

    # Set up new dict
@@ -941,21 +1061,26 @@ def merge_dicts(dict_1, dict_2):

    # Add the unique keys to the new dict
    for key in unique_to_dict_1:
        # If these items are ints or floats, then just put them in
        if isinstance(dict_1[key], (float, int)):
            new_dict[key] = dict_1[key]
        # Else, to be safe we should deepcopy them
        else:
            copy_dict = copy.deepcopy(dict_1[key])
            new_dict[key] = copy_dict

    for key in unique_to_dict_2:
        # If these items are ints or floats, then just put them in
        if isinstance(dict_2[key], (float, int)):
            new_dict[key] = dict_2[key]
        # Else, to be safe we should deepcopy them
        else:
            copy_dict = copy.deepcopy(dict_2[key])
            new_dict[key] = copy_dict

    # Go over the common keys:
    for key in overlapping_keys:

        # See whether the types are actually the same
        if not type(dict_1[key]) is type(dict_2[key]):
            print(
@@ -995,6 +1120,7 @@ def merge_dicts(dict_1, dict_2):
                        type(dict_1[key]), type(dict_2[key])
                    )
                )

    #
    return new_dict

@@ -1035,7 +1161,6 @@ class binarycDecoder(json.JSONDecoder):

class BinaryCEncoder(json.JSONEncoder):
    def default(self, o):
        print("inarycoij")
        try:
            str_repr = str(o)
        except TypeError:
@@ -1046,12 +1171,18 @@ class BinaryCEncoder(json.JSONEncoder):
        return JSONEncoder.default(self, o)


def binaryc_json_serializer(obj):
def binaryc_json_serializer(obj: Any) -> Any:
    """
    Custom serializer for binary_c to use when functions are present in the dictionary
    that we want to export.

    Function objects will be turned into str representations of themselves
    
    Args: 
        obj: obj being process 

    Returns: 
        Either string representation of object if the object is a function, or the object itself
    """

    if inspect.isfunction(obj):
@@ -1066,6 +1197,13 @@ def handle_ensemble_string_to_json(raw_output):
    creates a working JSON dictionary out of it.

    Having this wrapper makes it easy to

    Args:
        raw_output: raw output of the ensemble dump by binary_c

    Returns:
        json.loads(raw_output, cls=binarycDecoder)

    """

    # return json.loads(json.dumps(ast.literal_eval(raw_output)), cls=binarycDecoder)
+73 −24
Original line number Diff line number Diff line
"""
Collection of useful functions.

Part of this is copied/inspired by Robs
Rob's binary_stars module

Has functions to convert period to separation and vice versa.
calc_period_from_sep($m1,$m2,$sep) calculate the period given the separation.
calc_sep_from_period($m1,$m2,per) does the inverse.
M1,M2,separation are in solar units, period in days.
rzams($m,$z) gives you the ZAMS radius of a star
ZAMS_collision($m1,$m2,$e,$sep,$z) returns 1 if stars collide on the ZAMS

# TODO: check whether these are correct
Part of this is copied/inspired by Rob's binary_stars module

Functions:
    - calc_period_from_sep(m1, m2, sep) calculate the period given the separation.
    - calc_sep_from_period(m1, m2, per) does the inverse.
    - rzams(m, z) gives you the ZAMS radius of a star
    - ZAMS_collision(m1, m2, e, sep, z) returns 1 if stars collide on the ZAMS
    - roche_lobe(q): returns roche lobe radius in units of separation
    - ragb(m, z): radius at first thermal pulse

Tasks:
    - TODO: check whether these functions are correct
"""

import math
from typing import Union

AURSUN = 2.150445198804013386961742071435e02
YEARDY = 3.651995478818308811241877265275e02


def calc_period_from_sep(M1, M2, sep):
def calc_period_from_sep(M1: Union[int, float], M2: Union[int, float], sep: Union[int, float]) -> Union[int, float]:
    """
    calculate period from separation
    args : M1 (Msol), M2 (Msol), separation (Rsun)

    returns the period (days)
    Args:
        M1: Primary mass in solar mass
        M2: Secondary mass in solar mass
        sep: Separation in solar radii
    
    Returns: 
        period in years
    """

    return YEARDY * (sep / AURSUN) * math.sqrt(sep / (AURSUN * (M1 + M2)))


def calc_sep_from_period(M1, M2, period):
def calc_sep_from_period(M1: Union[int, float], M2: Union[int, float], period: Union[int, float]) -> Union[int, float]:
    """
    inverse of the above function
    args : M1 (Msol), M2 (Msol), period (days)
    Calculate separation from period.

    TODO: check whether this is still correct

    returns the separation (Rsun)
    Args:
        M1: Primary mass in solar mass
        M2: Secondary mass in solar mass
        period: Period of binary in days
    
    Returns:
        Separation in solar radii
    """

    return AURSUN * (period * period * (M1 + M2) / (YEARDY * YEARDY)) ** (1.0 / 3.0)


def roche_lobe(q):
def roche_lobe(q: Union[int, float]) -> Union[int, float]:
    """
    A function to evaluate R_L/a(q), Eggleton 1983.
    
    # TODO: check the definition of the mass ratio
    # TODO: check whether the logs are correct

    Args:
        q: mass ratio of the binary (secondary/primary)

    Returns:
        Roche lobe radius in units of the separation
    """

    p = q ** (1.0 / 3.0)
    return 0.49 * p * p / (0.6 * p * p + math.log(1.0 + p))


def ragb(m, z):
def ragb(m: Union[int, float], z: Union[int, float]) -> Union[int, float]:
    """
    Function to calculate radius of a star at first thermal pulse as a function of mass (z=0.02)
    Function to calculate radius of a star in units of solar radii at first thermal pulse as a function of mass (Z=0.02 only, but also good for Z=0.0001)
    
    Args:
        m: mass of star in units of solar mass
        z: metallicity of star

    Returns: 
        radius at first thermal pulse in units of solar radii
    """
    # Z=0.02 only, but also good for Z=0.001

    return m * 40.0 + 20.0
    # in Rsun


def zams_collission(m1, m2, sep, e, z):
def zams_collission(m1: Union[int, float], m2: Union[int, float], sep: Union[int, float], e: Union[int, float], z: Union[int, float]) -> Union[int, float]:
    """
    given m1,m2, separation and eccentricity (and metallicity)
    determine if two stars collide on the ZAMS

    Args:
        m1: Primary mass in solar mass
        m2: Secondary mass in solar mass
        sep: separation in solar radii
        e: eccentricity
        z: metallicity

    Returns:
        integer boolean whether the binary stars will collide at pericenter
    """

    # calculate periastron distance
@@ -80,6 +120,15 @@ def zams_collission(m1, m2, sep, e, z):
def rzams(m, z):
    """
    Function to determine the radius of a ZAMS star as a function of m and z:
    
    Based on the fits of Tout et al., 1996, MNRAS, 281, 257 

    Args:
        m: mass of star in solar mass
        z: metallicity

    Returns:
        radius of star at ZAMS, in solar radii
    """

    lzs = math.log10(z / 0.02)