Commit 3c5c30db authored by Funke's avatar Funke
Browse files

publish code

parent 970d2260
================================================================================
Copyright (c) 2019 National Center for Tumor Diseases,
Division of Translational Surgical Oncology
All rights reserved.
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
================================================================================
The file resnet.py was originally authored by Kensho Hara
(https://github.com/kenshohara/3D-ResNets-PyTorch) and distributed under the MIT
License:
MIT License
Copyright (c) 2017 Kensho Hara
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
================================================================================
The file metrics.py was originally authored by Colin Lea
(https://github.com/colincsl/TemporalConvolutionalNetworks) and distributed under
the MIT License:
MIT License
Copyright (c) 2016 Colin Lea
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
================================================================================
The following files were originally published by Xiong Yuanjun
(https://github.com/yjxiong/tsn-pytorch):
* models.py (parts)
* transforms.py
These files were originally distributed under the BSD 2-Clause License:
BSD 2-Clause License
Copyright (c) 2017, Multimedia Laboratary, The Chinese University of Hong Kong
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
================================================================================
The file build_of.py is a modified version of the file build_of.py authored by
Xiong Yuanjun (https://github.com/yjxiong/temporal-segment-networks),
distributed under the BSD 2-Clause License:
Copyright (c) 2016, Multimedia Laboratory, The Chinese University of Hong Kong
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
# surgical_gesture_recognition
# Using 3D Convolutional Neural Networks to Learn Spatiotemporal Features for Automatic Surgical Gesture Recognition in Video
3D Convolutional Neural Networks for Automatic Surgical Gesture Recognition in Video
\ No newline at end of file
PyTorch implementation of video-based surgical gesture recognition using 3D convolutional neural networks.
We propose to use a modified [3D ResNet-18](http://openaccess.thecvf.com/content_ICCV_2017_workshops/papers/w44/Hara_Learning_Spatio-Temporal_Features_ICCV_2017_paper.pdf) to predict dense gesture labels for RGB video input. Details can be found in our [paper](https://arxiv.org/abs/1907.11454).
This implementation is based on open source code published by Kensho Hara in the [3D-ResNets-PyTorch](https://github.com/kenshohara/3D-ResNets-PyTorch) repository.
## Code
### How to start
Simply clone this repository:
```bash
cd <the directory where the repo shall live>
git clone https://gitlab.com/nct_tso_public/surgical_gesture_recognition.git
```
In the following, we use `CODE_DIR` to refer to the absolute path to the code.
Check if you have all required Python packages installed. Our code depends on
> torch torchvision numpy scipy pillow opencv-python datetime sklearn numba matplotlib seaborn pandas
Experiments were run using Python 3.6 (Python 3.5 should also work fine) and [PyTorch 1.0.0](https://pytorch.org) with Cuda 9.2.
### Data preparation
Download the JIGSAWS dataset from [here](https://cirl.lcsr.jhu.edu/research/hmm/datasets/jigsaws_release/) and unzip it. You will obtain one folder per surgical task (`Suturing`, `Needle_Passing`, and `Knot_Tying`). We use `DATA_DIR` to refer to the absolute path to the *parent* of these folders.
To extract video frames and to pre-calculate optical flow, we used the code provided by [Limin Wang](http://wanglimin.github.io/), [Yuanjun Xiong](http://yjxiong.me/), and colleagues. You can do the same by executing the following steps:
- Download and build the *dense_flow* code:
Run `git clone --recursive http://github.com/yjxiong/dense_flow` and follow the [install](https://github.com/yjxiong/dense_flow) instructions.
We use `DF_BUILD_DIR` to refer to the absolute path to your *dense_flow* build folder, i.e., the folder containing the binary `extract_gpu` after successful installation.
Note that you will have to install LibZip and OpenCV to compile the code.
- We used OpenCV 3.4.
- OpenCV must be built with CUDA support (`-D WITH_CUDA=ON`) and with extra modules (`-D OPENCV_EXTRA_MODULES_PATH=/<path>/<to>/<your>/<opencv_contrib>/`).
- When building the *dense_flow* code, you can specify a custom location of your OpenCV library by running `OpenCV_DIR=/<path>/<to>/<your>/<opencv_dir>/ cmake ..` (instead of simply `cmake ..`)
- Run the script `extract_frames.sh`:
```bash
cd <CODE_DIR>
bash extract_frames.sh <DF_BUILD_DIR> <DATA_DIR>
```
This will extract frames at 5 fps from all *suturing* videos in the JIGSAWS dataset. The script skips videos ending with `capture_1.avi` because we only consider the video of the *right* camera.
Optionally, you can specify the parameters `step_size`, `num_gpu`, and `jobs_per_gpu` as 3rd, 4th, and 5th command line arguments. Here, `step_size` specifies at which temporal resolution the frames are extracted (namely at `<original fps>/<step_size>` fps), `num_gpu >= 1` specifies how many GPUs (at least one) are used, and `jobs_per_gpu` specifies how many videos will be processed in parallel on each GPU. For example, `bash extract_frames.sh <DF_BUILD_DIR> <DATA_DIR> 6 2 8` will extract frames at 30 fps / `<step_size>` = 5 fps using two GPUs and eight workers per GPU. Per default, we use `num_gpu = 1`, `jobs_per_gpu = 4`, and `step_size = 6`.
Finally, the data folder structure will look like this:
```
<DATA_DIR>
Suturing
video
Suturing_B001_capture1.avi
Suturing_B001_capture2.avi
...
transcriptions
Suturing_B001.txt
Suturing_B002.txt
...
(other JIGSAWS specific files and folders)
frames
Suturing_B001_capture2
flow_x_00001.jpg
flow_x_00002.jpg
...
flow_y_00001.jpg
...
img_00001.jpg
...
Suturing_B002_capture2
...
```
### Train a model
#### Quick start
- To obtain the weights of the [Kinetics](https://deepmind.com/research/open-source/kinetics)-pretrained *3D ResNet-18* model, download the file `resnet-18-kinetics.pth` that is made available [here](https://github.com/kenshohara/3D-ResNets-PyTorch#pre-trained-models) (you need to follow the first link).
- Create the video list files, which are required by our training and evaluation scripts, by running:
```bash
cd <CODE_DIR>
python3 create_video_files.py --data_dir <DATA_DIR>
```
- Now, the following command will train a model for surgical gesture recognition on the JIGSAWS suturing task, starting from the Kinetics-pretrained *3D ResNet-18*.
```bash
python3 train.py --exp <EXP> --split <SPLIT> --pretrain_path "/<your>/<path>/<to>/resnet-18-kinetics.pth" --data_path "<DATA_DIR>/Suturing/frames" --transcriptions_dir "<DATA_DIR>/Suturing/transcriptions" --out <OUT_DIR>
```
The command line parameter `--split` specifies which LOUO cross-validation fold is left out from the training data.
Results, e.g., model files, will be written to `<OUT_DIR>/<EXP>_<current date>/LOUO/<SPLIT>/<current time>`.
Note that we require you to specify a name `EXP` for the experiment so that you can identify the trained models at a later time.
You can set defaults, e.g., for `--data_path`, `--transcriptions_dir`, and `--out`, in the file `train_opts.py`.
Run `python3 train.py -h` to get a complete list of all command line parameters that can be specified.
#### More experiments
You can repeat the other experiments described in our paper as follows:
- Train a baseline *2D ResNet-18* model:
```bash
python3 train.py --exp <2D_EXP> --split <SPLIT> --arch resnet18 --snippet_length 1 --data_path "<DATA_DIR>/Suturing/frames" --transcriptions_dir "<DATA_DIR>/Suturing/transcriptions" --out <OUT_DIR>
```
- Train a 3D CNN for surgical gesture recognition after bootstrapping weights from a trained *2D ResNet-18* model:
```bash
python3 train.py --exp <EXP> --split <SPLIT> --use_resnet_shortcut_type_B True --bootstrap_from_2D True --pretrain_path "<OUT_DIR>/<2D_EXP>_<date>" --data_path "<DATA_DIR>/Suturing/frames" --transcriptions_dir "<DATA_DIR>/Suturing/transcriptions" --out <OUT_DIR>
```
In this case, the script expects to find the 2D models trained during a previous experiment at `<OUT_DIR>/<2D_EXP>_<date>/LOUO/<SPLIT>/<some timestamp>/`.
### Evaluate trained models
After training a model for every cross-validation fold, you can evaluate the experiment.
To test the 3D CNN that was initialized with Kinetics-pretrained weights, you can run:
```bash
python3 test.py --exp <EXP>_<date> --data_path "<DATA_DIR>/Suturing/frames" --transcriptions_dir "<DATA_DIR>/Suturing/transcriptions" --model_dir <OUT_DIR>
```
Here, `date` is the timestamp (current date) generated for the experiment at training time. The script expects to find the trained models at `<OUT_DIR>/<EXP>_<date>/LOUO/<SPLIT>/<some timestamp>/model_<no>.pth.tar`. By default, `no` is set to 249, which is the number of the final models saved after 250 epochs of training. You can evaluate models saved at earlier points during training by setting the command line parameter `--model_no`.
The script computes the surgical gesture estimates for every video in the dataset, using the model that hasn't seen the video at training time. The predictions are compared against the ground truth labels to compute the evaluation metrics (accuracy, average F1 score, edit score, and segmental F1 score). Results are saved as a python dictionary at `<OUT_DIR>/Eval/LOUO/5Hz/plain/<EXP>_<date>/<model_no>.pth.tar`.
If the dense gesture predictions shall be aggregated over time to obtain final gesture estimates (*sliding window*), just add `--sliding_window True` to the command. In this case, results will be saved at `<OUT_DIR>/Eval/LOUO/5Hz/window/<EXP>_<date>/<model_no>.pth.tar`.
The other experiments can be evaluated analogously:
- Evaluate the *2D ResNet-18* baseline:
```bash
python3 test.py --exp <EXP>_<date> --arch resnet18 --snippet_length 1 --data_path "<DATA_DIR>/Suturing/frames" --transcriptions_dir "<DATA_DIR>/Suturing/transcriptions" --model_dir <OUT_DIR>
```
- Evaluate the 3D CNN that was initialized with weights bootstrapped from a 2D model:
```bash
python3 test.py --exp <EXP>_<date> --use_resnet_shortcut_type_B True --data_path "<DATA_DIR>/Suturing/frames" --transcriptions_dir "<DATA_DIR>/Suturing/transcriptions" --model_dir <OUT_DIR> [--sliding_window True]
```
Run `python3 test.py -h` to get a complete list of all command line parameters that can be specified.
#### Evaluation at 2 Hz or 10 Hz
You can opt to extract the snippets considered for evaluation at 2 Hz or 10 Hz (instead of 5 Hz) from the videos. To do this, you need to add `--video_sampling_step 15` (2 Hz) or `--video_sampling_step 3` (10 Hz) to the command. Note that this won't alter the temporal structure within one video snippet.
To extract video snippets at frequencies other than 5 Hz, the script requires access to the video frames extracted at the full temporal resolution of 30 fps. To extract frames at 30 fps, you can run:
```bash
bash extract_frames.sh <DF_BUILD_DIR> <DATA_DIR> 1 <num_gpu> <jobs_per_gpu> "frames_30Hz"
```
Here, the 6th command line argument achieves that the extracted frames are written to `<DATA_DIR>/Suturing/frames_30Hz` (instead of messing with the previously created `frames` folder).
Now, you can set `--data_path "<DATA_DIR>/Suturing/frames_30Hz"` in the command, for example
```bash
python3 test.py --exp <EXP>_<date> --data_path "<DATA_DIR>/Suturing/frames_30Hz" --transcriptions_dir "<DATA_DIR>/Suturing/transcriptions" --model_dir <OUT_DIR> --video_sampling_step 3 [--sliding_window True]
```
## How to cite
If you use parts of the code in your own research, please cite:
@InProceedings{funke2019,
author="Funke, Isabel and Bodenstedt, Sebastian and Oehme, Florian and von Bechtolsheim, Felix and Weitz, J{\"u}rgen and Speidel, Stefanie",
title="Using {3D} Convolutional Neural Networks to Learn Spatiotemporal Features for Automatic Surgical Gesture Recognition in Video",
booktitle="Medical Image Computing and Computer Assisted Intervention -- MICCAI 2019",
year="2019",
pages="467--475",
editor="Shen, Dinggang and Liu, Tianming and Peters, Terry M. and Staib, Lawrence H. and Essert, Caroline and Zhou, Sean and Yap, Pew-Thian and Khan, Ali",
series="Lecture Notes in Computer Science",
volume="11768",
publisher="Springer International Publishing",
address="Cham",
doi="10.1007/978-3-030-32254-0\_52"
}
This work was carried out at the National Center for Tumor Diseases (NCT) Dresden, [Department of Translational Surgical Oncology](https://www.nct-dresden.de/tso.html).
# Adapted from https://github.com/yjxiong/temporal-segment-networks/blob/master/tools/build_of.py
__author__ = 'yjxiong'
import os
import glob
import sys
from pipes import quote
from multiprocessing import Pool, current_process
import argparse
out_path = ''
def dump_frames(vid_path):
import cv2
video = cv2.VideoCapture(vid_path)
vid_name = vid_path.split('/')[-1].split('.')[0]
out_full_path = os.path.join(out_path, vid_name)
fcount = int(video.get(cv2.cv.CV_CAP_PROP_FRAME_COUNT))
try:
os.mkdir(out_full_path)
except OSError:
pass
file_list = []
for i in range(fcount):
ret, frame = video.read()
assert ret
cv2.imwrite('{}/{:06d}.jpg'.format(out_full_path, i), frame)
access_path = '{}/{:06d}.jpg'.format(vid_name, i)
file_list.append(access_path)
print('{} done'.format(vid_name))
sys.stdout.flush()
return file_list
def run_optical_flow(vid_item, dev_id=0):
vid_path = vid_item[0]
vid_id = vid_item[1]
vid_name = vid_path.split('/')[-1].split('.')[0]
out_full_path = os.path.join(out_path, vid_name)
try:
os.mkdir(out_full_path)
except OSError:
pass
current = current_process()
dev_id = (int(current._identity[0]) - 1) % NUM_GPU
image_path = '{}/img'.format(out_full_path)
flow_x_path = '{}/flow_x'.format(out_full_path)
flow_y_path = '{}/flow_y'.format(out_full_path)
cmd = os.path.join(df_path, 'extract_gpu')+' -f={} -x={} -y={} -i={} -b=20 -t=1 -d={} -s={} -o={} -w={} -h={}'.format(
quote(vid_path), quote(flow_x_path), quote(flow_y_path), quote(image_path), dev_id, step, out_format, new_size[0], new_size[1])
os.system(cmd)
print('{} {} done'.format(vid_id, vid_name))
sys.stdout.flush()
return True
def run_warp_optical_flow(vid_item, dev_id=0):
vid_path = vid_item[0]
vid_id = vid_item[1]
vid_name = vid_path.split('/')[-1].split('.')[0]
out_full_path = os.path.join(out_path, vid_name)
try:
os.mkdir(out_full_path)
except OSError:
pass
current = current_process()
dev_id = (int(current._identity[0]) - 1) % NUM_GPU
flow_x_path = '{}/flow_x'.format(out_full_path)
flow_y_path = '{}/flow_y'.format(out_full_path)
cmd = os.path.join(df_path, 'extract_warp_gpu')+' -f={} -x={} -y={} -b=20 -t=1 -d={} -s=1 -o={}'.format(
vid_path, flow_x_path, flow_y_path, dev_id, out_format)
os.system(cmd)
print('warp on {} {} done'.format(vid_id, vid_name))
sys.stdout.flush()
return True
def nonintersection(lst1, lst2):
lst3 = [value for value in lst1 if ((value.split("/")[-1]).split(".")[0]) not in lst2]
return lst3
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="extract optical flows")
parser.add_argument("src_dir")
parser.add_argument("out_dir")
parser.add_argument("--num_worker", type=int, default=8)
parser.add_argument("--flow_type", type=str, default='tvl1', choices=['tvl1', 'warp_tvl1'])
parser.add_argument("--df_path", type=str, default='./lib/dense_flow/build/', help='path to the dense_flow toolbox')
parser.add_argument("--out_format", type=str, default='dir', choices=['dir','zip'],
help='path to the dense_flow toolbox')
parser.add_argument("--ext", type=str, default='avi', choices=['avi','mp4'], help='video file extensions')
parser.add_argument("--new_width", type=int, default=0, help='resize image width')
parser.add_argument("--new_height", type=int, default=0, help='resize image height')
parser.add_argument("--num_gpu", type=int, default=8, help='number of GPU')
parser.add_argument("--resume", type=str, default='no', choices=['yes','no'], help='resume optical flow extraction instead of overwriting')
parser.add_argument("-s", "--step", type=int, default=1)
parser.add_argument("--suffix", type=str, default=None, help='only consider videos ending with the specified suffix')
args = parser.parse_args()
out_path = args.out_dir
src_path = args.src_dir
num_worker = args.num_worker
flow_type = args.flow_type
df_path = args.df_path
out_format = args.out_format
ext = args.ext
new_size = (args.new_width, args.new_height)
NUM_GPU = args.num_gpu
resume = args.resume
step = args.step
if not os.path.isdir(out_path):
print("creating folder: "+out_path)
os.makedirs(out_path)
print("reading videos from folder: ", src_path)
print("selected extension of videos:", ext)
vid_list = glob.glob(src_path+'/*.'+ext)
if args.suffix is not None:
vid_list = list(filter(lambda x: x.endswith(args.suffix), vid_list))
print("total number of videos found: ", len(vid_list))
if(resume == 'yes'):
com_vid_list = os.listdir(out_path)
vid_list = nonintersection(vid_list, com_vid_list)
print("resuming from video: ", vid_list[0])
pool = Pool(num_worker)
if flow_type == 'tvl1':
pool.map(run_optical_flow, zip(vid_list, range(len(vid_list))))
elif flow_type == 'warp_tvl1':
pool.map(run_warp_optical_flow, zip(vid_list, range(len(vid_list))))
import argparse
import os
import cv2
def main(data_dir):
LOSO_splits = ['1', '2', '3', '4', '5']
LOUO_splits = ['B', 'C', 'D', 'E', 'F', 'G', 'H', 'I']
out_dir = "./Splits"
for task in ["Suturing", "Needle_Passing", "Knot_Tying"]:
meta_file = os.path.join(data_dir, task, "meta_file_{}.txt".format(task))
_annotations = [x.strip().split('\t') for x in open(meta_file)]
annotations = []
for i in range(len(_annotations)):
annotation = []
for elem in _annotations[i]:
if elem:
annotation.append(elem)
if annotation:
annotations.append(annotation)
trials = [row[0].split('_')[-1] for row in annotations]
video_frame_counts = {}
for trial in trials:
video_file = os.path.join(data_dir, task, "video", "{}_{}_capture2.avi".format(task, trial))
cap = cv2.VideoCapture(video_file)
video_frame_counts[trial] = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
cap.release()
for splits in [LOSO_splits, LOUO_splits]:
eval_scheme = "LOSO" if len(splits) == len(LOSO_splits) else "LOUO"
for i in range(len(splits)):
split = [trial for trial in trials if splits[i] in trial]
if not os.path.exists(os.path.join(out_dir, task, eval_scheme)):
os.makedirs(os.path.join(out_dir, task, eval_scheme))
split_file = open(os.path.join(out_dir, task, eval_scheme, "data_{}.csv".format(splits[i])), mode='w')
for trial in sorted(split):
row = ["{}_{}".format(task, trial)]
row.append(video_frame_counts[trial])
row = [str(elem) for elem in row]
split_file.write(','.join(row) + os.linesep)
split_file.close()
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="Create video list files.")
parser.add_argument('data_dir', type=str, help="Path to data folder, which contains the extracted images "
"for each video. One subfolder per video.")
args = parser.parse_args()
main(args.data_dir)
# Copyright (C) 2019 National Center of Tumor Diseases (NCT) Dresden, Division of Translational Surgical Oncology
import torch
import torch.utils.data as data
import torchvision
from transforms import Stack, ToTorchFormatTensor
from PIL import Image
import os
import numpy as np
from numpy.random import randint
import math
class GestureRecord(object):
def __init__(self, g_segments, snippet_length=16, min_overlap=1):
self.start_frames = []
self.accumulated_snippet_counts = [0]
self.num_unique_snippets = 0
_accumulated_snippet_count = 0
for s in g_segments:
if s[1] - s[0] + 1 >= min_overlap: # at least one complete snippet in segment
start = s[0] - snippet_length + min_overlap
end = s[1] - snippet_length + 1
self.start_frames.append(start)
_accumulated_snippet_count += end - start + 1
self.accumulated_snippet_counts.append(_accumulated_snippet_count)
self.num_unique_snippets += (s[1] - start) // snippet_length
def sample_idx(self):
idx = randint(self.snippet_count())
# transform to video-level frame no.
i = 0
while not (idx < self.accumulated_snippet_counts[i + 1]):
i += 1
idx = idx - self.accumulated_snippet_counts[i] + self.start_frames[i]
return idx
def snippet_count(self):
return self.accumulated_snippet_counts[-1]
class GestureDataSet(data.Dataset):
def __init__(self, root_path, list_of_list_files, transcriptions_dir, gesture_ids,
snippet_length=16, min_overlap=1, video_sampling_step=6,
modality='RGB', image_tmpl='img_{:05d}.jpg', video_suffix="_capture2",
return_3D_tensor=True, return_dense_labels=True,
transform=None, normalize=None, load_to_RAM=True):
self.root_path = root_path
self.list_of_list_files = list_of_list_files
self.transcriptions_dir = transcriptions_dir
self.gesture_ids = gesture_ids
self.snippet_length = snippet_length
self.min_overlap = min_overlap
self.video_sampling_step = video_sampling_step
self.modality = modality
self.image_tmpl = image_tmpl
self.video_suffix = video_suffix
self.return_3D_tensor = return_3D_tensor
self.return_dense_labels = return_dense_labels
self.transform = transform
self.normalize = normalize
self.load_to_RAM = load_to_RAM
self.gesture_dict = {} # for each gesture, save which segments of which video belong to that gesture
self.min_g_count = 0 # for each gesture, there are at least <min_g_count> non-overlapping snippets in the dataset
self.gesture_sequence_per_video = {}
self.image_data = {}
self._parse_list_files(list_of_list_files)
def _parse_list_files(self, list_of_list_files):
for list_file in list_of_list_files:
videos = [(x.strip().split(',')[0], x.strip().split(',')[1]) for x in open(list_file)]
for video in videos:
video_id = video[0]
frame_count = int(video[1])
gestures_file = os.path.join(self.transcriptions_dir, video_id + ".txt")
gestures = [[int(x.strip().split(' ')[0]), int(x.strip().split(' ')[1]), x.strip().split(' ')[2]]
for x in open(gestures_file)]
# [start_frame, end_frame, gesture_id]
"""
for i in range(len(gestures)):
if i + 1 < len(gestures):
assert (gestures[i][1] == gestures[i + 1][0] - 1)
else:
assert (gestures[i][1] < frame_count)
"""
# adjust indices to temporal downsampling (specified by "video_sampling_step")
_frame_count = frame_count // self.video_sampling_step
_last_rgb_frame = os.path.join(self.root_path, video_id + self.video_suffix,
'img_{:05d}.jpg'.format(_frame_count))
if not os.path.isfile(_last_rgb_frame):
_frame_count = _frame_count - 1