Commit 1922a045 authored by Benoit Martin's avatar Benoit Martin

fix pga test

parent 379c21b9
......@@ -9,7 +9,7 @@ import time
import torch
import numpy as np
from core import default
from core import default, GpuMode
from core.estimators.gradient_ascent import GradientAscent
from core.estimators.mcmc_saem import McmcSaem
from core.estimators.scipy_optimize import ScipyOptimize
......@@ -555,8 +555,9 @@ class Deformetrica:
if estimator_options is not None:
if 'gpu_mode' not in estimator_options:
estimator_options['gpu_mode'] = default.gpu_mode
# else:
# default.update_use_cuda(estimator_options['use_cuda'])
if estimator_options['gpu_mode'] is GpuMode.FULL and not torch.cuda.is_available():
logger.warning("GPU computation is not available, falling-back to CPU.")
estimator_options['gpu_mode'] = GpuMode.NONE
if 'state_file' not in estimator_options:
estimator_options['state_file'] = default.state_file
......
......@@ -167,23 +167,20 @@ def get_best_device(gpu_mode=GpuMode.AUTO):
"""
assert gpu_mode is not None
assert isinstance(gpu_mode, GpuMode)
use_cuda = False
if gpu_mode in [GpuMode.AUTO]:
# TODO this should be more clever
use_cuda = True if torch.cuda.is_available() else False
use_cuda = torch.cuda.is_available()
elif gpu_mode in [GpuMode.FULL]:
use_cuda = True
if not torch.cuda.is_available():
# logger.warning("GPU computation is not available, falling-back to CPU.")
use_cuda = False
assert isinstance(use_cuda, bool)
device_id = 0 if use_cuda and torch.cuda.is_available() else -1
device_id = 0 if use_cuda else -1
device = 'cuda:' + str(device_id) if use_cuda and torch.cuda.is_available() else 'cpu'
if use_cuda and torch.cuda.is_available() and mp.current_process().name != 'MainProcess':
if use_cuda and mp.current_process().name != 'MainProcess':
'''
PoolWorker-1 will use cuda:0
PoolWorker-2 will use cuda:1
......
......@@ -718,6 +718,51 @@ class API(unittest.TestCase):
def test_compute_parallel_transport_mesh_3d_alien(self):
self.__test_all(self._test_compute_parallel_transport_mesh_3d_alien)
#
# PGA
#
def _test_estimate_principal_geodesic_analysis_digit(self, dtype, gpu_mode):
BASE_DIR = functional_tests_data_dir + '/principal_geodesic_analysis/digits/'
template_specifications = {
'img': {'deformable_object_type': 'Image',
'filename': BASE_DIR + 'data/digit_2_mean.png',
'noise_std': 0.1,
'noise_variance_prior_normalized_dof': 10,
'noise_variance_prior_scale_std': 1}}
dataset_specifications = {
'dataset_filenames': [
[{'img': BASE_DIR + 'data/digit_2_sample_1.png'}],
[{'img': BASE_DIR + 'data/digit_2_sample_2.png'}],
[{'img': BASE_DIR + 'data/digit_2_sample_3.png'}],
[{'img': BASE_DIR + 'data/digit_2_sample_4.png'}],
[{'img': BASE_DIR + 'data/digit_2_sample_5.png'}],
[{'img': BASE_DIR + 'data/digit_2_sample_6.png'}],
[{'img': BASE_DIR + 'data/digit_2_sample_7.png'}],
[{'img': BASE_DIR + 'data/digit_2_sample_8.png'}],
[{'img': BASE_DIR + 'data/digit_2_sample_9.png'}],
[{'img': BASE_DIR + 'data/digit_2_sample_10.png'}],
[{'img': BASE_DIR + 'data/digit_2_sample_11.png'}],
[{'img': BASE_DIR + 'data/digit_2_sample_12.png'}],
[{'img': BASE_DIR + 'data/digit_2_sample_13.png'}],
[{'img': BASE_DIR + 'data/digit_2_sample_14.png'}],
[{'img': BASE_DIR + 'data/digit_2_sample_15.png'}]
],
'subject_ids': ['target']
}
self.deformetrica.estimate_principal_geodesic_analysis(
template_specifications,
dataset_specifications=dataset_specifications,
estimator_options={'optimization_method_type': 'ScipyLBFGS', 'max_iterations': 2},
model_options={'deformation_kernel_type': 'keops', 'deformation_kernel_width': 3,
'latent_space_dimension': 2,
'dtype': dtype, 'gpu_mode': gpu_mode},
)
@unittest.skip # TODO
def test_estimate_principal_geodesic_analysis_digit(self):
self.__test_all(self._test_estimate_principal_geodesic_analysis_digit)
#
# Shooting
#
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment