Machine learning to segment neutron images
Anders Kaestner, Beamline scientist - Neutron Imaging
Laboratory for Neutron Scattering and Imaging
Paul Scherrer Institut
If you want to run the notebook on your own computer, you'll need to perform the following step:
git clone https://github.com/ImagingLectures/MLSegmentation4NI.git
conda env create -f environment. yml -n MLSeg4NI
Enter the environment
conda env activate MLSeg4NI
Start jupyter and open the notebook lecture/ML4NeutronImageSegmentation.ipynb
Use the notebook
Leave the environment
conda env deactivate
This lecture needs some modules to run. We import all of them here.
import matplotlib.pyplot as plt
import seaborn as sn
import numpy as np
import pandas as pd
import skimage.filters as flt
import skimage.io as io
import matplotlib as mpl
from sklearn.cluster import KMeans
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import confusion_matrix
from sklearn.datasets import make_blobs
from matplotlib.colors import ListedColormap
from matplotlib.patches import Ellipse
from lecturesupport import plotsupport as ps
import scipy.stats as stats
import astropy.io.fits as fits
import keras.metrics as metrics
import keras.losses as loss
import keras.optimizers as opt
from keras.models import Model
from keras.layers import Input, Conv2D, MaxPooling2D, UpSampling2D, concatenate
%matplotlib inline
from IPython.display import set_matplotlib_formats
set_matplotlib_formats('svg', 'png')
#plt.style.use('seaborn')
mpl.rcParams['figure.dpi'] = 300
Using TensorFlow backend.
import importlib
importlib.reload(ps);
Introduction to neutron imaging
Introduction to segmentation
Problematic segmentation tasks
A very abstract definition:
In most cases this is a two- or three-dimensional position (x,y,z coordinates) and a numeric value (intensity)
Images are great for qualitative analyses since our brains can quickly interpret them without large programming investements.
The transmitted radiation is described by Beer-Lambert's law which in its basic form looks like
$$I=I_0\cdot{}e^{-\int_L \mu{}(x) dx}$$Fundamental information | Additional dimensions | Derived information |
---|---|---|
2D Radiography | Time series | q-values |
3D Tomography | Spectra | strain |
Crystal orientation |
Transmission through sample | X-ray attenuation | Neutron attenuation |
Start out with a simple image of a cross with added noise
$$ I(x,y) = f(x,y) $$fig,ax = plt.subplots(1,2,figsize=(12,6))
nx = 5; ny = 5;
# Create the test image
xx, yy = np.meshgrid(np.arange(-nx, nx+1)/nx*2*np.pi, np.arange(-ny, ny+1)/ny*2*np.pi)
cross_im = 1.5*np.abs(np.cos(xx*yy))/(np.abs(xx*yy)+(3*np.pi/nx)) + np.random.uniform(-0.25, 0.25, size = xx.shape)
# Show it
im=ax[0].imshow(cross_im, cmap = 'hot'); ax[0].set_title("Image")
ax[1].hist(cross_im.ravel(),bins=10); ax[1].set_xlabel('Gray value'); ax[1].set_ylabel('Counts'); ax[1].set_title("Histogram");
Applying the threshold is a deceptively simple operation
$$ I(x,y) = \begin{cases} 1, & f(x,y)\geq0.40 \\ 0, & f(x,y)<0.40 \end{cases}$$threshold = 0.4; thresh_img = cross_im > threshold
fig,ax = plt.subplots(1,2,figsize=(12,6))
ax[0].imshow(cross_im, cmap = 'hot', extent = [xx.min(), xx.max(), yy.min(), yy.max()]); ax[0].set_title("Image")
ax[0].plot(xx[np.where(thresh_img)]*0.9, yy[np.where(thresh_img)]*0.9,
'ks', markerfacecolor = 'green', alpha = 0.5,label = 'Threshold', markersize = 22); ax[0].legend(fontsize=12);
ax[1].hist(cross_im.ravel(),bins=10); ax[1].axvline(x=threshold,color='r',label='Threshold'); ax[1].legend(fontsize=12);
ax[1].set_xlabel('Gray value'); ax[1].set_ylabel('Counts'); ax[1].set_title("Histogram");
The noise in neutron imaging mainly originates from the amount of captured neutrons.
This noise is Poisson distributed and the signal to noise ratio is
$$SNR=\frac{E[x]}{s[x]}\sim\frac{N}{\sqrt{N}}=\sqrt{N}$$Woodland Encounter Bev Doolittle
In neutron imaging you see all these image phenomena.
Different types of limited data:
Data augmentation is a method modify your exisiting data to obtain variations of it.
Augmentation will be used to increase the training data in the root segmenation example in the end of this lecture.
Both augmented and simulated data should be combined with real data.
Transfer learning is a technique that uses a pre-trained network to
test_pts = pd.DataFrame(make_blobs(n_samples=200, random_state=2018)[
0], columns=['x', 'y'])
plt.plot(test_pts.x, test_pts.y, 'r.');
The user only have to provide the number of classes the algorithm shall find.
Note The algorithm will find exactly the number you ask it to, it doesn't care if it makes sense!
N=3
fig, ax = plt.subplots(1,N,figsize=(18,4.5))
for i in range(N) :
km = KMeans(n_clusters=i+2, random_state=2018); n_grp = km.fit_predict(test_pts)
ax[i].scatter(test_pts.x, test_pts.y, c=n_grp)
ax[i].set_title('{0} groups'.format(i+2))
orig = fits.getdata('../data/spots/mixture12_00001.fits')[::4,::4]
fig,ax = plt.subplots(1,6,figsize=(18,5)); x,y = np.meshgrid(np.linspace(0,1,orig.shape[0]),np.linspace(0,1,orig.shape[1]))
ax[0].imshow(orig, vmin=0, vmax=4000), ax[0].set_title('Original')
ax[1].imshow(x), ax[1].set_title('x-coordinates')
ax[2].imshow(y), ax[2].set_title('y-coordinates')
ax[3].imshow(flt.gaussian(orig, sigma=5)), ax[3].set_title('Weighted neighborhood')
ax[4].imshow(flt.sobel_h(orig),vmin=0, vmax=0.001),ax[4].set_title('Horizontal edges')
ax[5].imshow(flt.sobel_v(orig),vmin=0, vmax=0.001),ax[5].set_title('Vertical edges');
tof = np.load('../data/tofdata.npy')
wtof = tof.mean(axis=2)
plt.imshow(wtof,cmap='gray');
plt.title('Average intensity all time bins');
fig, ax= plt.subplots(1,2,figsize=(12,5))
ax[0].imshow(wtof,cmap='gray'); ax[0].set_title('Average intensity all time bins');
ax[0].plot(57,3,'ro'), ax[0].plot(15,30,'bo'), ax[0].plot(79,90,'go'); ax[0].plot(100,120,'co');
ax[1].plot(tof[30,15,:],'b', label='Sample'); ax[1].plot(tof[3,57,:],'r', label='Background'); ax[1].plot(tof[90,79,:],'g', label='Spacer'); ax[1].legend();ax[1].plot(tof[120,100,:],'c', label='Sample 2');
tofr=tof.reshape([tof.shape[0]*tof.shape[1],tof.shape[2]])
print("Input ToF dimensions",tof.shape)
print("Reshaped ToF data",tofr.shape)
Input ToF dimensions (128, 128, 661) Reshaped ToF data (16384, 661)
km = KMeans(n_clusters=4, random_state=2018) # Random state is an initialization parameter for the random number generator
c = km.fit_predict(tofr).reshape(tof.shape[:2]) # Label image
kc = km.cluster_centers_.transpose() # cluster centroid spectra
Results from the first try
fig,axes = plt.subplots(1,3,figsize=(18,5)); axes=axes.ravel()
axes[0].imshow(wtof,cmap='viridis'); axes[0].set_title('Average image')
p=axes[1].plot(kc); axes[1].set_title('Cluster centroid spectra'); axes[1].set_aspect(tof.shape[2], adjustable='box')
cmap=ps.buildCMap(p) # Create a color map with the same colors as the plot
im=axes[2].imshow(c,cmap=cmap); plt.colorbar(im);
axes[2].set_title('Cluster map');
plt.tight_layout()
What happens when we increase the number of clusters to ten?
km = KMeans(n_clusters=10, random_state=2018)
c = km.fit_predict(tofr).reshape(tof.shape[:2]) # Label image
kc = km.cluster_centers_.transpose() # cluster centroid spectra
Results of k-means with ten clusters
fig,axes = plt.subplots(1,3,figsize=(18,5)); axes=axes.ravel()
axes[0].imshow(wtof,cmap='gray'); axes[0].set_title('Average image')
p=axes[1].plot(kc); axes[1].set_title('Cluster centroid spectra'); axes[1].set_aspect(tof.shape[2], adjustable='box')
cmap=ps.buildCMap(p) # Create a color map with the same colors as the plot
im=axes[2].imshow(c,cmap=cmap); plt.colorbar(im);
axes[2].set_title('Cluster map');
plt.tight_layout()
fig,axes = plt.subplots(1,1,figsize=(14,5));
plt.plot(kc); axes.set_title('Cluster centroid spectra');
axes.add_patch(Ellipse((0,0.62), width=30,height=0.55,fill=False,color='r')) #,axes.set_aspect(tof.shape[2], adjustable='box')
axes.add_patch(Ellipse((0,0.24), width=30,height=0.15,fill=False,color='cornflowerblue')),axes.set_aspect(tof.shape[2], adjustable='box');
del km, c, kc, tofr, tof
blob_data, blob_labels = make_blobs(n_samples=100, random_state=2018)
test_pts = pd.DataFrame(blob_data, columns=['x', 'y'])
test_pts['group_id'] = blob_labels
plt.scatter(test_pts.x, test_pts.y, c=test_pts.group_id, cmap='viridis');
orig= fits.getdata('../data/spots/mixture12_00001.fits')
annotated=io.imread('../data/spots/mixture12_00001.png'); mask=(annotated[:,:,1]==0)
r=600; c=600; w=256
ps.magnifyRegion(orig,[r,c,r+w,c+w],[15,7],vmin=400,vmax=4000,title='Neutron radiography')
Parameters
selem=np.ones([3,3])
forig=orig.astype('float32')
mimg = flt.median(forig,selem=selem)
d = np.abs(forig-mimg)
fig,ax=plt.subplots(1,2,figsize=(12,5))
h,x,y,u=ax[0].hist2d(forig.ravel(),d.ravel(), bins=100);
ax[0].set_xlabel('Input image - $f$'),ax[0].set_ylabel('$|f-med_{3x3}(f)|$'),ax[0].set_title('Bivariate histogram');
ax[1].imshow(np.log(h[:,::-1]+1).transpose(),vmin=0,vmax=10,extent=[x.min(),x.max(),y.min(),y.max()])
ax[1].set_xlabel('Input image - $f$'),ax[1].set_ylabel('$|f-med_{3x3}(f)|$'),ax[1].set_title('Log bivariate histogram');
def spotCleaner(img, threshold=0.95, selem=np.ones([3,3])) :
fimg=img.astype('float32')
mimg = flt.median(fimg,selem=selem)
timg = threshold < np.abs(fimg-mimg)
cleaned = mimg * timg + fimg * (1-timg)
return (cleaned,timg)
baseclean,timg = spotCleaner(orig,threshold=1000)
ps.magnifyRegion(baseclean,[r,c,r+w,c+w],[12,3],vmin=400,vmax=4000,title='Cleaned image')
ps.magnifyRegion(timg,[r,c,r+w,c+w],[12,3],vmin=0,vmax=1,title='Detection image')
Training data
trainorig = forig[:,:1000].ravel()
traind = d[:,:1000].ravel()
trainmask = mask[:,:1000].ravel()
train_pts = pd.DataFrame({'orig': trainorig, 'd': traind, 'mask':trainmask})
Test data
testorig = forig[:,1000:].ravel()
testd = d[:,1000:].ravel()
testmask = mask[:,1000:].ravel()
test_pts = pd.DataFrame({'orig': testorig, 'd': testd, 'mask':testmask})
k_class = KNeighborsClassifier(1)
k_class.fit(train_pts[['orig', 'd']], train_pts['mask'])
KNeighborsClassifier(n_neighbors=1)
Inspect decision space
xx, yy = np.meshgrid(np.linspace(test_pts.orig.min(), test_pts.orig.max(), 100),
np.linspace(test_pts.d.min(), test_pts.d.max(), 100),indexing='ij');
grid_pts = pd.DataFrame(dict(x=xx.ravel(), y=yy.ravel()))
grid_pts['predicted_id'] = k_class.predict(grid_pts[['x', 'y']])
plt.scatter(grid_pts.x, grid_pts.y, c=grid_pts.predicted_id, cmap='gray'); plt.title('Testing Points'); plt.axis('square');
pred = k_class.predict(test_pts[['orig', 'd']])
pimg = pred.reshape(d[:,1000:].shape)
fig,ax = plt.subplots(1,3,figsize=(15,6))
ax[0].imshow(forig[:,1000:],vmin=0,vmax=4000), ax[0].set_title('Original image')
ax[1].imshow(pimg), ax[1].set_title('Predicted spot')
ax[2].imshow(mask[:,1000:]),ax[2].set_title('Annotated spots');
ps.showHitMap(mask[:,1000:],timg[:,1000:])
plt.savefig('spotbaseline.png',dpi=300)
ps.showHitMap(mask[:,1000:], pimg)
plt.savefig('spotknn.png',dpi=300)
Note There are other spot detection methods that perform better than the baseline.
del k_class
Classification | Segmentation |
---|---|
pixels to classes | pixels to pixels |
Segmentation is mostly based on variations of the U-Net architechture
Scales in traditional image processing | U-Net architecture |
---|---|
We have two choices:
We will use the spotty image as training data for this example
There is only one image!
fig,ax=plt.subplots(1,2,figsize=(12,5))
ax[0].imshow(forig,vmin=0,vmax=4000,cmap='gray'); ax[0].set_title('Original');
ax[1].imshow(mask,cmap='gray'); ax[1].set_title('Mask');
Any analysis system must be verified to be demonstrate its performance and to further optimize it.
For this we need to split our data into three categories:
wpos = [1100,600]; ww = 512
train_img, valid_img, forigc = forig[128:256, 500:1300], forig[500:1000, 300:1500], forig[wpos[0]:(wpos[0]+ww),wpos[1]:(wpos[1]+ww)]
train_mask, valid_mask, maskc = mask[128:256, 500:1300], mask[500:1000, 300:1500], mask[wpos[0]:(wpos[0]+ww),wpos[1]:(wpos[1]+ww)]
fig, ax = plt.subplots(1, 4, figsize=(15, 6), dpi=300); ax=ax.ravel()
ax[0].imshow(train_img, cmap='bone',vmin=0,vmax=4000);ax[0].set_title('Train Image')
ax[1].imshow(train_mask, cmap='bone'); ax[1].set_title('Train Mask')
ax[2].imshow(valid_img, cmap='bone',vmin=0,vmax=4000); ax[2].set_title('Validation Image')
ax[3].imshow(valid_mask, cmap='bone');ax[3].set_title('Validation Mask');
Training | Validation | Test |
---|---|---|
70% | 15% | 15% |
We need:
def buildSpotUNet( base_depth = 48) :
in_img = Input((None, None, 1), name='Image_Input')
lay_1 = Conv2D(base_depth, kernel_size=(3, 3), padding='same',activation='relu')(in_img)
lay_2 = Conv2D(base_depth, kernel_size=(3, 3), padding='same',activation='relu')(lay_1)
lay_3 = MaxPooling2D(pool_size=(2, 2))(lay_2)
lay_4 = Conv2D(base_depth*2, kernel_size=(3, 3), padding='same',activation='relu')(lay_3)
lay_5 = Conv2D(base_depth*2, kernel_size=(3, 3), padding='same',activation='relu')(lay_4)
lay_6 = MaxPooling2D(pool_size=(2, 2))(lay_5)
lay_7 = Conv2D(base_depth*4, kernel_size=(3, 3), padding='same',activation='relu')(lay_6)
lay_8 = Conv2D(base_depth*4, kernel_size=(3, 3), padding='same',activation='relu')(lay_7)
lay_9 = UpSampling2D((2, 2))(lay_8)
lay_10 = concatenate([lay_5, lay_9])
lay_11 = Conv2D(base_depth*2, kernel_size=(3, 3), padding='same',activation='relu')(lay_10)
lay_12 = Conv2D(base_depth*2, kernel_size=(3, 3), padding='same',activation='relu')(lay_11)
lay_13 = UpSampling2D((2, 2))(lay_12)
lay_14 = concatenate([lay_2, lay_13])
lay_15 = Conv2D(base_depth, kernel_size=(3, 3), padding='same',activation='relu')(lay_14)
lay_16 = Conv2D(base_depth, kernel_size=(3, 3), padding='same',activation='relu')(lay_15)
lay_17 = Conv2D(1, kernel_size=(1, 1), padding='same',
activation='relu')(lay_16)
t_unet = Model(inputs=[in_img], outputs=[lay_17], name='SpotUNET')
return t_unet
Model summary
t_unet = buildSpotUNet(base_depth=24)
t_unet.summary()
WARNING:tensorflow:From /home/travis/miniconda/envs/book/lib/python3.7/site-packages/tensorflow_core/python/ops/resource_variable_ops.py:1630: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version. Instructions for updating: If using Keras pass *_constraint arguments to layers. WARNING:tensorflow:From /home/travis/miniconda/envs/book/lib/python3.7/site-packages/keras/backend/tensorflow_backend.py:4070: The name tf.nn.max_pool is deprecated. Please use tf.nn.max_pool2d instead. Model: "SpotUNET" __________________________________________________________________________________________________ Layer (type) Output Shape Param # Connected to ================================================================================================== Image_Input (InputLayer) (None, None, None, 1 0 __________________________________________________________________________________________________ conv2d_1 (Conv2D) (None, None, None, 2 240 Image_Input[0][0] __________________________________________________________________________________________________ conv2d_2 (Conv2D) (None, None, None, 2 5208 conv2d_1[0][0] __________________________________________________________________________________________________ max_pooling2d_1 (MaxPooling2D) (None, None, None, 2 0 conv2d_2[0][0] __________________________________________________________________________________________________ conv2d_3 (Conv2D) (None, None, None, 4 10416 max_pooling2d_1[0][0] __________________________________________________________________________________________________ conv2d_4 (Conv2D) (None, None, None, 4 20784 conv2d_3[0][0] __________________________________________________________________________________________________ max_pooling2d_2 (MaxPooling2D) (None, None, None, 4 0 conv2d_4[0][0] __________________________________________________________________________________________________ conv2d_5 (Conv2D) (None, None, None, 9 41568 max_pooling2d_2[0][0] __________________________________________________________________________________________________ conv2d_6 (Conv2D) (None, None, None, 9 83040 conv2d_5[0][0] __________________________________________________________________________________________________ up_sampling2d_1 (UpSampling2D) (None, None, None, 9 0 conv2d_6[0][0] __________________________________________________________________________________________________ concatenate_1 (Concatenate) (None, None, None, 1 0 conv2d_4[0][0] up_sampling2d_1[0][0] __________________________________________________________________________________________________ conv2d_7 (Conv2D) (None, None, None, 4 62256 concatenate_1[0][0] __________________________________________________________________________________________________ conv2d_8 (Conv2D) (None, None, None, 4 20784 conv2d_7[0][0] __________________________________________________________________________________________________ up_sampling2d_2 (UpSampling2D) (None, None, None, 4 0 conv2d_8[0][0] __________________________________________________________________________________________________ concatenate_2 (Concatenate) (None, None, None, 7 0 conv2d_2[0][0] up_sampling2d_2[0][0] __________________________________________________________________________________________________ conv2d_9 (Conv2D) (None, None, None, 2 15576 concatenate_2[0][0] __________________________________________________________________________________________________ conv2d_10 (Conv2D) (None, None, None, 2 5208 conv2d_9[0][0] __________________________________________________________________________________________________ conv2d_11 (Conv2D) (None, None, None, 1 25 conv2d_10[0][0] ================================================================================================== Total params: 265,105 Trainable params: 265,105 Non-trainable params: 0 __________________________________________________________________________________________________
def prep_img(x, n=1):
return (prep_mask(x, n=n)-train_img.mean())/train_img.std()
def prep_mask(x, n=1):
return np.stack([np.expand_dims(x, -1)]*n, 0)
unet_pred = t_unet.predict(prep_img(forigc))[0, :, :, 0]
WARNING:tensorflow:From /home/travis/miniconda/envs/book/lib/python3.7/site-packages/keras/backend/tensorflow_backend.py:422: The name tf.global_variables is deprecated. Please use tf.compat.v1.global_variables instead.
fig, m_axs = plt.subplots(2, 3, figsize=(15, 6), dpi=150)
for c_ax in m_axs.ravel():
c_ax.axis('off')
((ax1, _, ax2), (ax3, ax4, ax5)) = m_axs
ax1.imshow(train_img, cmap='bone',vmin=0,vmax=4000); ax1.set_title('Train Image')
ax2.imshow(train_mask, cmap='viridis'); ax2.set_title('Train Mask')
ax3.imshow(forigc, cmap='bone',vmin=0, vmax=4000); ax3.set_title('Test Image')
ax4.imshow(unet_pred, cmap='viridis', vmin=0, vmax=0.1); ax4.set_title('Predicted Segmentation')
ax5.imshow(maskc, cmap='viridis'); ax5.set_title('Ground Truth');
The untrained model doesn't perform very well. You clearly see that the image structures appear here. What is worth noting the spots already appear as amplified. This what we want to improve during the training.
mlist = [
metrics.TruePositives(name='tp'), metrics.FalsePositives(name='fp'),
metrics.TrueNegatives(name='tn'), metrics.FalseNegatives(name='fn'),
metrics.BinaryAccuracy(name='accuracy'), metrics.Precision(name='precision'),
metrics.Recall(name='recall'), metrics.AUC(name='auc'),
metrics.MeanAbsoluteError(name='mae')]
t_unet.compile(
loss=loss.BinaryCrossentropy(), # we use the binary cross-entropy to optimize
optimizer=opt.Adam(lr=1e-3), # we use ADAM to optimize
metrics=mlist # we keep track of the metrics in mlist
)
WARNING:tensorflow:From /home/travis/miniconda/envs/book/lib/python3.7/site-packages/keras/backend/tensorflow_backend.py:3172: where (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version. Instructions for updating: Use tf.where in 2.0, which has the same broadcast rule as np.where
This is a very bad way to train a model;
The goal is to be aware of these techniques and have a feeling for how they can work for complex problems.
Nsamples = 3
Nepochs = 20
loss_history = t_unet.fit(prep_img(train_img, n=Nsamples),
prep_mask(train_mask, n=Nsamples),
validation_data=(prep_img(valid_img),
prep_mask(valid_mask)),
epochs=Nepochs,
verbose = 2)
Train on 3 samples, validate on 1 samples Epoch 1/20 - 10s - loss: 0.0885 - tp: 0.0000e+00 - fp: 0.0000e+00 - tn: 304656.0000 - fn: 2544.0000 - accuracy: 0.9917 - precision: 0.0000e+00 - recall: 0.0000e+00 - auc: 0.4037 - mae: 0.0407 - val_loss: 0.1122 - val_tp: 10.0000 - val_fp: 5.0000 - val_tn: 593511.0000 - val_fn: 6474.0000 - val_accuracy: 0.9892 - val_precision: 0.6667 - val_recall: 0.0015 - val_auc: 0.5409 - val_mae: 0.0206 Epoch 2/20 - 7s - loss: 0.0929 - tp: 6.0000 - fp: 0.0000e+00 - tn: 304656.0000 - fn: 2538.0000 - accuracy: 0.9917 - precision: 1.0000 - recall: 0.0024 - auc: 0.5248 - mae: 0.0159 - val_loss: 0.1100 - val_tp: 23.0000 - val_fp: 13.0000 - val_tn: 593503.0000 - val_fn: 6461.0000 - val_accuracy: 0.9892 - val_precision: 0.6389 - val_recall: 0.0035 - val_auc: 0.5923 - val_mae: 0.0363 Epoch 3/20 - 7s - loss: 0.0951 - tp: 12.0000 - fp: 0.0000e+00 - tn: 304656.0000 - fn: 2532.0000 - accuracy: 0.9918 - precision: 1.0000 - recall: 0.0047 - auc: 0.5822 - mae: 0.0280 - val_loss: 0.1010 - val_tp: 16.0000 - val_fp: 8.0000 - val_tn: 593508.0000 - val_fn: 6468.0000 - val_accuracy: 0.9892 - val_precision: 0.6667 - val_recall: 0.0025 - val_auc: 0.6288 - val_mae: 0.0215 Epoch 4/20 - 7s - loss: 0.0853 - tp: 6.0000 - fp: 0.0000e+00 - tn: 304656.0000 - fn: 2538.0000 - accuracy: 0.9917 - precision: 1.0000 - recall: 0.0024 - auc: 0.6159 - mae: 0.0165 - val_loss: 0.1214 - val_tp: 10.0000 - val_fp: 2.0000 - val_tn: 593514.0000 - val_fn: 6474.0000 - val_accuracy: 0.9892 - val_precision: 0.8333 - val_recall: 0.0015 - val_auc: 0.6416 - val_mae: 0.0108 Epoch 5/20 - 7s - loss: 0.0965 - tp: 3.0000 - fp: 0.0000e+00 - tn: 304656.0000 - fn: 2541.0000 - accuracy: 0.9917 - precision: 1.0000 - recall: 0.0012 - auc: 0.6263 - mae: 0.0083 - val_loss: 0.0993 - val_tp: 18.0000 - val_fp: 7.0000 - val_tn: 593509.0000 - val_fn: 6466.0000 - val_accuracy: 0.9892 - val_precision: 0.7200 - val_recall: 0.0028 - val_auc: 0.7299 - val_mae: 0.0112 Epoch 6/20 - 7s - loss: 0.0812 - tp: 6.0000 - fp: 0.0000e+00 - tn: 304656.0000 - fn: 2538.0000 - accuracy: 0.9917 - precision: 1.0000 - recall: 0.0024 - auc: 0.7084 - mae: 0.0086 - val_loss: 0.0779 - val_tp: 76.0000 - val_fp: 36.0000 - val_tn: 593480.0000 - val_fn: 6408.0000 - val_accuracy: 0.9892 - val_precision: 0.6786 - val_recall: 0.0117 - val_auc: 0.7385 - val_mae: 0.0248 Epoch 7/20 - 7s - loss: 0.0759 - tp: 36.0000 - fp: 6.0000 - tn: 304650.0000 - fn: 2508.0000 - accuracy: 0.9918 - precision: 0.8571 - recall: 0.0142 - auc: 0.6784 - mae: 0.0191 - val_loss: 0.0646 - val_tp: 162.0000 - val_fp: 123.0000 - val_tn: 593393.0000 - val_fn: 6322.0000 - val_accuracy: 0.9892 - val_precision: 0.5684 - val_recall: 0.0250 - val_auc: 0.8015 - val_mae: 0.0364 Epoch 8/20 - 7s - loss: 0.0564 - tp: 93.0000 - fp: 48.0000 - tn: 304608.0000 - fn: 2451.0000 - accuracy: 0.9919 - precision: 0.6596 - recall: 0.0366 - auc: 0.7700 - mae: 0.0281 - val_loss: 0.0788 - val_tp: 218.0000 - val_fp: 219.0000 - val_tn: 593297.0000 - val_fn: 6266.0000 - val_accuracy: 0.9892 - val_precision: 0.4989 - val_recall: 0.0336 - val_auc: 0.8566 - val_mae: 0.0613 Epoch 9/20 - 7s - loss: 0.0781 - tp: 147.0000 - fp: 93.0000 - tn: 304563.0000 - fn: 2397.0000 - accuracy: 0.9919 - precision: 0.6125 - recall: 0.0578 - auc: 0.8611 - mae: 0.0644 - val_loss: 0.0633 - val_tp: 191.0000 - val_fp: 127.0000 - val_tn: 593389.0000 - val_fn: 6293.0000 - val_accuracy: 0.9893 - val_precision: 0.6006 - val_recall: 0.0295 - val_auc: 0.8870 - val_mae: 0.0446 Epoch 10/20 - 7s - loss: 0.0600 - tp: 123.0000 - fp: 54.0000 - tn: 304602.0000 - fn: 2421.0000 - accuracy: 0.9919 - precision: 0.6949 - recall: 0.0483 - auc: 0.8944 - mae: 0.0457 - val_loss: 0.0478 - val_tp: 148.0000 - val_fp: 67.0000 - val_tn: 593449.0000 - val_fn: 6336.0000 - val_accuracy: 0.9893 - val_precision: 0.6884 - val_recall: 0.0228 - val_auc: 0.8967 - val_mae: 0.0232 Epoch 11/20 - 7s - loss: 0.0393 - tp: 90.0000 - fp: 33.0000 - tn: 304623.0000 - fn: 2454.0000 - accuracy: 0.9919 - precision: 0.7317 - recall: 0.0354 - auc: 0.8938 - mae: 0.0188 - val_loss: 0.0533 - val_tp: 108.0000 - val_fp: 32.0000 - val_tn: 593484.0000 - val_fn: 6376.0000 - val_accuracy: 0.9893 - val_precision: 0.7714 - val_recall: 0.0167 - val_auc: 0.8892 - val_mae: 0.0137 Epoch 12/20 - 7s - loss: 0.0496 - tp: 66.0000 - fp: 18.0000 - tn: 304638.0000 - fn: 2478.0000 - accuracy: 0.9919 - precision: 0.7857 - recall: 0.0259 - auc: 0.8444 - mae: 0.0106 - val_loss: 0.0441 - val_tp: 127.0000 - val_fp: 42.0000 - val_tn: 593474.0000 - val_fn: 6357.0000 - val_accuracy: 0.9893 - val_precision: 0.7515 - val_recall: 0.0196 - val_auc: 0.9295 - val_mae: 0.0159 Epoch 13/20 - 7s - loss: 0.0358 - tp: 84.0000 - fp: 27.0000 - tn: 304629.0000 - fn: 2460.0000 - accuracy: 0.9919 - precision: 0.7568 - recall: 0.0330 - auc: 0.9254 - mae: 0.0133 - val_loss: 0.0468 - val_tp: 136.0000 - val_fp: 47.0000 - val_tn: 593469.0000 - val_fn: 6348.0000 - val_accuracy: 0.9893 - val_precision: 0.7432 - val_recall: 0.0210 - val_auc: 0.8983 - val_mae: 0.0237 Epoch 14/20 - 7s - loss: 0.0433 - tp: 90.0000 - fp: 27.0000 - tn: 304629.0000 - fn: 2454.0000 - accuracy: 0.9919 - precision: 0.7692 - recall: 0.0354 - auc: 0.8614 - mae: 0.0264 - val_loss: 0.0529 - val_tp: 145.0000 - val_fp: 54.0000 - val_tn: 593462.0000 - val_fn: 6339.0000 - val_accuracy: 0.9893 - val_precision: 0.7286 - val_recall: 0.0224 - val_auc: 0.8489 - val_mae: 0.0311 Epoch 15/20 - 7s - loss: 0.0543 - tp: 93.0000 - fp: 33.0000 - tn: 304623.0000 - fn: 2451.0000 - accuracy: 0.9919 - precision: 0.7381 - recall: 0.0366 - auc: 0.7918 - mae: 0.0382 - val_loss: 0.0532 - val_tp: 142.0000 - val_fp: 50.0000 - val_tn: 593466.0000 - val_fn: 6342.0000 - val_accuracy: 0.9893 - val_precision: 0.7396 - val_recall: 0.0219 - val_auc: 0.8291 - val_mae: 0.0302 Epoch 16/20 - 7s - loss: 0.0549 - tp: 93.0000 - fp: 33.0000 - tn: 304623.0000 - fn: 2451.0000 - accuracy: 0.9919 - precision: 0.7381 - recall: 0.0366 - auc: 0.7748 - mae: 0.0383 - val_loss: 0.0525 - val_tp: 129.0000 - val_fp: 41.0000 - val_tn: 593475.0000 - val_fn: 6355.0000 - val_accuracy: 0.9893 - val_precision: 0.7588 - val_recall: 0.0199 - val_auc: 0.8227 - val_mae: 0.0259 Epoch 17/20 - 7s - loss: 0.0509 - tp: 78.0000 - fp: 27.0000 - tn: 304629.0000 - fn: 2466.0000 - accuracy: 0.9919 - precision: 0.7429 - recall: 0.0307 - auc: 0.7772 - mae: 0.0324 - val_loss: 0.0515 - val_tp: 121.0000 - val_fp: 34.0000 - val_tn: 593482.0000 - val_fn: 6363.0000 - val_accuracy: 0.9893 - val_precision: 0.7806 - val_recall: 0.0187 - val_auc: 0.8337 - val_mae: 0.0205 Epoch 18/20 - 7s - loss: 0.0453 - tp: 75.0000 - fp: 21.0000 - tn: 304635.0000 - fn: 2469.0000 - accuracy: 0.9919 - precision: 0.7812 - recall: 0.0295 - auc: 0.8012 - mae: 0.0242 - val_loss: 0.0528 - val_tp: 112.0000 - val_fp: 28.0000 - val_tn: 593488.0000 - val_fn: 6372.0000 - val_accuracy: 0.9893 - val_precision: 0.8000 - val_recall: 0.0173 - val_auc: 0.8621 - val_mae: 0.0147 Epoch 19/20 - 7s - loss: 0.0414 - tp: 66.0000 - fp: 15.0000 - tn: 304641.0000 - fn: 2478.0000 - accuracy: 0.9919 - precision: 0.8148 - recall: 0.0259 - auc: 0.8439 - mae: 0.0150 - val_loss: 0.0555 - val_tp: 107.0000 - val_fp: 26.0000 - val_tn: 593490.0000 - val_fn: 6377.0000 - val_accuracy: 0.9893 - val_precision: 0.8045 - val_recall: 0.0165 - val_auc: 0.8866 - val_mae: 0.0114 Epoch 20/20 - 7s - loss: 0.0403 - tp: 66.0000 - fp: 15.0000 - tn: 304641.0000 - fn: 2478.0000 - accuracy: 0.9919 - precision: 0.8148 - recall: 0.0259 - auc: 0.8855 - mae: 0.0091 - val_loss: 0.0545 - val_tp: 118.0000 - val_fp: 27.0000 - val_tn: 593489.0000 - val_fn: 6366.0000 - val_accuracy: 0.9893 - val_precision: 0.8138 - val_recall: 0.0182 - val_auc: 0.8903 - val_mae: 0.0117
titleDict = {'tp': "True Positives",'fp': "False Positives",'tn': "True Negatives",'fn': "False Negatives", 'accuracy':"BinaryAccuracy",'precision': "Precision",'recall':"Recall",'auc': "Area under Curve", 'mae': "Mean absolute error"}
fig,ax = plt.subplots(2,5, figsize=(20,8), dpi=300)
ax =ax.ravel()
for idx,key in enumerate(titleDict.keys()):
ax[idx].plot(loss_history.epoch, loss_history.history[key], color='coral', label='Training')
ax[idx].plot(loss_history.epoch, loss_history.history['val_'+key], color='cornflowerblue', label='Validation')
ax[idx].set_title(titleDict[key]);
ax[9].axis('off');
axLine, axLabel = ax[0].get_legend_handles_labels() # Take the lables and plot line information from the first panel
lines =[]; labels = []; lines.extend(axLine); labels.extend(axLabel);fig.legend(lines, labels, bbox_to_anchor=(0.7, 0.3), loc='upper left');
unet_train_pred = t_unet.predict(prep_img(train_img[:,wpos[1]:(wpos[1]+ww)]))[0, :, :, 0]
fig, m_axs = plt.subplots(1, 3, figsize=(18, 4), dpi=150); m_axs= m_axs.ravel();
for c_ax in m_axs: c_ax.axis('off')
m_axs[0].imshow(train_img[:,wpos[1]:(wpos[1]+ww)], cmap='bone', vmin=0, vmax=4000), m_axs[0].set_title('Train Image')
m_axs[1].imshow(unet_train_pred, cmap='viridis', vmin=0, vmax=0.2), m_axs[1].set_title('Predicted Training')
m_axs[2].imshow(train_mask[:,wpos[1]:(wpos[1]+ww)], cmap='viridis'), m_axs[2].set_title('Train Mask');
unet_pred = t_unet.predict(prep_img(forigc))[0, :, :, 0]
fig, m_axs = plt.subplots(1, 3, figsize=(18, 4), dpi=150); m_axs = m_axs.ravel() ;
for c_ax in m_axs: c_ax.axis('off')
m_axs[0].imshow(forigc, cmap='bone', vmin=0, vmax=4000); m_axs[0].set_title('Full Image')
f1=m_axs[1].imshow(unet_pred, cmap='viridis', vmin=0, vmax=0.1); m_axs[1].set_title('Predicted Segmentation'); fig.colorbar(f1,ax=m_axs[1]);
m_axs[2].imshow(maskc,cmap='viridis'); m_axs[2].set_title('Ground Truth');
fig, ax = plt.subplots(1,2, figsize=(12,4))
ax0=ax[0].imshow(unet_pred, vmin=0, vmax=0.1); ax[0].set_title('Predicted segmentation'); fig.colorbar(ax0,ax=ax[0])
ax[1].imshow(0.05<unet_pred), ax[1].set_title('Final segmenation');
gt = maskc
pr = 0.05<unet_pred
ps.showHitCases(gt,pr,cmap='gray')
ps.showHitMap(gt,pr)
plt.savefig('spotunet.png')
Baseline
k-NN
U-Net
Improvements
Today: much of this mark-up is done manually!
Acknowledgement: This work was done by Gian Guido Parenza as a master project.
Radiography | Tomography |
---|---|
Provided by A. Carminati et al. | Provided by M. Menon et al. |
Radiograph of a rhizobox | Current segmentation |
---|---|
Problems: Unwanted elements are marked as roots
Compare different loss functions | Details of branching loss |
---|---|
Our options
Medical image processing is the saviour!
U-Net trained with roots only | U-Net trained with transfer learning |
---|---|
Transfer learning
after some modification
Original tomography data and ground truth | Segmentations |
---|---|
This project has shown that
The models still need more training to cover wider variations in the data.
We have demonstrated how some machine learning techniques can be used on neutron images: