Skip to content

Commit

Permalink
Fix DICOM issue with JPEG2K decoder (#955)
Browse files Browse the repository at this point in the history
This PR tries to address the issue raised in 948 where
DICOM with JPEG2K is not stable. The reason is that
JPEG2k's decoder takes a different approach with codec registration and cleanup.
And it is not safe to cleanup for jpeg2k until program exit.

This PR fixes 948.

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
  • Loading branch information
yongtang authored May 9, 2020
1 parent 824251a commit e6dee86
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 24 deletions.
72 changes: 54 additions & 18 deletions tensorflow_io/core/kernels/image_dicom_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

// clang-format off
#include "dcmtk/config/osconfig.h"

#include <dcmtk/dcmdata/dcfilefo.h>
#include "dcmtk/dcmdata/dcfilefo.h"

#include "dcmtk/dcmdata/dcdict.h"
#include "dcmtk/dcmdata/dcistrmb.h"
#include "dcmtk/dcmdata/dctk.h"
Expand All @@ -26,18 +28,18 @@ limitations under the License.

#include "dcmtk/dcmimgle/diutils.h"

#include "dcmtk/dcmimage/dipipng.h" /* for dcmimage PNG plugin */
#include "dcmtk/dcmimage/dipitiff.h" /* for dcmimage TIFF plugin */
#include "dcmtk/dcmjpeg/dipijpeg.h" /* for dcmimage JPEG plugin */
#include "dcmtk/dcmimage/dipipng.h" // for dcmimage PNG plugin
#include "dcmtk/dcmimage/dipitiff.h" // for dcmimage TIFF plugin
#include "dcmtk/dcmjpeg/dipijpeg.h" // for dcmimage JPEG plugin

#include "dcmtk/dcmimage/diregist.h"
#include "dcmtk/dcmimgle/dcmimage.h"

#include "dcmtk/dcmdata/dcrledrg.h" /* for DcmRLEDecoderRegistration */
#include "dcmtk/dcmjpeg/djdecode.h" /* for dcmjpeg decoders */
#include "dcmtk/dcmjpls/djdecode.h" /* for dcmjpls decoders */
#include "dcmtk/dcmdata/dcrledrg.h" // for DcmRLEDecoderRegistration
#include "dcmtk/dcmjpeg/djdecode.h" // for dcmjpeg decoders
#include "dcmtk/dcmjpls/djdecode.h" // for dcmjpls decoders

#include "fmjpeg2k/djdecode.h" / for fmjpeg2koj decoders * /
#include "fmjpeg2k/djdecode.h" // for fmjpeg2koj decoders

#include <cstdint>
#include <exception>
Expand All @@ -46,13 +48,55 @@ limitations under the License.
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/framework/types.h"

// clang-format on

typedef uint64_t
Uint64; // Uint64 not present in tensorflow::custom-op docker image dcmtk

namespace tensorflow {
namespace io {
namespace {

// FMJPEG2K is not safe to cleanup, so use DecoderRegistration
// to provide protection and only cleanup during program exit.
class DecoderRegistration {
public:
static void registerCodecs() { instance().registration(); }
static void cleanup() {}

private:
explicit DecoderRegistration() : initialized_(false) {}
~DecoderRegistration() {
mutex_lock l(mu_);
if (initialized_) {
DcmRLEDecoderRegistration::cleanup(); // deregister RLE codecs
DJDecoderRegistration::cleanup(); // deregister JPEG codecs
DJLSDecoderRegistration::cleanup(); // deregister JPEG-LS codecs
FMJPEG2KDecoderRegistration::cleanup(); // deregister fmjpeg2koj
initialized_ = false;
}
}

void registration() {
mutex_lock l(mu_);
if (!initialized_) {
DcmRLEDecoderRegistration::registerCodecs(); // register RLE codecs
DJDecoderRegistration::registerCodecs(); // register JPEG codecs
DJLSDecoderRegistration::registerCodecs(); // register JPEG-LS codecs
FMJPEG2KDecoderRegistration::registerCodecs(); // register fmjpeg2koj
initialized_ = true;
}
}
static DecoderRegistration &instance() {
static DecoderRegistration decoder_registration;
return decoder_registration;
}

private:
mutex mu_;
bool initialized_ TF_GUARDED_BY(mu_);
};

template <typename dtype>
class DecodeDICOMImageOp : public OpKernel {
public:
Expand All @@ -67,18 +111,10 @@ class DecodeDICOMImageOp : public OpKernel {
// Get the color_dim
OP_REQUIRES_OK(context, context->GetAttr("color_dim", &color_dim_));

DcmRLEDecoderRegistration::registerCodecs(); // register RLE codecs
DJDecoderRegistration::registerCodecs(); // register JPEG codecs
DJLSDecoderRegistration::registerCodecs(); // register JPEG-LS codecs
FMJPEG2KDecoderRegistration::registerCodecs(); // register fmjpeg2koj
DecoderRegistration::registerCodecs();
}

~DecodeDICOMImageOp() {
DcmRLEDecoderRegistration::cleanup(); // deregister RLE codecs
DJDecoderRegistration::cleanup(); // deregister JPEG codecs
DJLSDecoderRegistration::cleanup(); // deregister JPEG-LS codecs
FMJPEG2KDecoderRegistration::cleanup(); // deregister fmjpeg2koj
}
~DecodeDICOMImageOp() { DecoderRegistration::cleanup(); }

void Compute(OpKernelContext *context) override {
// Grab the input file content tensor
Expand Down
57 changes: 51 additions & 6 deletions tests/test_dicom_eager.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@


import os
import numpy as np
import pytest

import tensorflow as tf
Expand All @@ -35,8 +36,7 @@


def test_dicom_input():
"""test_dicom_input
"""
"""test_dicom_input"""
_ = tfio.image.decode_dicom_data
_ = tfio.image.decode_dicom_image
_ = tfio.image.dicom_tags
Expand Down Expand Up @@ -70,8 +70,7 @@ def test_dicom_input():
],
)
def test_decode_dicom_image(fname, exp_shape):
"""test_decode_dicom_image
"""
"""test_decode_dicom_image"""

dcm_path = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "test_dicom", fname
Expand Down Expand Up @@ -116,8 +115,7 @@ def test_decode_dicom_image(fname, exp_shape):
],
)
def test_decode_dicom_data(fname, tag, exp_value):
"""test_decode_dicom_data
"""
"""test_decode_dicom_data"""

dcm_path = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "test_dicom", fname
Expand Down Expand Up @@ -145,5 +143,52 @@ def test_dicom_image_shape():
dataset = dataset.map(lambda e: tf.image.resize(e, (224, 224)))


def test_dicom_image_concurrency():
"""test_decode_dicom_image_currency"""

@tf.function
def preprocess(dcm_content):
tags = tfio.image.decode_dicom_data(
dcm_content, tags=[tfio.image.dicom_tags.PatientsName]
)
tf.print(tags)
image = tfio.image.decode_dicom_image(dcm_content, dtype=tf.float32)
return image

dcm_path = os.path.join(
os.path.dirname(os.path.abspath(__file__)),
"test_dicom",
"TOSHIBA_J2K_OpenJPEGv2Regression.dcm",
)

dataset = (
tf.data.Dataset.from_tensor_slices([dcm_path])
.repeat()
.map(tf.io.read_file)
.map(preprocess, num_parallel_calls=8)
.take(200)
)
for i, item in enumerate(dataset):
print(tf.shape(item), i)
assert np.array_equal(tf.shape(item), [1, 512, 512, 1])

dcm_path = os.path.join(
os.path.dirname(os.path.abspath(__file__)),
"test_dicom",
"US-PAL-8-10x-echo.dcm",
)

dataset = (
tf.data.Dataset.from_tensor_slices([dcm_path])
.repeat()
.map(tf.io.read_file)
.map(preprocess, num_parallel_calls=8)
.take(200)
)
for i, item in enumerate(dataset):
print(tf.shape(item), i)
assert np.array_equal(tf.shape(item), [10, 430, 600, 3])


if __name__ == "__main__":
test.main()

0 comments on commit e6dee86

Please sign in to comment.