utils.py 5.93 KB
Newer Older
Luna's avatar
Luna committed
1 2 3
"""

Litecord
Luna's avatar
Luna committed
4
Copyright (C) 2018-2019  Luna Mendes
Luna's avatar
Luna committed
5 6 7 8 9 10 11 12 13 14 15 16 17 18 19

This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, version 3 of the License.

This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License
along with this program.  If not, see <http://www.gnu.org/licenses/>.

"""

20
import asyncio
21
import json
22
from typing import Any, Iterable, Optional, Sequence, List, Dict
Luna's avatar
Luna committed
23

24
from logbook import Logger
25
from quart.json import JSONEncoder
26
from quart import current_app as app
27 28 29 30

log = Logger(__name__)


Luna's avatar
Luna committed
31
async def async_map(function, iterable: Iterable) -> list:
Luna's avatar
Luna committed
32 33 34 35 36 37 38 39
    """Map a coroutine to an iterable."""
    res = []

    for element in iterable:
        result = await function(element)
        res.append(result)

    return res
40 41 42


async def task_wrapper(name: str, coro):
Luna's avatar
Luna committed
43
    """Wrap a given coroutine in a task."""
44 45 46 47 48 49
    try:
        await coro
    except asyncio.CancelledError:
        pass
    except:
        log.exception('{} task error', name)
50 51 52 53 54


def dict_get(mapping, key, default):
    """Return `default` even when mapping[key] is None."""
    return mapping.get(key) or default
55 56


Luna's avatar
Luna committed
57
def index_by_func(function, indexable: Sequence[Any]) -> Optional[int]:
58 59
    """Search in an idexable and return the index number
    for an iterm that has func(item) = True."""
60
    for index, item in enumerate(indexable):
61 62 63 64
        if function(item):
            return index

    return None
65 66 67 68 69 70 71


def _u(val):
    """convert to unsigned."""
    return val % 0x100000000


Luna's avatar
Luna committed
72
def mmh3(inp_str: str, seed: int = 0):
73 74 75 76 77 78 79
    """MurMurHash3 implementation.

    This seems to match Discord's JavaScript implementaiton.

    Based off
      https://github.com/garycourt/murmurhash-js/blob/master/murmurhash3_gc.js
    """
Luna's avatar
Luna committed
80
    key = [ord(c) for c in inp_str]
81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137

    remainder = len(key) & 3
    bytecount = len(key) - remainder
    h1 = seed

    # mm3 constants
    c1 = 0xcc9e2d51
    c2 = 0x1b873593
    i = 0

    while i < bytecount:
        k1 = (
            (key[i] & 0xff) |
            ((key[i + 1] & 0xff) << 8) |
            ((key[i + 2] & 0xff) << 16) |
            ((key[i + 3] & 0xff) << 24)
        )

        i += 4

        k1 = ((((k1 & 0xffff) * c1) + ((((_u(k1) >> 16) * c1) & 0xffff) << 16))) & 0xffffffff
        k1 = (k1 << 15) | (_u(k1) >> 17)
        k1 = ((((k1 & 0xffff) * c2) + ((((_u(k1) >> 16) * c2) & 0xffff) << 16))) & 0xffffffff;

        h1 ^= k1
        h1 = (h1 << 13) | (_u(h1) >> 19);
        h1b = ((((h1 & 0xffff) * 5) + ((((_u(h1) >> 16) * 5) & 0xffff) << 16))) & 0xffffffff;
        h1 = (((h1b & 0xffff) + 0x6b64) + ((((_u(h1b) >> 16) + 0xe654) & 0xffff) << 16))


    k1 = 0
    v = None

    if remainder == 3:
        v = (key[i + 2] & 0xff) << 16
    elif remainder == 2:
        v = (key[i + 1] & 0xff) << 8
    elif remainder == 1:
        v = (key[i] & 0xff)

    if v is not None:
        k1 ^= v

    k1 = (((k1 & 0xffff) * c1) + ((((_u(k1) >> 16) * c1) & 0xffff) << 16)) & 0xffffffff
    k1 = (k1 << 15) | (_u(k1) >> 17)
    k1 = (((k1 & 0xffff) * c2) + ((((_u(k1) >> 16) * c2) & 0xffff) << 16)) & 0xffffffff
    h1 ^= k1

    h1 ^= len(key)

    h1 ^= _u(h1) >> 16
    h1 = (((h1 & 0xffff) * 0x85ebca6b) + ((((_u(h1) >> 16) * 0x85ebca6b) & 0xffff) << 16)) & 0xffffffff
    h1 ^= _u(h1) >> 13
    h1 = ((((h1 & 0xffff) * 0xc2b2ae35) + ((((_u(h1) >> 16) * 0xc2b2ae35) & 0xffff) << 16))) & 0xffffffff
    h1 ^= _u(h1) >> 16

    return _u(h1) >> 0
138 139


140
class LitecordJSONEncoder(JSONEncoder):
Luna's avatar
Luna committed
141
    """Custom JSON encoder for Litecord."""
142
    def default(self, value: Any):
Luna's avatar
Luna committed
143 144
        """By default, this will try to get the to_json attribute of a given
        value being JSON encoded."""
145 146 147 148 149 150
        try:
            return value.to_json
        except AttributeError:
            return super().default(value)


151
async def pg_set_json(con):
Luna's avatar
Luna committed
152
    """Set JSON and JSONB codecs for an asyncpg connection."""
153 154
    await con.set_type_codec(
        'json',
155
        encoder=lambda v: json.dumps(v, cls=LitecordJSONEncoder),
156 157 158 159 160 161
        decoder=json.loads,
        schema='pg_catalog'
    )

    await con.set_type_codec(
        'jsonb',
162
        encoder=lambda v: json.dumps(v, cls=LitecordJSONEncoder),
163 164 165
        decoder=json.loads,
        schema='pg_catalog'
    )
166 167


Luna's avatar
Luna committed
168
def yield_chunks(input_list: Sequence[Any], chunk_size: int):
169 170 171 172 173 174 175 176 177 178 179 180
    """Yield successive n-sized chunks from l.

    Taken from https://stackoverflow.com/a/312464.

    Modified to make linter happy (variable name changes,
    typing, comments).
    """

    # range accepts step param, so we use that to
    # make the chunks
    for idx in range(0, len(input_list), chunk_size):
        yield input_list[idx:idx + chunk_size]
181 182 183 184 185

def to_update(j: dict, orig: dict, field: str) -> bool:
    """Compare values to check if j[field] is actually updating
    the value in orig[field]. Useful for icon checks."""
    return field in j and j[field] and j[field] != orig[field]
186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216


async def search_result_from_list(rows: List) -> Dict[str, Any]:
    """Generate the end result of the search query, given a list of rows.
    
    Each row must contain:
     - A bigint on `current_id`
     - An int (?) on `total_results`
     - Two bigint[], each on `before` and `after` respectively.
    """
    results = 0 if not rows else rows[0]['total_results']
    res = []

    for row in rows:
        before, after = [], []

        for before_id in reversed(row['before']):
            before.append(await app.storage.get_message(before_id))

        for after_id in row['after']:
            after.append(await app.storage.get_message(after_id))

        msg = await app.storage.get_message(row['current_id'])
        msg['hit'] = True
        res.append(before + [msg] + after)

    return {
        'total_results': results,
        'messages': res,
        'analytics_id': '',
    }