Commit 181a4c3d authored by Benoit Martin's avatar Benoit Martin

working gpu_mode + better api unit tests

parent 3f711336
Pipeline #57210787 failed with stages
in 22 minutes and 6 seconds
import os
import torch
from core import GpuMode
from support import utilities
......
import numpy as np
import torch
import logging
......@@ -87,8 +86,8 @@ class MultiObjectAttachment:
Compute the current distance between source and target, assuming points are the new points of the source
We assume here that the target never moves.
"""
c1, n1, c2, n2 = MultiObjectAttachment.__get_source_and_target_centers_and_normals(points, source, target)
device, _ = utilities.get_best_device(kernel.gpu_mode)
c1, n1, c2, n2 = MultiObjectAttachment.__get_source_and_target_centers_and_normals(points, source, target, device=device)
def current_scalar_product(points_1, points_2, normals_1, normals_2):
assert points_1.device == points_2.device == normals_1.device == normals_2.device, 'tensors must be on the same device'
......@@ -105,8 +104,8 @@ class MultiObjectAttachment:
Compute the point cloud distance between source and target, assuming points are the new points of the source
We assume here that the target never moves.
"""
c1, n1, c2, n2 = MultiObjectAttachment.__get_source_and_target_centers_and_normals(points, source, target)
device, _ = utilities.get_best_device(kernel.gpu_mode)
c1, n1, c2, n2 = MultiObjectAttachment.__get_source_and_target_centers_and_normals(points, source, target, device=device)
def point_cloud_scalar_product(points_1, points_2, normals_1, normals_2):
return torch.dot(normals_1.view(-1),
......@@ -125,8 +124,8 @@ class MultiObjectAttachment:
source and target are SurfaceMesh objects
points are source points (torch tensor)
"""
c1, n1, c2, n2 = MultiObjectAttachment.__get_source_and_target_centers_and_normals(points, source, target)
device, _ = utilities.get_best_device(kernel.gpu_mode)
c1, n1, c2, n2 = MultiObjectAttachment.__get_source_and_target_centers_and_normals(points, source, target, device=device)
# alpha = normales non unitaires
areaa = torch.norm(n1, 2, 1)
......@@ -175,16 +174,19 @@ class MultiObjectAttachment:
####################################################################################################################
@staticmethod
def __get_source_and_target_centers_and_normals(points, source, target):
def __get_source_and_target_centers_and_normals(points, source, target, device=None):
if device is None:
device = points.device
dtype = str(points.dtype)
c1, n1 = source.get_centers_and_normals(points,
tensor_scalar_type=utilities.get_torch_scalar_type(dtype=dtype),
tensor_integer_type=utilities.get_torch_integer_type(dtype=dtype),
device=points.device)
device=device)
c2, n2 = target.get_centers_and_normals(tensor_scalar_type=utilities.get_torch_scalar_type(dtype=dtype),
tensor_integer_type=utilities.get_torch_integer_type(dtype=dtype),
device=points.device)
device=device)
assert c1.device == n1.device == c2.device == n2.device, 'all tensors must be on the same device, c1.device=' + str(c1.device) \
+ ', n1.device=' + str(n1.device)\
......
import time
import warnings
from copy import deepcopy
import support.kernels as kernel_factory
......@@ -40,7 +39,7 @@ class Exponential:
self.kernel = kernel
if shoot_kernel_type is not None:
self.shoot_kernel = kernel_factory.factory(shoot_kernel_type, kernel_width=kernel.kernel_width)
self.shoot_kernel = kernel_factory.factory(shoot_kernel_type, gpu_mode=kernel.gpu_mode, kernel_width=kernel.kernel_width)
else:
self.shoot_kernel = self.kernel
......
import logging
import resource
import time
import warnings
import torch
from core import default
from core.model_tools.deformations.exponential import Exponential
from in_out.array_readers_and_writers import *
import torch.multiprocessing as mp
from support import utilities
import logging
logger = logging.getLogger(__name__)
......@@ -190,9 +189,12 @@ class Geodesic:
# assert times[j-1] <= time
# assert times[j] >= time
weight_left = torch.Tensor([(times[j] - time) / (times[j] - times[j - 1])]).type(self.momenta_t0.type())
weight_right = torch.Tensor([(time - times[j - 1]) / (times[j] - times[j - 1])]).type(self.momenta_t0.type())
template_t = self.get_template_points_trajectory()
device, _ = utilities.get_best_device(self.backward_exponential.kernel.gpu_mode)
weight_left = utilities.move_data([(times[j] - time) / (times[j] - times[j - 1])], device=device, dtype=self.momenta_t0.dtype)
weight_right = utilities.move_data([(time - times[j - 1]) / (times[j] - times[j - 1])], device=device, dtype=self.momenta_t0.dtype)
template_t = {key: [utilities.move_data(v, device=device) for v in value] for key, value in self.get_template_points_trajectory().items()}
deformed_points = {key: weight_left * value[j - 1] + weight_right * value[j]
for key, value in template_t.items()}
......@@ -213,6 +215,8 @@ class Geodesic:
if self.shoot_is_modified or self.flow_is_modified:
device, _ = utilities.get_best_device(self.backward_exponential.kernel.gpu_mode)
# Backward exponential -------------------------------------------------------------------------------------
length = self.t0 - self.tmin
self.backward_exponential.number_of_time_points = \
......@@ -223,6 +227,7 @@ class Geodesic:
if self.flow_is_modified:
self.backward_exponential.set_initial_template_points(self.template_points_t0)
if self.backward_exponential.number_of_time_points > 1:
self.backward_exponential.move_data_to_(device=device)
self.backward_exponential.update()
# Forward exponential --------------------------------------------------------------------------------------
......@@ -235,6 +240,7 @@ class Geodesic:
if self.flow_is_modified:
self.forward_exponential.set_initial_template_points(self.template_points_t0)
if self.forward_exponential.number_of_time_points > 1:
self.forward_exponential.move_data_to_(device=device)
self.forward_exponential.update()
self.shoot_is_modified = False
......
import torch
import support.kernels as kernel_factory
from core import default
from core.model_tools.deformations.exponential import Exponential
from core.model_tools.deformations.geodesic import Geodesic
......@@ -219,6 +218,9 @@ class SpatiotemporalReferenceFrame:
Update the geodesic, and compute the parallel transport of each column of the modulation matrix along
this geodesic, ignoring the tangential components.
"""
device = self.geodesic.control_points_t0.device
# Update the geodesic.
self.geodesic.update()
......@@ -229,11 +231,11 @@ class SpatiotemporalReferenceFrame:
if self.transport_is_modified:
# Projects the modulation_matrix_t0 attribute columns.
self._update_projected_modulation_matrix_t0()
self._update_projected_modulation_matrix_t0(device=device)
# Initializes the projected_modulation_matrix_t attribute size.
self.projected_modulation_matrix_t = \
[torch.zeros(self.projected_modulation_matrix_t0.size()).type(self.modulation_matrix_t0.type())
[torch.zeros(self.projected_modulation_matrix_t0.size(), dtype=self.modulation_matrix_t0.dtype, device=device)
for _ in range(len(self.control_points_t))]
# Transport each column, ignoring the tangential components.
......@@ -253,7 +255,7 @@ class SpatiotemporalReferenceFrame:
# Initializes the extended projected_modulation_matrix_t variable.
projected_modulation_matrix_t_extended = [
torch.zeros(self.projected_modulation_matrix_t0.size()).type(self.modulation_matrix_t0.type())
torch.zeros(self.projected_modulation_matrix_t0.size(), dtype=self.modulation_matrix_t0.dtype, device=device)
for _ in range(len(self.control_points_t))]
# Transport each column, ignoring the tangential components.
......@@ -282,9 +284,10 @@ class SpatiotemporalReferenceFrame:
### Auxiliary methods:
####################################################################################################################
def _update_projected_modulation_matrix_t0(self):
self.projected_modulation_matrix_t0 = \
torch.zeros(self.modulation_matrix_t0.size()).type(self.modulation_matrix_t0.type())
def _update_projected_modulation_matrix_t0(self, device='cpu'):
self.projected_modulation_matrix_t0 = torch.zeros(self.modulation_matrix_t0.size(),
dtype=self.modulation_matrix_t0.dtype,
device=device)
norm_squared = self.geodesic.backward_exponential.scalar_product(
self.geodesic.control_points_t0, self.geodesic.momenta_t0, self.geodesic.momenta_t0)
......
......@@ -28,6 +28,7 @@ class AffineAtlas(AbstractStatisticalModel):
tensor_integer_type=default.tensor_integer_type,
dense_mode=default.dense_mode,
number_of_processes=default.number_of_processes,
gpu_mode=default.gpu_mode,
# dataset,
freeze_translation_vectors=default.freeze_translation_vectors,
......@@ -35,7 +36,7 @@ class AffineAtlas(AbstractStatisticalModel):
freeze_scaling_ratios=default.freeze_scaling_ratios,
**kwargs):
AbstractStatisticalModel.__init__(self, name='AffineAtlas')
AbstractStatisticalModel.__init__(self, name='AffineAtlas', gpu_mode=gpu_mode)
self.dimension = dimension
self.tensor_scalar_type = tensor_scalar_type
......@@ -46,7 +47,7 @@ class AffineAtlas(AbstractStatisticalModel):
# self.dataset = dataset
object_list, self.objects_name, self.objects_name_extension, self.objects_noise_variance, self.multi_object_attachment \
= create_template_metadata(template_specifications, self.dimension)
= create_template_metadata(template_specifications, self.dimension, gpu_mode=gpu_mode)
self.template = DeformableMultiObject(object_list)
......
import logging
import math
import torch
......@@ -18,6 +17,7 @@ from support.probability_distributions.multi_scalar_inverse_wishart_distribution
MultiScalarInverseWishartDistribution
from support.probability_distributions.normal_distribution import NormalDistribution
import logging
logger = logging.getLogger(__name__)
......@@ -40,7 +40,6 @@ class BayesianAtlas(AbstractStatisticalModel):
deformation_kernel_type=default.deformation_kernel_type,
deformation_kernel_width=default.deformation_kernel_width,
deformation_kernel_device=default.deformation_kernel_device,
shoot_kernel_type=default.shoot_kernel_type,
number_of_time_points=default.number_of_time_points,
......@@ -54,9 +53,11 @@ class BayesianAtlas(AbstractStatisticalModel):
freeze_control_points=default.freeze_control_points,
initial_cp_spacing=default.initial_cp_spacing,
gpu_mode=default.gpu_mode,
**kwargs):
AbstractStatisticalModel.__init__(self, name='BayesianAtlas')
AbstractStatisticalModel.__init__(self, name='BayesianAtlas', gpu_mode=gpu_mode)
# Global-like attributes.
self.dimension = dimension
......@@ -82,7 +83,7 @@ class BayesianAtlas(AbstractStatisticalModel):
# Deformation.
self.exponential = Exponential(
dense_mode=dense_mode,
kernel=kernel_factory.factory(deformation_kernel_type, kernel_width=deformation_kernel_width),
kernel=kernel_factory.factory(deformation_kernel_type, gpu_mode=gpu_mode, kernel_width=deformation_kernel_width),
shoot_kernel_type=shoot_kernel_type,
number_of_time_points=number_of_time_points,
use_rk2_for_shoot=use_rk2_for_shoot, use_rk2_for_flow=use_rk2_for_flow)
......@@ -90,7 +91,7 @@ class BayesianAtlas(AbstractStatisticalModel):
# Template.
(object_list, self.objects_name, self.objects_name_extension,
objects_noise_variance, self.multi_object_attachment) = create_template_metadata(
template_specifications, self.dimension)
template_specifications, self.dimension, gpu_mode=gpu_mode)
self.template = DeformableMultiObject(object_list)
# self.template.update()
......@@ -102,7 +103,7 @@ class BayesianAtlas(AbstractStatisticalModel):
self.use_sobolev_gradient = use_sobolev_gradient
self.smoothing_kernel_width = smoothing_kernel_width
if self.use_sobolev_gradient:
self.sobolev_kernel = kernel_factory.factory(deformation_kernel_type, kernel_width=smoothing_kernel_width)
self.sobolev_kernel = kernel_factory.factory(deformation_kernel_type, gpu_mode=gpu_mode, kernel_width=smoothing_kernel_width)
# Template data.
self.fixed_effects['template_data'] = self.template.get_data()
......@@ -373,7 +374,7 @@ class BayesianAtlas(AbstractStatisticalModel):
"""
t_list, t_name, t_name_extension, t_noise_variance, t_multi_object_attachment = \
create_template_metadata(template_specifications)
create_template_metadata(template_specifications, gpu_mode=self.gpu_mode)
self.template.object_list = t_list
self.objects_name = t_name
......@@ -463,6 +464,7 @@ class BayesianAtlas(AbstractStatisticalModel):
"""
Core part of the ComputeLogLikelihood methods. Fully torch.
"""
device, _ = utilities.get_best_device(self.exponential.kernel.gpu_mode)
# Initialize: cross-sectional dataset --------------------------------------------------------------------------
targets = dataset.deformable_objects
......@@ -476,6 +478,7 @@ class BayesianAtlas(AbstractStatisticalModel):
for i, target in enumerate(targets):
self.exponential.set_initial_momenta(momenta[i])
self.exponential.move_data_to_(device=device)
self.exponential.update()
deformed_points = self.exponential.get_template_points()
deformed_data = self.template.get_deformed_data(deformed_points, template_data)
......@@ -553,6 +556,7 @@ class BayesianAtlas(AbstractStatisticalModel):
self._write_model_parameters(individual_RER, output_dir)
def _write_model_predictions(self, dataset, individual_RER, output_dir, compute_residuals=True):
device, _ = utilities.get_best_device(self.exponential.kernel.gpu_mode)
# Initialize.
template_data, template_points, control_points = self._fixed_effects_to_torch_tensors(False)
......@@ -565,6 +569,7 @@ class BayesianAtlas(AbstractStatisticalModel):
residuals = [] # List of torch 1D tensors. Individuals, objects.
for i, subject_id in enumerate(dataset.subject_ids):
self.exponential.set_initial_momenta(momenta[i])
self.exponential.move_data_to_(device=device)
self.exponential.update()
deformed_points = self.exponential.get_template_points()
......
......@@ -139,6 +139,7 @@ class DeterministicAtlas(AbstractStatisticalModel):
self.exponential = Exponential(
dense_mode=dense_mode,
kernel=kernel_factory.factory(deformation_kernel_type,
gpu_mode=gpu_mode,
kernel_width=deformation_kernel_width),
shoot_kernel_type=shoot_kernel_type,
number_of_time_points=number_of_time_points,
......@@ -158,6 +159,7 @@ class DeterministicAtlas(AbstractStatisticalModel):
self.smoothing_kernel_width = smoothing_kernel_width
if self.use_sobolev_gradient:
self.sobolev_kernel = kernel_factory.factory(deformation_kernel_type,
gpu_mode=gpu_mode,
kernel_width=smoothing_kernel_width)
# Template data.
......@@ -347,6 +349,7 @@ class DeterministicAtlas(AbstractStatisticalModel):
exponential.set_initial_template_points(template_points)
exponential.set_initial_control_points(control_points)
exponential.set_initial_momenta(momenta)
exponential.move_data_to_(device=device)
exponential.update()
# Compute attachment and regularity.
......@@ -501,9 +504,10 @@ class DeterministicAtlas(AbstractStatisticalModel):
self._write_model_parameters(output_dir)
def _write_model_predictions(self, dataset, individual_RER, output_dir, compute_residuals=True):
device, _ = utilities.get_best_device(self.gpu_mode)
# Initialize.
template_data, template_points, control_points, momenta = self._fixed_effects_to_torch_tensors(False)
template_data, template_points, control_points, momenta = self._fixed_effects_to_torch_tensors(False, device=device)
# Deform, write reconstructions and compute residuals.
self.exponential.set_initial_template_points(template_points)
......
......@@ -51,9 +51,11 @@ class GeodesicRegression(AbstractStatisticalModel):
initial_momenta=default.initial_momenta,
gpu_mode=default.gpu_mode,
**kwargs):
AbstractStatisticalModel.__init__(self, name='GeodesicRegression')
AbstractStatisticalModel.__init__(self, name='GeodesicRegression', gpu_mode=gpu_mode)
# Global-like attributes.
self.dimension = dimension
......@@ -73,7 +75,7 @@ class GeodesicRegression(AbstractStatisticalModel):
# Deformation.
self.geodesic = Geodesic(
dense_mode=dense_mode,
kernel=kernel_factory.factory(deformation_kernel_type, kernel_width=deformation_kernel_width),
kernel=kernel_factory.factory(deformation_kernel_type, gpu_mode=gpu_mode, kernel_width=deformation_kernel_width),
shoot_kernel_type=shoot_kernel_type,
t0=t0, concentration_of_time_points=concentration_of_time_points,
use_rk2_for_shoot=use_rk2_for_shoot, use_rk2_for_flow=use_rk2_for_flow)
......@@ -81,7 +83,7 @@ class GeodesicRegression(AbstractStatisticalModel):
# Template.
(object_list, self.objects_name, self.objects_name_extension,
self.objects_noise_variance, self.multi_object_attachment) = create_template_metadata(
template_specifications, self.dimension)
template_specifications, self.dimension, gpu_mode=gpu_mode)
self.template = DeformableMultiObject(object_list)
# self.template.update()
......@@ -91,7 +93,7 @@ class GeodesicRegression(AbstractStatisticalModel):
self.use_sobolev_gradient = use_sobolev_gradient
self.smoothing_kernel_width = smoothing_kernel_width
if self.use_sobolev_gradient:
self.sobolev_kernel = kernel_factory.factory(deformation_kernel_type, kernel_width=smoothing_kernel_width)
self.sobolev_kernel = kernel_factory.factory(deformation_kernel_type, gpu_mode=gpu_mode, kernel_width=smoothing_kernel_width)
# Template data.
self.fixed_effects['template_data'] = self.template.get_data()
......
This diff is collapsed.
......@@ -64,9 +64,12 @@ class PrincipalGeodesicAnalysis(AbstractStatisticalModel):
initial_principal_directions=default.initial_principal_directions,
freeze_principal_directions=default.freeze_principal_directions,
freeze_noise_variance=default.freeze_noise_variance,
gpu_mode=default.gpu_mode,
**kwargs):
AbstractStatisticalModel.__init__(self, name='PrincipalGeodesicAnalysis')
AbstractStatisticalModel.__init__(self, name='PrincipalGeodesicAnalysis', gpu_mode=gpu_mode)
self.dimension = dimension
self.tensor_scalar_type = tensor_scalar_type
......@@ -102,7 +105,7 @@ class PrincipalGeodesicAnalysis(AbstractStatisticalModel):
(object_list, self.objects_name, self.objects_name_extension,
objects_noise_variance, self.multi_object_attachment) = create_template_metadata(
template_specifications, self.dimension)
template_specifications, self.dimension, gpu_mode=gpu_mode)
self.template = DeformableMultiObject(object_list)
# self.template.update()
......@@ -112,7 +115,7 @@ class PrincipalGeodesicAnalysis(AbstractStatisticalModel):
self.objects_noise_dimension = compute_noise_dimension(self.template, self.multi_object_attachment,
self.dimension)
self.exponential = Exponential(dense_mode=dense_mode,
kernel=kernel_factory.factory(deformation_kernel_type, kernel_width=deformation_kernel_width),
kernel=kernel_factory.factory(deformation_kernel_type, gpu_mode=gpu_mode, kernel_width=deformation_kernel_width),
shoot_kernel_type=shoot_kernel_type,
number_of_time_points=number_of_time_points,
use_rk2_for_shoot=use_rk2_for_shoot, use_rk2_for_flow=use_rk2_for_flow)
......@@ -120,7 +123,7 @@ class PrincipalGeodesicAnalysis(AbstractStatisticalModel):
self.use_sobolev_gradient = use_sobolev_gradient
self.smoothing_kernel_width = smoothing_kernel_width
if self.use_sobolev_gradient:
self.sobolev_kernel = kernel_factory.factory(deformation_kernel_type, kernel_width=smoothing_kernel_width)
self.sobolev_kernel = kernel_factory.factory(deformation_kernel_type, gpu_mode=gpu_mode, kernel_width=smoothing_kernel_width)
# Template data
self.set_template_data(self.template.get_data())
......
......@@ -155,10 +155,12 @@ def read_and_create_image_dataset(dataset_filenames, visit_ages, subject_ids, te
return longitudinal_dataset
def create_template_metadata(template_specifications, dimension=None):
def create_template_metadata(template_specifications, dimension=None, gpu_mode=None):
"""
Creates a longitudinal dataset object from xml parameters.
"""
if gpu_mode is None:
gpu_mode = default.gpu_mode
objects_list = []
objects_name = []
......@@ -193,6 +195,7 @@ def create_template_metadata(template_specifications, dimension=None):
if object_norm in ['current', 'pointcloud', 'varifold']:
objects_norm_kernels.append(kernel_factory.factory(
object['kernel_type'],
gpu_mode=gpu_mode,
kernel_width=object['kernel_width']))
else:
objects_norm_kernels.append(kernel_factory.factory(kernel_factory.Type.NO_KERNEL))
......
......@@ -79,6 +79,7 @@ def compute_parallel_transport(template_specifications,
The following code block needs to be done on cpu due to the high memory usage of the matrix inversion.
TODO: maybe use Keops Inv ?
"""
velocity = utilities.move_data(velocity, dtype=tensor_scalar_type, device='cpu') # TODO: could this be done on gpu ?
kernel_matrix = utilities.move_data(kernel_matrix, dtype=tensor_scalar_type, device='cpu') # TODO: could this be done on gpu ?
cholesky_kernel_matrix = torch.potrf(kernel_matrix)
......@@ -98,7 +99,7 @@ def compute_parallel_transport(template_specifications,
Second half of the code.
"""
objects_list, objects_name, objects_name_extension, _, _ = create_template_metadata(template_specifications, dimension)
objects_list, objects_name, objects_name_extension, _, _ = create_template_metadata(template_specifications, dimension, gpu_mode=gpu_mode)
template = DeformableMultiObject(objects_list)
template_points = template.get_points()
......
......@@ -29,6 +29,7 @@ def compute_shooting(template_specifications,
number_of_time_points=default.number_of_time_points,
use_rk2_for_shoot=default.use_rk2_for_shoot,
use_rk2_for_flow=default.use_rk2_for_flow,
gpu_mode=default.gpu_mode,
output_dir=default.output_dir, **kwargs
):
logger.info('[ compute_shooting function ]')
......@@ -37,11 +38,11 @@ def compute_shooting(template_specifications,
Create the template object
"""
deformation_kernel = kernel_factory.factory(deformation_kernel_type, kernel_width=deformation_kernel_width)
deformation_kernel = kernel_factory.factory(deformation_kernel_type, gpu_mode=gpu_mode, kernel_width=deformation_kernel_width)
(object_list, t_name, t_name_extension,
t_noise_variance, multi_object_attachment) = create_template_metadata(
template_specifications, dimension)
template_specifications, dimension, gpu_mode=gpu_mode)
template = DeformableMultiObject(object_list)
......
......@@ -17,10 +17,13 @@ class Type(Enum):
instance_map = dict()
def factory(kernel_type, gpu_mode=default.gpu_mode, *args, **kwargs):
def factory(kernel_type, cuda_type=None, gpu_mode=default.gpu_mode, *args, **kwargs):
"""Return an instance of a kernel corresponding to the requested kernel_type"""
assert isinstance(gpu_mode, GpuMode)
if cuda_type is None:
cuda_type = default.dtype
# turn enum string to enum object
if isinstance(kernel_type, str):
try:
......@@ -37,7 +40,7 @@ def factory(kernel_type, gpu_mode=default.gpu_mode, *args, **kwargs):
return None
res = None
hash = AbstractKernel.hash(kernel_type, gpu_mode, *args, **kwargs)
hash = AbstractKernel.hash(kernel_type, cuda_type, gpu_mode, *args, **kwargs)
if hash not in instance_map:
res = kernel_type.value(gpu_mode, *args, **kwargs) # instantiate
instance_map[hash] = res
......
......@@ -21,12 +21,9 @@ class AbstractKernel(ABC):
and self.gpu_mode == other.gpu_mode \
and self.kernel_width == other.kernel_width
def __hash__(self):
return AbstractKernel.hash(self.kernel_type, self.gpu_mode, self.kernel_width)
@staticmethod
def hash(kernel_type, gpu_mode, kernel_width, *args, **kwargs):
return hash((kernel_type, gpu_mode, kernel_width))
def hash(kernel_type, cuda_type, gpu_mode, *args, **kwargs):
return hash((kernel_type, cuda_type, gpu_mode, frozenset(args), frozenset(kwargs.items())))
@abstractmethod
def convolve(self, x, y, p, mode=None):
......
......@@ -63,9 +63,6 @@ class KeopsKernel(AbstractKernel):
def __eq__(self, other):
return AbstractKernel.__eq__(self, other) and self.cuda_type == other.cuda_type
def __hash__(self):
return hash((self.kernel_type, self.gpu_mode, self.kernel_width, self.cuda_type))
def convolve(self, x, y, p, mode='gaussian'):
if mode == 'gaussian':
assert isinstance(x, torch.Tensor), 'x variable must be a torch Tensor'
......@@ -78,7 +75,7 @@ class KeopsKernel(AbstractKernel):
+ ', y.device=' + str(y.device) + ', p.device=' + str(p.device)
d = x.size(1)
gamma = self.gamma.to(x.device)
gamma = self.gamma.to(x.device, dtype=x.dtype)
device_id = x.device.index if x.device.index is not None else -1
res = self.gaussian_convolve[d - 2](gamma, x.contiguous(), y.contiguous(), p.contiguous(), device_id=device_id)
......@@ -95,7 +92,7 @@ class KeopsKernel(AbstractKernel):
+ ', y.device=' + str(y.device) + ', p.device=' + str(p.device)
d = x.size(1)
gamma = self.gamma.to(x.device)
gamma = self.gamma.to(x.device, dtype=x.dtype)
device_id = x.device.index if x.device.index is not None else -1
res = self.point_cloud_convolve[d - 2](gamma, x.contiguous(), y.contiguous(), p.contiguous(), device_id=device_id)
......@@ -119,7 +116,7 @@ class KeopsKernel(AbstractKernel):
x, nx = x
y, ny = y
d = x.size(1)
gamma = self.gamma.to(x.device)
gamma = self.gamma.to(x.device, dtype=x.dtype)
device_id = x.device.index if x.device.index is not None else -1
res = self.varifold_convolve[d - 2](gamma, x.contiguous(), y.contiguous(), nx.contiguous(), ny.contiguous(), p.contiguous(), device_id=device_id)
......@@ -144,7 +141,7 @@ class KeopsKernel(AbstractKernel):
assert px.device == x.device == y.device == py.device, 'tensors must be on the same device'
d = x.size(1)
gamma = self.gamma.to(x.device)
gamma = self.gamma.to(x.device, dtype=x.dtype)
device_id = x.device.index if x.device.index is not None else -1
res = (-2 * gamma * self.gaussian_convolve_gradient_x[d - 2](gamma, x, y, px, py, device_id=device_id))
......
......@@ -2,7 +2,7 @@ import torch
import torch.multiprocessing as mp
import numpy as np
from core import GpuMode, get_best_gpu_mode
from core import GpuMode
from core.observations.deformable_objects.deformable_multi_object import DeformableMultiObject
import logging
......@@ -169,8 +169,15 @@ def get_best_device(gpu_mode=GpuMode.AUTO):
assert isinstance(gpu_mode, GpuMode)
use_cuda = False
if gpu_mode in [GpuMode.FULL, GpuMode.AUTO]:
if gpu_mode in [GpuMode.AUTO]:
# TODO this should be more clever
use_cuda = True if torch.cuda.is_available() else False
elif gpu_mode in [GpuMode.FULL]:
use_cuda = True
if not torch.cuda.is_available():
# logger.warning("GPU computation is not available, falling-back to CPU.")
use_cuda = False
assert isinstance(use_cuda, bool)
device_id = 0 if use_cuda and torch.cuda.is_available() else -1
......
This diff is collapsed.
......@@ -236,6 +236,37 @@ class KeopsKernelTest(KernelTestBase):
self._assert_same_kernels(kernel_instance, deserialized_kernel)
@unittest.skipIf(not torch.cuda.is_available(), 'cuda is not available')
def test_gpu_mode(self):
for gpu_mode, cuda_type in [(gpu_mode, cuda_type)
for gpu_mode in [gpu_mode for gpu_mode in GpuMode]
for cuda_type in ['float32', 'float64']]:
print('gpu_mode: ' + str(gpu_mode) + ', cuda_type: ' + cuda_type)
if gpu_mode is GpuMode.AUTO:
continue # TODO
kernel_instance = kernel_factory.factory(kernel_factory.Type.KEOPS, gpu_mode=gpu_mode, kernel_width=1., cuda_type=cuda_type)
x = self.x
y = self.y
p = self.p
if cuda_type == 'float32':
default.update_dtype('float32')
x = self.x.float()
y = self.y.float()
p = self.p.float()
res = kernel_instance.convolve(x, y, p)
if gpu_mode is GpuMode.FULL:
self.assertEqual('cuda', res.device.type)
res = res.cpu()
self.assertEqual('cpu', res.device.type)
self._assert_tensor_close(res, self.expected_convolve_res, precision=1e-7)
class KeopsVersusCuda(unittest.TestCase):
def setUp(self):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment