Commit bb458015 authored by Benoit Martin's avatar Benoit Martin

fix unit tests

parent 181a4c3d
Pipeline #57234926 passed with stages
in 14 minutes and 4 seconds
......@@ -72,10 +72,14 @@ class Exponential:
# self.cholesky_matrices = {}
def move_data_to_(self, device):
self.initial_control_points = utilities.move_data(self.initial_control_points, device)
self.initial_momenta = utilities.move_data(self.initial_momenta, device)
self.initial_template_points = {key: utilities.move_data(value, device) for key, value in
self.initial_template_points.items()}
if self.initial_control_points is not None:
self.initial_control_points = utilities.move_data(self.initial_control_points, device)
if self.initial_momenta is not None:
self.initial_momenta = utilities.move_data(self.initial_momenta, device)
if self.initial_template_points is not None:
self.initial_template_points = {key: utilities.move_data(value, device) for key, value in
self.initial_template_points.items()}
def light_copy(self):
light_copy = Exponential(self.dense_mode,
......
from enum import Enum
from core import default, GpuMode
from core import default
from support.kernels.abstract_kernel import AbstractKernel
......@@ -17,12 +17,12 @@ class Type(Enum):
instance_map = dict()
def factory(kernel_type, cuda_type=None, gpu_mode=default.gpu_mode, *args, **kwargs):
def factory(kernel_type, cuda_type=None, gpu_mode=None, *args, **kwargs):
"""Return an instance of a kernel corresponding to the requested kernel_type"""
assert isinstance(gpu_mode, GpuMode)
if cuda_type is None:
cuda_type = default.dtype
if gpu_mode is None:
gpu_mode = default.gpu_mode
# turn enum string to enum object
if isinstance(kernel_type, str):
......@@ -42,7 +42,7 @@ def factory(kernel_type, cuda_type=None, gpu_mode=default.gpu_mode, *args, **kwa
res = None
hash = AbstractKernel.hash(kernel_type, cuda_type, gpu_mode, *args, **kwargs)
if hash not in instance_map:
res = kernel_type.value(gpu_mode, *args, **kwargs) # instantiate
res = kernel_type.value(gpu_mode=gpu_mode, cuda_type=cuda_type, *args, **kwargs) # instantiate
instance_map[hash] = res
else:
res = instance_map[hash]
......
......@@ -25,6 +25,9 @@ class AbstractKernel(ABC):
def hash(kernel_type, cuda_type, gpu_mode, *args, **kwargs):
return hash((kernel_type, cuda_type, gpu_mode, frozenset(args), frozenset(kwargs.items())))
def __hash__(self, **kwargs):
return AbstractKernel.hash(self.kernel_type, None, self.gpu_mode, **kwargs)
@abstractmethod
def convolve(self, x, y, p, mode=None):
raise NotImplementedError
......
......@@ -242,9 +242,9 @@ class KeopsKernelTest(KernelTestBase):
for gpu_mode, cuda_type in [(gpu_mode, cuda_type)
for gpu_mode in [gpu_mode for gpu_mode in GpuMode]
for cuda_type in ['float32', 'float64']]:
print('gpu_mode: ' + str(gpu_mode) + ', cuda_type: ' + cuda_type)
if gpu_mode is GpuMode.AUTO:
continue # TODO
print('gpu_mode: ' + str(gpu_mode) + ', cuda_type: ' + cuda_type)
kernel_instance = kernel_factory.factory(kernel_factory.Type.KEOPS, gpu_mode=gpu_mode, kernel_width=1., cuda_type=cuda_type)
......
......@@ -3,7 +3,7 @@ import unittest
import torch
from core import default
from core import default, GpuMode
from core.model_tools.deformations.geodesic import Geodesic
import support.kernels as kernel_factory
from torch.autograd import Variable
......@@ -37,7 +37,7 @@ class ParallelTransportTests(unittest.TestCase):
factor = 5
geodesic = Geodesic(
dense_mode=False,
kernel=kernel_factory.factory('torch', kernel_width=0.01),
kernel=kernel_factory.factory('keops', kernel_width=0.01, gpu_mode=GpuMode.NONE),
t0=0.,
use_rk2_for_shoot=True,
concentration_of_time_points=10 * factor
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment