In [None]:
from matplotlib import pyplot as plt
from PIL import Image
import numpy as np
from skimage.measure import block_reduce
import requests


with open('s34_e150_1arc_v3.tif','wb') as f:
    f.write(requests.get(
        'http://sendimage.whu.edu.cn/res/DEM_share/SRTM1/S60,E150/s34_e150_1arc_v3.tif'
    ).content)
    f.close()



In [None]:
im = np.asarray(Image.open('s34_e150_1arc_v3.tif'))
plt.imshow(im)
plt.show()


In [None]:
"""UnionFind.py

Union-find data structure. Based on Josiah Carlson's code,
http://aspn.activestate.com/ASPN/Cookbook/Python/Recipe/215912
with significant additional changes by D. Eppstein.
"""


class UnionFind:

    """Union-find data structure.

    Each unionFind instance X maintains a family of disjoint sets of
    hashable objects, supporting the following two methods:

    - X[item] returns a name for the set containing the given item.
      Each set is named by an arbitrarily-chosen one of its members; as
      long as the set remains unchanged it will keep the same name. If
      the item is not yet part of a set in X, a new singleton set is
      created for it.

    - X.union(item1, item2, ...) merges the sets containing each item
      into a single larger set.  If any item is not yet part of a set
      in X, it is added to X as one of the members of the merged set.
    """

    def __init__(self):
        """Create a new empty union-find structure."""
        self.weights = {}
        self.parents = {}

    def add(self, object, weight):
        if object not in self.parents:
            self.parents[object] = object
            self.weights[object] = weight

    def __contains__(self, object):
        return object in self.parents

    def __getitem__(self, object):
        """Find and return the name of the set containing the object."""

        # check for previously unknown object
        if object not in self.parents:
            assert(False)
            self.parents[object] = object
            self.weights[object] = 1
            return object

        # find path of objects leading to the root
        path = [object]
        root = self.parents[object]
        while root != path[-1]:
            path.append(root)
            root = self.parents[root]

        # compress the path and return
        for ancestor in path:
            self.parents[ancestor] = root
        return root

    def __iter__(self):
        """Iterate through all items ever found or unioned by this structure.

        """
        return iter(self.parents)

    def union(self, *objects):
        """Find the sets containing the objects and merge them all."""
        roots = [self[x] for x in objects]
        heaviest = max([(self.weights[r], r) for r in roots])[1]
        for r in roots:
            if r != heaviest:
                self.parents[r] = heaviest


In [None]:
from dataclasses import dataclass
from tqdm import tqdm


@dataclass
class Peak:
    birth_point: tuple
    birth_level: float
    persistence: float
    death_point: tuple


class Mountaineer:
    """
    Finds the peaks in the image.
    
    Uses concepts from persistent homology, and the UnionFind data structure.
    Adapted from:  https://git.sthu.org/?p=persistence.git;a=summary
    """
    def __init__(self, image):
        self.image = image
        self.h, self.w = image.shape

        self.uf = UnionFind()

    def get_comp_birth(self, p):
        pp = self.uf[p]
        y, x = pp[0], pp[1]

        return self.image[y, x]

    def survey_peaks(self):
        # Get sorted array of indices, highest first
        idxs_sorted = np.dstack(
            np.unravel_index(np.argsort(self.image.ravel()), self.image.shape)
        )[0, ...][::-1, :]
        
        idx_im = np.indices(self.image.shape)

        groups0 = {}
        
        for i, p in tqdm(enumerate(idxs_sorted)):
            p = (p[0], p[1])
            y, x = p[0], p[1]
            val = self.image[y, x]
            if val == -np.inf: continue
            
            _y, y_ = max(y - 1, 0), min(y + 1, self.h)
            _x, x_ = max(x - 1, 0), min(x + 1, self.w)
            
            neighbours = idx_im[:, _y:y_+1, _x:x_+1].transpose((2, 1, 0)).reshape((-1, 2))
            neighbours = list(map(tuple, neighbours))

            ni = [self.uf[q] for q in neighbours if q in self.uf]
            nc = sorted([(self.get_comp_birth(q), q) for q in set(ni)], reverse=True)

            if i == 0:
                groups0[p] = Peak(birth_point=None, birth_level=val, persistence=val, death_point=None)

            self.uf.add(p, -i)

            if len(nc) > 0:
                oldp = nc[0][1]
                self.uf.union(oldp, p)
                # Merge all others with oldp
                for bl, q in nc[1:]:
                    if self.uf[q] not in groups0:
                        groups0[self.uf[q]] = Peak(birth_point=None, birth_level=bl, persistence=bl-val, death_point=p)

                    self.uf.union(oldp, q)

        def add_bp(peak: Peak, bp):
            peak.birth_point = bp
            return peak
        
        g = groups0
        groups = [add_bp(v, k) for k, v in groups0.items()]
        groups.sort(key=lambda g: g.persistence, reverse=True)

        return groups



In [None]:
peaks = Mountaineer(block_reduce(im, (4, 4), np.mean)).survey_peaks()


In [None]:
im = block_reduce(im, (4, 4), np.mean)

fig = plt.figure(figsize=(5, 10), dpi=120)
plt.imshow(im)

ax = fig.add_subplot(111)
for i, peak in enumerate(peaks):
    p_birth, bl, pers, p_death = peak.birth_point, peak.birth_level, peak.persistence, peak.death_point
    if pers <= 20.0:
        continue
    y, x = p_birth
    ax.plot([x], [y], '.', c='b')
    ax.text(x, y+0.25, str(i+1), color='red')
    if i >= 20: break
    
ax.set_xlim((0,im.shape[1]))
ax.set_ylim((0,im.shape[0]))
plt.gca().invert_yaxis()

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from matplotlib import cm

img = im


def get(p):
    if p == None:
        return img.min()
    return img[p[0], p[1]]

pts = [p for p in peaks]
pts.sort(key=lambda x: get(x.birth_point), reverse=True)

persistences = np.asarray([p.persistence for p in pts])
_max, _min = persistences.max(), persistences.min()

pts = list(filter(lambda x: x.persistence > 150, pts, ))

persistences = np.asarray([p.persistence for p in pts])
colors = cm.viridis((persistences - _min) / (_max - _min))[::-1]


begin = np.asarray([get(p.birth_point) for p in pts])[::-1]
end = np.asarray([get(p.death_point) for p in pts])[::-1]
p_sorted = [p for p in peaks]
p_sorted.sort(key=lambda x: x.persistence, reverse=True)
tags = {p.birth_point: idx for idx, p in enumerate(p_sorted)}
tags = [str(tags[p.birth_point]+1) for p in pts][::-1]

fig = plt.figure(figsize=(10, 15))
plot = plt.scatter(begin, begin, c = begin, cmap = 'viridis')
plt.clf()
plt.grid(linestyle=':')

cbar = plt.colorbar(plot, aspect=70)

plot = plt.barh(range(len(begin)),  end-begin, left=begin, color=colors)
plt.gca().invert_xaxis()

plt.margins(y=0.01)
plt.xlabel('Elevation (m)', fontsize=16)
plt.ylabel('Birth and Death of Individual Peaks\n(label refers to persistence rank)', fontsize=16)
cbar.ax.set_ylabel("Persistence (m)", fontsize=16)
plt.title("Bar Chart of 66-most Persistent Peaks", fontsize=16)
plt.yticks(range(len(begin)), tags)
plt.tight_layout()
plt.savefig('barchart.png', dpi=120)
plt.show()
