from __future__ import annotations
from collections import namedtuple
from typing import Literal
import ee
import ipywidgets as widgets
import pandas as pd
import plotly.graph_objects as go
from sankee import sampling, themes, utils
SankeyParameters = namedtuple(
"SankeyParameters",
[
"node_labels",
"link_labels",
"node_palette",
"link_palette",
"label",
"source",
"target",
"value",
],
)
[docs]def sankify(
image_list: list[ee.Image],
band: str,
labels: dict[int, str],
palette: dict[int, str],
region: None | ee.Geometry = None,
label_list: None | list[str] = None,
max_classes: None | int = None,
n: int = 500,
title: None | str = None,
scale: None | int = None,
seed: int = 0,
label_type: None | Literal["class", "percent", "count"] = "class",
theme: str | themes.Theme = "default",
) -> SankeyPlot:
"""
Generate an interactive Sankey plot showing land cover change over time from a series of images.
Parameters
----------
image_list : List[ee.Image]
An ordered list of images representing a time series of classified data. Each image will be
sampled to generate the Sankey plot. Any length of list is allowed, but lists with more than
3 or 4 images may produce unusable plots.
band : str
The name of the band in all images of image_list that contains classified data.
labels : dict
The labels associated with each value of all images in image_list. Any values not defined
in the labels will be dropped from the sampled data.
palette : dict
The colors associated with each value of all images in image_list.
region : ee.Geometry, default None
A region to generate samples within. The region must overlap all images. If none is
provided, the geometry of the first image will be used. For this to work, images must be
bounded.
label_list : List[str], default None
An ordered list of labels corresponding to the images. The list must be the same length as
image_list. If none is provided, sequential numeric labels will be automatically assigned
starting at 0. Labels are displayed on-hover on the Sankey nodes.
max_classes : int, default None
If a value is provided, small classes will be removed until max_classes remain. Class size
is calculated based on total times sampled in the time series.
n : int, default 500
The number of sample points to randomly generate for characterizing all images. More samples
will provide more representative data but will take longer to process.
title : str, default None
An optional title that will be displayed above the Sankey plot.
scale : int, default None
The scale in image units to perform sampling at. If none is provided, GEE will attempt to
use the image's nominal scale, which may cause errors depending on the image projection.
seed : int, default 0
The seed value used to generate repeatable results during random sampling.
label_type : str, default "class"
The type of label to display for each link, one of "class", "percent", "count", or None.
Selecting "class" will use the class label, "percent" will use the proportion of sampled
pixels in each class, and "count" will use the number of sampled pixels in each class.
None will disable link labels.
theme : str or Theme
The theme to apply to the Sankey diagram. Can be the name of a built-in theme (e.g. "d3") or
a custom `sankee.Theme` object.
Returns
-------
SankeyPlot
An interactive Sankey plot widget.
"""
if region is None:
region = image_list[0].geometry()
label_list = label_list if label_list is not None else list(range(len(image_list)))
label_list = [str(label) for label in label_list]
if len(label_list) != len(image_list):
raise ValueError("The number of labels must match the number of images.")
if len(set(label_list)) != len(label_list):
raise ValueError("All labels in the `label_list` must be unique.")
data, samples = sampling.generate_sample_data(
image_list=image_list,
image_labels=label_list,
band=band,
scale=scale,
include=list(labels.keys()),
max_classes=max_classes,
region=region,
n=n,
seed=seed,
)
return SankeyPlot(
data=data,
labels=labels,
palette=palette,
title=title,
samples=samples,
label_type=label_type,
theme=theme,
)
class SankeyPlot(widgets.DOMWidget):
def __init__(
self,
*,
data: pd.DataFrame,
labels: dict[int, str],
palette: dict[int, str],
title: str,
samples: ee.FeatureCollection,
label_type: None | Literal["class", "percent", "count"],
theme: str | themes.Theme,
):
self.data, self.labels, self.palette = self._merge_duplicate_classes(data, labels, palette)
self.title = title
self.samples = samples
self.label_type = label_type
self.theme = theme if isinstance(theme, themes.Theme) else themes.load_theme(theme)
self.hide = []
# Initialized by `self.generate_plot`
self.df = None
self.plot = self._generate_figurewidget()
self.gui = self._generate_gui()
def _merge_duplicate_classes(self, data, labels, palette):
"""
Combine classes with duplicated labels and colors into a single class.
This allows classes that are distinct in the sampled image to be aggregated at the plotting
stage, which is more efficient.
"""
# A mapping of (color, label) to the first sampled value associated with that pair.
running_map: dict[tuple[str, str], int] = {}
remap: dict[int, int] = {}
# If a label-color pair is repeated with different values, remap the values to the first
# occurrence of that label-color pair.
for key, label, color in zip(labels.keys(), labels.values(), palette.values()):
if (color, label) in running_map:
prev_key = running_map[(color, label)]
remap[key] = prev_key
else:
running_map[(color, label)] = key
# Grab the distinct color and label with their associated value
palette = {v: k[0] for k, v in running_map.items()}
labels = {v: k[1] for k, v in running_map.items()}
# Apply the value remapping to merge classes
data = data.replace(remap)
return data, labels, palette
def _get_sorted_classes(self) -> pd.Series:
"""Return all unique class values, sorted by the total number of observations."""
start_count = (
self.df.loc[:, ["source", "total"]]
.groupby("source")
.mean()
.reset_index()[["source", "total"]]
.rename(columns={"source": "class", "total": "count"})
)
end_count = (
self.df.loc[:, ["target", "changed"]]
.groupby("target")
.sum()
.reset_index()[["target", "changed"]]
.rename(columns={"target": "class", "changed": "count"})
)
total_count = pd.concat([start_count, end_count]).groupby("class").sum().reset_index()
return total_count.sort_values(by="count", ascending=False)["class"].reset_index(drop=True)
def _get_active_classes(self) -> pd.Series:
"""Return all unique active, visibile class values after filtering."""
return self.df[["source", "target"]].melt().value.unique()
def _generate_plot_parameters(self) -> SankeyParameters:
"""Generate Sankey plot parameters from a formatted, cleaned dataframe"""
df = self.df.copy()
source_df = df[["source", "source_year"]].rename(
columns={"source": "class", "source_year": "year"}
)
target_df = df[["target", "target_year"]].rename(
columns={"target": "class", "target_year": "year"}
)
all_classes = pd.concat([source_df, target_df])
all_classes = all_classes.drop_duplicates().reset_index(drop=True)
all_classes["color"] = all_classes["class"].apply(lambda k: self.palette[k]).tolist()
all_classes["id"] = all_classes.groupby(["year", "class"], sort=False).ngroup()
# Join the sequential class-year IDs to the dataframe
df["source_id"] = pd.merge(
left=df,
right=all_classes,
how="left",
left_on=["source_year", "source"],
right_on=["year", "class"],
)["id"]
df["target_id"] = pd.merge(
left=df,
right=all_classes,
how="left",
left_on=["target_year", "target"],
right_on=["year", "class"],
)["id"]
# Calculate the proportion of each class in each year
melted = self.data.melt(var_name="year")
melted = melted.groupby(["year", "value"]).size().reset_index(name="count")
melted["proportion_of_total"] = (
melted.groupby("year")["count"]
.transform(lambda x: x / x.sum())
.apply(lambda x: f"{x:.0%}")
)
all_classes = all_classes.merge(
melted, left_on=["year", "class"], right_on=["year", "value"]
)
if self.label_type == "class":
all_classes["label"] = all_classes["class"].apply(lambda k: self.labels[k])
elif self.label_type == "percent":
all_classes["label"] = all_classes["proportion_of_total"]
elif self.label_type == "count":
all_classes["label"] = all_classes["count"]
elif not self.label_type:
all_classes["label"] = ""
else:
raise ValueError(
"Invalid label_type. Choose from 'class', 'percent', 'count', or None."
)
return SankeyParameters(
node_labels=all_classes.year,
link_labels=df.link_label,
node_palette=all_classes.color,
link_palette=df.source_color,
label=all_classes.label,
source=df.source_id,
target=df.target_id,
value=df.changed,
)
def _generate_dataframe(self) -> pd.DataFrame:
"""Convert raw sampling data to a formatted dataframe"""
data = self.data.copy()
if self.hide:
hide_mask = pd.concat([(data == i).any(axis=1) for i in self.hide], axis=1).any(axis=1)
data = data[~hide_mask]
permutations = []
# Get all unique class-year combinations
for source, target in utils.pairwise(data.columns):
permutations += list(
zip([source] * len(data), [target] * len(data), data[source], data[target])
)
df = pd.DataFrame(permutations, columns=["source_year", "target_year", "source", "target"])
# Count the unique combinations of all four fields
df = (
df.groupby(["source_year", "target_year", "source", "target"])
.size()
.reset_index()
.rename(columns={0: "changed"})
)
# Count the total number of source samples in each year
df["total"] = df.groupby(["source_year", "source"]).changed.transform("sum")
# Calculate what percent of the source samples went into each target class
df["proportion"] = df["changed"] / df["total"]
# Join the class labels and colors to the class IDs
df["source_label"] = df.source.apply(lambda k: self.labels[k])
df["target_label"] = df.target.apply(lambda k: self.labels[k])
df["source_color"] = df.source.apply(lambda k: self.palette[k])
df["target_color"] = df.target.apply(lambda k: self.palette[k])
def build_link_label(row: pd.Series) -> str:
# Early exit in case all classes are excluded
if row.shape[0] == 0:
return ""
verb = "remained" if row.source == row.target else "became"
pct = f"{row.proportion:.0%}"
return f"<b>{pct}</b> of <b>{row.source_label}</b> {verb} <b>{row.target_label}</b>"
# Describe the class changes
df["link_label"] = df.apply(build_link_label, axis=1)
return df
@property
def _view_name(self):
"""When the Sankey object is displayed by IPython, render the plot"""
return self.gui._view_name
@property
def _model_id(self):
"""When the Sankey object is displayed by IPython, render the plot"""
return self.gui._model_id
def update_layout(self, *args, **kwargs):
"""Pass layout changes to the plot."""
# This is primarily kept for compatibility with geemap
self.plot.update_layout(*args, **kwargs)
def _generate_gui(self):
BUTTON_HEIGHT = "24px"
BUTTON_WIDTH = "24px"
unique_classes = self._get_sorted_classes()
def toggle_button(button):
button.toggle()
class_name = button.tooltip
class_id = next(key for key in self.labels.keys() if self.labels[key] == class_name)
if not button.state:
self.hide.append(class_id)
else:
self.hide.remove(class_id)
update_plot()
def update_plot():
"""Swap new data into the plot."""
new_sankey = self._generate_sankey()
self.plot.data[0].link = new_sankey.link
self.plot.data[0].node = new_sankey.node
buttons = []
active_classes = self._get_active_classes()
for i in unique_classes:
label = self.labels[i]
on_color = self.palette[i]
state = i in active_classes
button = utils.ColorToggleButton(tooltip=label, on_color=on_color, state=state)
button.layout.width = BUTTON_WIDTH
button.layout.height = BUTTON_HEIGHT
button.on_click(toggle_button)
buttons.append(button)
def reset_plot(_):
for button in buttons:
if not button.state:
button.click()
reset_button = widgets.Button(
icon="refresh",
tooltip="Reset plot",
layout=widgets.Layout(height=BUTTON_HEIGHT, width=BUTTON_WIDTH, padding="0 0 0 3px"),
)
reset_button.on_click(reset_plot)
open_button = widgets.Button(
icon="external-link",
tooltip="Open in new tab",
layout=widgets.Layout(height=BUTTON_HEIGHT, width=BUTTON_WIDTH, padding="0 0 0 3px"),
)
open_button.on_click(
lambda _: self.plot.update_layout(width=None, height=None).show(renderer="browser")
)
button_box = widgets.HBox([*buttons, widgets.Label("|"), reset_button, open_button])
gui = widgets.VBox(
[
self.plot,
widgets.VBox([button_box], layout=widgets.Layout(align_items="center")),
]
)
return gui
def _generate_sankey(self) -> go.Figure:
"""Generate the Sankey plot based on the currently visible classes."""
self.df = self._generate_dataframe()
# Explicitly return an empty Sankey plot if all classes are hidden to avoid widget update
# errors.
if len(self.df) == 0:
return go.Sankey()
params = self._generate_plot_parameters()
node_kwargs = dict(
customdata=params.node_labels,
hovertemplate="<b>%{customdata}</b><extra></extra>",
label=[f"<span style='{self.theme.label_style}'>{s}</span>" for s in params.label],
color=params.node_palette,
)
link_kwargs = dict(
source=params.source,
target=params.target,
value=params.value,
color=params.link_palette,
customdata=params.link_labels,
hovertemplate="%{customdata} <extra></extra>",
)
return go.Sankey(
arrangement="snap",
node={**node_kwargs, **self.theme.node_kwargs},
link={**link_kwargs, **self.theme.link_kwargs},
)
def _generate_figurewidget(self) -> go.FigureWidget:
"""Generate the FigureWidget that wraps the Sankey plot."""
fig = go.FigureWidget(data=[self._generate_sankey()])
fig.update_layout(
title_text=f"<span style='{self.theme.title_style}'>{self.title}</span>"
if self.title
else None,
font_size=16,
title_x=0.5,
paper_bgcolor="rgba(0, 0, 0, 0)",
)
return fig