Verified Commit 69e0aa93 authored by Matthew Burket's avatar Matthew Burket
Browse files

Merge branch '71-no-ip-as-mx' into 'master'

Resolve "Check if MX recrods are DNS"

Closes #71

See merge request dnstats/dnstats!50
parents c25df66a 7bb53df8
Loading
Loading
Loading
Loading
+85 −0
Original line number Diff line number Diff line
import enum

from dnstats.dnsvalidate.util import MaxValue, validate_numbers, is_an_ip
from dnstats.dnsutils import validate_domain

import publicsuffix2


class MXErrors(enum.Enum):
    NO_MX_RECORDS = 0
    BLANK_MX_RECORD = 1
    TOO_MANY_PARTS = 2
    TOO_FEW_PARTS = 3
    PREFERENCE_OUT_OF_RANGE = 4
    INVALID_PREFERENCE = 5
    INVALID_EXCHANGE = 6
    EXCHANGE_IS_AN_IP = 7
    NOT_PUBLIC_DOMAIN = 8
    POSSIBLE_BAD_EXCHANGE = 9


class MxRecord:
    def __init__(self, preference: int, exchange: str):
        self.preference = preference
        self.exchange = exchange


class Mx:
    def __init__(self, mx_records: list):
        self.mx_records = mx_records
        self._validate()

    def _validate(self) -> dict:
        result = dict()
        result['errors'] = []
        result['records'] = []
        if not self.mx_records or len(self.mx_records) == 0:
            result['errors'].append(MXErrors.NO_MX_RECORDS)
            self.errors = result['errors']
            return result

        for record in self.mx_records:
            if not record:
                result['errors'].append(MXErrors.BLANK_MX_RECORD)
                continue
            parts = record.split(' ')

            if len(parts) > 2:
                result['errors'].append(MXErrors.TOO_MANY_PARTS)
                continue

            if len(parts) < 2:
                result['errors'].append(MXErrors.TOO_FEW_PARTS)
                continue

            preference = parts[0]
            exchange = parts[1]

            # Check preference to be an unsigned 16 bit int, RFC 974 (Page 2)
            preference, preference_errors = validate_numbers(preference, MXErrors.INVALID_PREFERENCE,
                                                                  MXErrors.PREFERENCE_OUT_OF_RANGE, MaxValue.USIXTEEN)
            result['errors'].extend(preference_errors)

            if preference_errors:
                continue

            if is_an_ip(exchange):
                result['errors'].append(MXErrors.EXCHANGE_IS_AN_IP)
                continue

            if not validate_domain(exchange):
                result['errors'].append(MXErrors.INVALID_EXCHANGE)
                continue

            if not publicsuffix2.get_tld(exchange) in publicsuffix2.PublicSuffixList().tlds:
                result['errors'].append(MXErrors.NOT_PUBLIC_DOMAIN)
                continue

            if not exchange.endswith('.') and exchange.endswith(publicsuffix2.get_tld(exchange)):
                result['errors'].append(MXErrors.POSSIBLE_BAD_EXCHANGE)

            result['records'].append(MxRecord(preference, exchange))
        self.errors = result['errors']
        self.valid_mx_records = result['records']
        return result
+11 −18
Original line number Diff line number Diff line
@@ -2,6 +2,7 @@ import enum


from dnstats.dnsutils import validate_domain
from dnsvalidate.util import validate_numbers, MaxValue


class SoaErrors(enum.Enum):
@@ -83,33 +84,25 @@ class Soa:
            errors.append(SoaErrors.INVALID_RNAME)
        result['rname'] = rname

        result['serial'], serial_errors = validate_numbers(serial, SoaErrors.INVALID_SERIAL, SoaErrors.SERIAL_NOT_IN_RANGE)
        result['serial'], serial_errors = validate_numbers(serial, SoaErrors.INVALID_SERIAL,
                                                           SoaErrors.SERIAL_NOT_IN_RANGE, MaxValue.UTHIRTY_TW0)
        errors.extend(serial_errors)

        result['refresh'], refresh_errors = validate_numbers(refresh, SoaErrors.INVALID_REFRESH, SoaErrors.REFRESH_NOT_IN_RANGE)
        result['refresh'], refresh_errors = validate_numbers(refresh, SoaErrors.INVALID_REFRESH,
                                                             SoaErrors.REFRESH_NOT_IN_RANGE, MaxValue.UTHIRTY_TW0)
        errors.extend(refresh_errors)

        result['retry'], retry_errors = validate_numbers(retry, SoaErrors.INVALID_RETRY, SoaErrors.RETRY_NOT_IN_RANGE)
        result['retry'], retry_errors = validate_numbers(retry, SoaErrors.INVALID_RETRY,
                                                         SoaErrors.RETRY_NOT_IN_RANGE, MaxValue.UTHIRTY_TW0)
        errors.extend(retry_errors)

        result['expire'], expire_errors = validate_numbers(expire, SoaErrors.INVALID_EXPIRE, SoaErrors.EXPIRE_NOT_IN_RANGE)
        result['expire'], expire_errors = validate_numbers(expire, SoaErrors.INVALID_EXPIRE,
                                                           SoaErrors.EXPIRE_NOT_IN_RANGE, MaxValue.UTHIRTY_TW0)
        errors.extend(expire_errors)

        result['minimum'], minimum_errors = validate_numbers(minimum, SoaErrors.INVALID_MINIMUM, SoaErrors.MINIMUM_NOT_IN_RANGE)
        result['minimum'], minimum_errors = validate_numbers(minimum, SoaErrors.INVALID_MINIMUM,
                                                             SoaErrors.MINIMUM_NOT_IN_RANGE, MaxValue.UTHIRTY_TW0)
        errors.extend(minimum_errors)

        result['errors'] = errors
        return result


def validate_numbers(value: str, invalid_error: SoaErrors, out_of_range_error: SoaErrors) -> (int, list):
    errors = list()
    try:
        value_int = int(value)
    except ValueError:
        errors.append(invalid_error)
        return -1, errors

    if value_int > 4294967295 or value_int < 0:
        errors.append(out_of_range_error)
    return value_int, errors
+118 −0
Original line number Diff line number Diff line
import unittest

from dnstats.dnsvalidate.mx import Mx, MXErrors


class TestMx(unittest.TestCase):
    def test_no_mx(self):
        records = list()
        mx = Mx(records)
        expected_errors = [MXErrors.NO_MX_RECORDS]
        self.assertEqual(expected_errors, mx.errors)

    def test_one_valid_one_not(self):
        records = ['10 mail.dnstats.io.', 'taco dnstats.io.']
        mx = Mx(records)
        expected_errors = [MXErrors.INVALID_PREFERENCE]
        self.assertEqual(expected_errors, mx.errors)
        self.assertEqual(1, len(mx.valid_mx_records))

    def test_one_not_one_valid(self):
        records = ['taco dnstats.io.', '10 mail.dnstats.io.']
        mx = Mx(records)
        expected_errors = [MXErrors.INVALID_PREFERENCE]
        self.assertEqual(expected_errors, mx.errors)
        self.assertEqual(1, len(mx.valid_mx_records))

    def test_one_valid(self):
        records = ['10 mail.dnstats.io.']
        mx = Mx(records)
        expected_errors = list()
        self.assertEqual(expected_errors, mx.errors)
        self.assertEqual(1, len(mx.valid_mx_records))

    def test_two_valid(self):
        records = ['10 mail.dnstats.io.', '20 mail2.dnstats.io.']
        mx = Mx(records)
        expected_errors = list()
        self.assertEqual(expected_errors, mx.errors)
        self.assertEqual(2, len(mx.valid_mx_records))

    def test_one_invalid_preference(self):
        records = ['taco mail.dnstats.io.']
        mx = Mx(records)
        expected_errors = [MXErrors.INVALID_PREFERENCE]
        self.assertEqual(expected_errors, mx.errors)
        self.assertEqual(0, len(mx.valid_mx_records))

    def test_one_invalid_exchange(self):
        records = ['10 123']
        mx = Mx(records)
        expected_errors = [MXErrors.INVALID_EXCHANGE]
        self.assertEqual(expected_errors, mx.errors)
        self.assertEqual(0, len(mx.valid_mx_records))

    def test_two_invalid_exchange(self):
        records = ['10 123', '20 ']
        mx = Mx(records)
        expected_errors = [MXErrors.INVALID_EXCHANGE, MXErrors.INVALID_EXCHANGE]
        self.assertEqual(expected_errors, mx.errors)
        self.assertEqual(0, len(mx.valid_mx_records))

    def test_just_preference(self):
        records = ['10']
        mx = Mx(records)
        expected_errors = [MXErrors.TOO_FEW_PARTS]
        self.assertEqual(expected_errors, mx.errors)
        self.assertEqual(0, len(mx.valid_mx_records))

    def test_too_many_parts(self):
        records = ['10 10 10']
        mx = Mx(records)
        expected_errors = [MXErrors.TOO_MANY_PARTS]
        self.assertEqual(expected_errors, mx.errors)
        self.assertEqual(0, len(mx.valid_mx_records))

    def test_exchange_ip4(self):
        records = ['10 172.104.25.239']
        mx = Mx(records)
        expected_errors = [MXErrors.EXCHANGE_IS_AN_IP]
        self.assertEqual(expected_errors, mx.errors)
        self.assertEqual(0, len(mx.valid_mx_records))

    def test_exchange_ip4_int(self):
        records = ['10 10.1.1.1']
        mx = Mx(records)
        expected_errors = [MXErrors.EXCHANGE_IS_AN_IP]
        self.assertEqual(expected_errors, mx.errors)
        self.assertEqual(0, len(mx.valid_mx_records))

    def test_exchange_ip6_int(self):
        records = ['10 2600:3c03::f03c:92ff:feb0:7de']
        mx = Mx(records)
        expected_errors = [MXErrors.EXCHANGE_IS_AN_IP]
        self.assertEqual(expected_errors, mx.errors)
        self.assertEqual(0, len(mx.valid_mx_records))

    def test_two_invalid_preference(self):
        records = ['taco mail.dnstats.io.', 'taco mail.dnstats.io.']
        mx = Mx(records)
        expected_errors = [MXErrors.INVALID_PREFERENCE, MXErrors.INVALID_PREFERENCE]
        self.assertEqual(expected_errors, mx.errors)
        self.assertEqual(0, len(mx.valid_mx_records))

    def test_one_invalid_tld(self):
        records = ['10 mail.dnstats.lan.']
        mx = Mx(records)
        expected_errors = [MXErrors.NOT_PUBLIC_DOMAIN]
        self.assertEqual(expected_errors, mx.errors)
        self.assertEqual(0, len(mx.valid_mx_records))

    def test_one_no_end_dot(self):
        records = ['10 mail.dnstats.io']
        mx = Mx(records)
        expected_errors = [MXErrors.POSSIBLE_BAD_EXCHANGE]
        self.assertEqual(expected_errors, mx.errors)
        self.assertEqual(1, len(mx.valid_mx_records))

+37 −0
Original line number Diff line number Diff line
import enum

import ipaddress


class MaxValue(enum.Enum):
    USIXTEEN = 65535
    UTHIRTY_TW0 = 4294967295


def validate_numbers(value: str, invalid_error: enum.Enum, out_of_range_error: enum.Enum, max_value: MaxValue) -> (int, list):
    """
    Validate the value of unsigned ints
    :param value: Bhe int to check
    :param invalid_error: The error to return if the number is in valid
    :param out_of_range_error: The error to return if the value is out of range
    :param max_value: A member of :class:MaxValue that is max value you want to check
    :return:
    """
    errors = list()
    try:
        value_int = int(value)
    except ValueError:
        errors.append(invalid_error)
        return -1, errors

    if value_int > max_value.value or value_int < 0:
        errors.append(out_of_range_error)
    return value_int, errors


def is_an_ip(value: str) -> bool:
    try:
        ipaddress.ip_address(value)
    except ValueError:
        return False
    return True
 No newline at end of file
+29 −9
Original line number Diff line number Diff line
@@ -70,23 +70,43 @@ def full_raise(grade: Grade) -> Grade:
        return Grade.A_PLUS


def update_count_dict(d: dict, key: str):
    if key in d:
        d[key] += 1
def update_count_dict(dictt: dict, key: str) -> None:
    """
    Given a dict and key add one to the value. Set it to 1 if the key is not found.
    :param dictt: dictt to search
    :param key: key to search for
    :return: None
    """
    if key in dictt:
        dictt[key] += 1
    else:
        d[key] = 1
        dictt[key] = 1


def get_grade(d: dict, k: str, default: int) -> int:
    value = d.get(k)
def get_grade(dictt: dict, key: str, default: int) -> int:
    """
    Given key and dict get the value from the dict. If the
    :param dictt: dict to look up in
    :param key: the key to search for
    :param default: the value to return if key is not found
    :return: value from dictt from key, or default
    """
    value = dictt.get(key)
    if not value:
        return default
    else:
        return value


def not_in_penalty(d: dict, k: str, pen: int):
    if d.__contains__(k):
def not_in_penalty(dictt: dict, key: str, penalty: int):
    """
    If the
    :param dictt: dict to search
    :param key: key to search for 
    :param penalty: the value to return if the value is not found
    :return: 0 if k is in dictt, otherwise 
    """
    if dictt.__contains__(key):
        return 0
    else:
        return pen
        return penalty
Loading