# Generated by pandoc-plot 1.9.1
import matplotlib.pyplot as plt
import numpy as np

np.random.seed(2019)

import numpy as np
import matplotlib.pyplot as plt
from skued import diffread
from pathlib import Path
from math import floor
from skimage.registration import phase_cross_correlation
from skimage.registration._masked_phase_cross_correlation import cross_correlate_masked


def center_of_mass_masked(im, mask):
    rr, cc = np.indices(im.shape)
    weights = im * mask.astype(im.dtype)

    r = np.average(rr, weights=weights)
    c = np.average(cc, weights=weights)
    return int(r), int(c)


def correlate(arr1, arr2, m1, m2):
    return cross_correlate_masked(arr1=arr1, arr2=arr2, m1=m1, m2=m2, mode='same')


im = diffread(Path("images") / "autocenter" / "graphite.tif")
mask = diffread(Path("images") / "autocenter" / "graphite_mask.tif").astype(bool)

r, c = center_of_mass_masked(im, mask)
side_length = floor(min([r, abs(r - im.shape[0]), c, abs(c - im.shape[1])]))
rs = slice(r - side_length, r + side_length)
cs = slice(c - side_length, c + side_length)
im = im[rs, cs]
mask = mask[rs, cs]
im_r = im[::-1, ::-1]
mask_r = mask[::-1, ::-1]

fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(6, 6))
xcorr = np.abs(correlate(im, im_r, mask, mask_r))
for ax, image in zip([ax1, ax2], [im, im_r]):
    ax.imshow(image, vmin=0, vmax=200, cmap="inferno")

ax3.imshow(xcorr, vmin=xcorr.mean(), cmap="inferno")
ax3.axhline(y=xcorr.shape[0] / 2, linestyle="--", color="w", linewidth=1)
ax3.axvline(x=xcorr.shape[1] / 2, linestyle="--", color="w", linewidth=1)

xcorr_zoomed = xcorr[
    xcorr.shape[0] // 4 : 3 * xcorr.shape[0] // 4,
    xcorr.shape[1] // 4 : 3 * xcorr.shape[1] // 4,
]
ax4.imshow(xcorr_zoomed, vmin=xcorr.mean(), cmap="inferno")
ax4.axhline(y=xcorr_zoomed.shape[0] / 2, linestyle="--", color="w", linewidth=1)
ax4.axvline(x=xcorr_zoomed.shape[1] / 2, linestyle="--", color="w", linewidth=1)
ax4.text(
    x=0.05, y=0.95, s="2x", transform=ax4.transAxes, ha="left", va="top", color="w"
)

shift, *_ = phase_cross_correlation(
    reference_image=im,
    moving_image=im_r,
    reference_mask=mask,
    moving_mask=mask_r,
)

ax4.arrow(
    x=xcorr_zoomed.shape[1] / 2,
    y=xcorr_zoomed.shape[0] / 2,
    dx=shift[1],
    dy=shift[0],
    color="w",
    length_includes_head=True,
    head_width=10,
)

for ax in (ax1, ax2, ax3, ax4):
    ax.axis("off")
plt.tight_layout()
Click here to see how this plot was generated.