Commit 87db2c45 authored by Benoit Martin's avatar Benoit Martin

make auto kernel active

parent e045226f
Pipeline #81125020 passed with stages
in 21 minutes and 48 seconds
from enum import Enum
from enum import Enum, auto
from ...core import default
from import AbstractKernel
......@@ -9,8 +9,8 @@ class Type(Enum):
from import TorchKernel
from import KeopsKernel
AUTO = auto()
NO_KERNEL = auto()
TORCH = TorchKernel
KEOPS = KeopsKernel
......@@ -50,9 +50,12 @@ def factory(kernel_type, cuda_type=None, gpu_mode=None, *args, **kwargs):
if not isinstance(kernel_type, Type):
raise TypeError('kernel_type must be an instance of KernelType Enum')
if kernel_type in [Type.UNDEFINED, Type.NO_KERNEL]:
if kernel_type in [Type.NO_KERNEL]:
return None
if kernel_type in [Type.AUTO]:
kernel_type = kernel_selector()
res = None
hash = AbstractKernel.hash(kernel_type, cuda_type, gpu_mode, *args, **kwargs)
if hash not in instance_map:
......@@ -26,13 +26,13 @@ class KernelFactoryTest(unittest.TestCase):
def test_no_kernel_type_from_string(self):
for k in ['no_kernel', 'no-kernel', 'no kernel', 'undefined', 'UNDEFINED']:
for k in ['no_kernel', 'no-kernel', 'no kernel']:
logging.debug("testing kernel= %s" % k)
instance = dfca.kernels.factory(k, kernel_width=1.)
def test_non_cuda_kernel_factory_from_string(self):
for k in ['torch', 'TORCH', 'keops', 'KEOPS']:
for k in ['torch', 'TORCH', 'keops', 'KEOPS', 'auto', 'AUTO']:
logging.debug("testing kernel= %s" % k)
instance = dfca.kernels.factory(k, kernel_width=1.)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment