Predicting SDG indicators using Deep Learning¶
Previous work¶
I will be working with the DHS dataset introduced as part of SustainBench, a suite of sustainability benchmarks targeting the SDGs. The dataset consists of geographical data points, each containing data on poverty, child mortality, women's educational attainment, women's BMI, water quality and access to sanitation, as well as one satellite and multiple street-level images of each location. The socio-economic data is sourced from the DHS Program, a US-government-funded project which conducts household-level surveys in various countries. The daylight satellite images are sourced from the NASA-satellites Landsat 5, 7 and 8, the nighttime satellite images are sourced from the US-operated DMSP and VIIRS-equipped satellites. The street-level images are sourced from Mapillary, an online service for sharing crowdsourced, geotagged photots. The documentation of the dataset can be found here.
The authors provide a simple k-NN baseline model for predicting each of the mentioned features given the mean night-light value of the corresponding satellite image. The performance of the baseline models varies from terrible (squared Pearson correlation $r^2 = 0.01$ for child mortality) to okay ($r^2 = 0.63$ for the asset index). The only other model trained on this dataset, according to the leaderboard on the aforementioned SustainBench website, uses street-level images to predict women's average BMI in India. It outperforms the baseline model with a squared correlation coefficient of $r^2 = 0.57$, compared to $r^2 = 0.42$ for the baseline model (Lee et al., 2021).
The general task of predicting socio-economic data given images using machine learning has been attempted outside of the SustainBench framework. Jean et al. (2016) use transfer learning with a pre-trained CNN to estimate economic well-being given satellite imagery. Hall et al. (2023) review the many attempts of predicting poverty given satellite images in recent years. Abitbol et al. (2020) achieved promising results when predicting socioeconomic indicators given satellite images, yet mention the the lack of interpretability of the model. Burke et al. (2021) conclude that satellite-based approaches to quantifying indicators should not be used to replace the data collected on the ground, but should rather be used to enhance it as part of a combined approach.
It follows that much research is present in this particular field. Nonetheless, many questions remain in regard to ethics and reliability of using machine learning to infer indicators.
SustainBench has not been widely adapted as a uniform benchmark task. Five out of six features within the DHS dataset have no projects competing with the baseline model in the leaderboard.
Exploratory data analysis¶
Overview of the data¶
We can get a brief overview of the data with the pandas function describe(...)
.
import numpy as np
import torch
import pandas as pd
import matplotlib.pyplot as plt
import geopandas as gpd
import contextily as ctx
from shapely.geometry import Point
data = pd.read_csv('dhs_final_labels.csv')
data.describe()
year | lat | lon | n_asset | asset_index | n_water | water_index | n_sanitation | sanitation_index | under5_mort | n_under5_mort | women_edu | women_bmi | n_women_edu | n_women_bmi | cluster_id | adm1dhs | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
count | 117644.000000 | 117644.000000 | 117644.000000 | 86936.000000 | 86936.000000 | 87938.000000 | 87938.000000 | 89271.000000 | 89271.000000 | 105582.000000 | 105582.000000 | 117062.000000 | 94866.000000 | 117062.000000 | 94866.000000 | 1.176440e+05 | 117644.000000 |
mean | 2010.964894 | 10.875259 | 29.263579 | 23.914558 | 0.174589 | 23.937615 | 3.763723 | 24.009242 | 3.086101 | 18.163958 | 18.345021 | 6.354988 | 23.296365 | 24.861065 | 18.778098 | 6.487475e+05 | 1053.883071 |
std | 5.301742 | 16.276815 | 54.533060 | 7.779958 | 1.848209 | 7.806283 | 1.123908 | 7.765566 | 1.282027 | 46.747577 | 12.160344 | 3.468181 | 2.946691 | 11.069406 | 9.611864 | 4.907284e+06 | 3037.622967 |
min | 1996.000000 | -30.588811 | -92.176053 | 5.000000 | -3.823164 | 5.000000 | 1.000000 | 5.000000 | 1.000000 | 0.000000 | 5.000000 | 0.000000 | 15.758333 | 5.000000 | 5.000000 | 1.000000e+00 | 0.000000 |
25% | 2007.000000 | 0.057170 | -0.137666 | 20.000000 | -1.451730 | 20.000000 | 3.037037 | 20.000000 | 2.037037 | 0.000000 | 10.000000 | 3.750000 | 21.148019 | 18.000000 | 11.000000 | 1.900000e+02 | 5.000000 |
50% | 2013.000000 | 11.982408 | 32.822764 | 22.000000 | 0.179011 | 22.000000 | 3.933333 | 22.000000 | 2.958333 | 0.000000 | 15.000000 | 6.333333 | 22.804907 | 24.000000 | 18.000000 | 4.660000e+02 | 13.000000 |
75% | 2015.000000 | 24.617686 | 77.497416 | 27.000000 | 1.842407 | 27.000000 | 4.826087 | 27.000000 | 4.272727 | 0.000000 | 23.000000 | 8.916667 | 24.898071 | 30.000000 | 24.000000 | 5.098325e+04 | 31.000000 |
max | 2019.000000 | 48.436031 | 126.842321 | 108.000000 | 3.607050 | 108.000000 | 5.000000 | 108.000000 | 5.000000 | 692.307692 | 166.000000 | 17.800000 | 48.111667 | 130.000000 | 118.000000 | 7.571352e+07 | 9999.000000 |
We can see that there are more than 100'000 data points, although not all data points contain values for all the columns. There are 63'329 rows without NaN values, which will become relevant later. According to the SustainBench paper Yeh et al. (2021), the columns describe the following data:
year
: Year of data collection.lat
: Latitude coordinate of the data point.lon
: Longitude coordinate of the data point.n_asset
: Number of asset-index-related observations.asset_index
: Index or score representing the asset wealth of the average household. High is rich.n_water
: Number of water-quality-index-related observations.water_index
: Water quality index. High is good.n_sanitation
: Number of sanitation-index-related observations.sanitation_index
: Index describing the location's average access to sanitation facilities. High is good.under5_mort
: Mortality rate of children under 5 years of age per 1000.n_under5_mort
: Number of observations related to under-5 mortality.women_edu
: Average years of educational attainment of women in the surveyed area.women_bmi
: Average Body Mass Index (BMI) measurement of women.n_women_edu
: Number of women surveyed regarding their education.n_women_bmi
: Number of women surveyed regarding their BMI.cluster_id
: Identifier for the cluster or group that the data point belongs to.adm1dhs
: Administrative division identifier, specific to the Demographic and Health Surveys (DHS) system.
The columns we are interested in are asset_index
, water_index
, sanitation_index
, under5_mort
, women_edu
and women_bmi
, as they align with indicators of various SDGs.
Spacial representation of the dataset¶
The first plot shows the spacial representation of the dataset, i.e., on which parts of the world the dataset contains data on. We can observe that quite a large part of the world is represented, but it is still limited, containing barely any data on countries in the Global North. The countries close to the equator are clearly overrepresented. This limits the comparisons we can make with regard to the features.
# Convert pandas dataframe to a geodataframe which is useful for plotting the location on a map
geo_data = gpd.GeoDataFrame(data, geometry=gpd.points_from_xy(data.lon, data.lat))
# Convert coordinate reference system to the Web Mercator system, which is apparently
# the gold standard for online maps
geo_data.set_crs(epsg=4326, inplace=True)
gdf = geo_data.to_crs(epsg=3857)
fig, ax = plt.subplots(figsize=(10, 10))
# Plot a point denoting the location on a map for each entry of the dataset
gdf.plot(ax=ax, color='black', alpha=0.1)
# Add world map backdrop
ctx.add_basemap(ax, source=ctx.providers.OpenStreetMap.Mapnik)
plt.xlabel('Longitude')
plt.ylabel('Latitude')
plt.title('Spacial representation of dataset')
# Such that we can see the whole world
ax.set_xlim(-2e7, 2e7)
ax.set_ylim(-1e7, 1e7)
plt.show()
Spacial representation of the features¶
Next, we are going to look at how the different features are spacially distributed. For this, I have generated a plot per feature we are interested in. We can observe some interesting phenomena in these plots:
- We can already infer correlation, as there seems to be a tendency towards regions being on a similar ends of the spectrum for multiple features.
- We can infer variance within countries, such as Colombia's sanitation index being significantly lower on the coast of the Pacific Ocean than in the Andes region of the country.
- Child mortality is unusually high in Nepal. This is possibly a measurement error.
- The European countries seem to be wealthy, the central African countries seem to be poor (with respect to the asset index). There is a very similar tendency with regard to the sanitation and water quality indices.
from matplotlib import colormaps
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import random
def plot_geo_data(data, key, title, cmap='random'):
"""
Creates a plot over a world map.
Args:
data (pandas dataframe): The dataframe containing the data to be plotted.
key (str): The key of the feature (i.e., the column in the dataframe) which is to be plotted.
title (str): For the title of the plot.
"""
# Create a matplotlib figure and axis with a geo projection
fig, ax = plt.subplots(figsize=(15, 10), subplot_kw={'projection': ccrs.PlateCarree()})
# Add features to the map: coastlines, borders, and land
ax.add_feature(cfeature.COASTLINE)
ax.add_feature(cfeature.BORDERS, linestyle=':')
ax.add_feature(cfeature.LAND, edgecolor='black')
# Make sure the whole map is visible
ax.set_extent([-180, 180, -90, 90], crs=ccrs.PlateCarree())
cmap_input = cmap
cmap = random.choice(colormaps()) if cmap_input == 'random' else cmap_input
if cmap_input == 'random':
print(cmap)
scatter = ax.scatter(data['lon'], data['lat'], c=data[key], marker='o', cmap=cmap, transform=ccrs.PlateCarree(), alpha=0.4)
plt.colorbar(scatter, ax=ax, label=title, shrink=0.6)
plt.title(f'Spacial Distribution of {title}')
plt.show()
plot_geo_data(data, 'women_bmi', 'Women\'s Average BMI', cmap='inferno_r')
plot_geo_data(data, 'under5_mort', 'Child Mortality Rate', cmap='inferno_r')
plot_geo_data(data, 'asset_index', 'Asset Index', cmap='RdYlGn')
plot_geo_data(data, 'sanitation_index', 'Sanitation Index', cmap='RdYlGn')
plot_geo_data(data, 'water_index', 'Water Quality Index', cmap='RdYlGn')
plot_geo_data(data, 'women_edu', 'Women\'s Educational Attainment', cmap='RdYlGn')
# Dictionary mapping country codes to countries as mentioned on page 23 of the SustainBench paper
country_mapping = {
'AL': 'Albania',
'AM': 'Armenia',
'AO': 'Angola',
'BD': 'Bangladesh',
'BF': 'Burkina Faso',
'BJ': 'Benin',
'BO': 'Bolivia',
'BU': 'Burundi',
'CD': 'Congo Democratic Republic',
'CI': 'Cote d’Ivoire',
'CM': 'Cameroon',
'CO': 'Colombia',
'DR': 'Dominican Republic',
'EG': 'Egypt',
'ET': 'Ethiopia',
'GA': 'Gabon',
'GH': 'Ghana',
'GN': 'Guinea',
'GU': 'Guatemala',
'GY': 'Guyana',
'HN': 'Honduras',
'HT': 'Haiti',
'IA': 'India',
'ID': 'Indonesia',
'JO': 'Jordan',
'KE': 'Kenya',
'KH': 'Cambodia',
'KM': 'Comoros',
'KY': 'Kyrgyz Republic',
'LB': 'Liberia',
'LS': 'Lesotho',
'MA': 'Morocco',
'MB': 'Moldova',
'MD': 'Madagascar',
'ML': 'Mali',
'MM': 'Myanmar',
'MW': 'Malawi',
'MZ': 'Mozambique',
'NG': 'Nigeria',
'NI': 'Niger',
'NM': 'Namibia',
'NP': 'Nepal',
'PE': 'Peru',
'PH': 'Philippines',
'PK': 'Pakistan',
'RW': 'Rwanda',
'SL': 'Sierra Leone',
'SN': 'Senegal',
'SZ': 'Eswatini',
'TD': 'Chad',
'TG': 'Togo',
'TJ': 'Tajikistan',
'TZ': 'Tanzania',
'UG': 'Uganda',
'ZM': 'Zambia',
'ZW': 'Zimbabwe'
}
Mean value of features over countries¶
The next set of plots shows us the average of the features across the different countries. Interestingly, the variance of the features differs significantly. While women's BMI is somewhat similar for all countries, the asset index and women's educational attainment differs significantly between countries. Again, we can observe Nepal's unusually high child mortality rate. For future work, it would be interesting to use boxplots to visualise the variance of the features within each country.
column_names = ['water_index', 'sanitation_index', 'under5_mort', 'asset_index', 'women_edu', 'women_bmi']
features_format = ['Water Quality Index', 'Sanitation Index', 'Child Mortality per 1000', 'Asset Index', 'Women\'s educational attainment in years', 'Women\'s BMI']
# Create a 3x2 grid of subplots
fig, axs = plt.subplots(3, 2, figsize=(10, 25))
plt.subplots_adjust(wspace=0.6, hspace=0.2)
# Flatten the 3x2 grid for easy iteration
axs = axs.ravel()
for i, feature in enumerate(column_names):
average_by_nation = data.groupby('cname')[feature].mean()
average_by_nation = average_by_nation.dropna()
sorted_average_by_nation = average_by_nation.sort_values(ascending=False)
countries = []
for nation in sorted_average_by_nation.keys():
countries.append(country_mapping[nation])
ax = axs[i]
ax.set_title(f'Average {features_format[i]} by country')
ax.barh(sorted_average_by_nation.keys(), sorted_average_by_nation)
ax.set_xlabel(features_format[i])
ax.set_yticks(range(len(countries)), countries, fontsize=7)
# Print the mean and standard deviation over the means for all countries
print(f'{feature}: {np.round(np.mean(sorted_average_by_nation), 2)}±{np.round(np.std(sorted_average_by_nation), 2)}')
water_index: 3.71±0.62 sanitation_index: 3.01±0.79 under5_mort: 20.03±35.72 asset_index: 0.09±1.23 women_edu: 6.11±2.77 women_bmi: 23.54±1.94
Correlation between features¶
We are using the Pearson correlation coefficient $r \in [-1, 1]$ to represent the linear correlation between the features. Each pair of features has a coefficient $r$. If $r = 1$, this implies a perfect linear relationship, such as a pair of identical features. $r=-1$ denotes a perfect negative linear relationship. $r \approx 0$ implies that there is little to no correlation. The coefficient is computed as
$r = \frac{\sum (X_i - \overline{X})(Y_i - \overline{Y})}{\sqrt{\sum (X_i - \overline{X})^2 \sum (Y_i - \overline{Y})^2}} = \frac{Cov(X, Y)}{\sigma_X \sigma_Y}$
where $\overline{X}$, $\overline{Y}$ are the mean values and $X_i$, $Y_i$ the individual samples of features $X$ and $Y$. $\sigma_X$ and $\sigma_Y$ denote the standard deviations of $X$ and $Y$. This formula shows us that the Pearson correlation coefficient is essentially the normalised covariance, squashing it into $[-1, 1]$.
The correlation matrix shows the correlation between each pair of features. Note that correlation is commutative, thus the correlation matrix is symmetric. We can observe that the asset index and access to sanitation are highly correlated with a Pearson correlation coefficient of 0.89. This positive correlation coefficient implies a strong tendency that the asset index tends to be lower if access to sanitation is lower, and vice-versa. Many other features are correlated to a certain noteworthy degree. Only the child mortality rate is consistently uncorrelated with all other features, suggesting that none of the provided features provide significant information on child mortality.
# Nice format of the feature labels
features_format = ['Water Quality', 'Access to\nSanitation', 'Child\nMortality', 'Assets', 'Women\'s\neducation', 'Women\'s BMI']
# Compute the correlation matrix using the pearson correlation coefficient
correlation_matrix = data[column_names].corr(method='pearson')
import seaborn as sns
# Display the correlation matrix using Seaborn
plt.figure(figsize=(10, 8))
ax = sns.heatmap(correlation_matrix, annot=True, cmap='coolwarm', fmt=".2f")
ax.set_xticklabels(features_format, rotation=0)
ax.set_yticklabels(features_format, rotation=0)
plt.title('Correlation Matrix')
plt.show()
Visualising the covariance between features¶
The correlation can be understood as an unnormalised covariance (i.e., $r \sim Cov(X, Y)$). Thus, to visualise the correlation, we can plot the covariance. We can see how certain pairs of features have a distinct linear shape, such as the highly correlated sanitation and asset indices. All the plots involving child mortality seem very noisy and random, representing the very low correlation with the other features. Note the perfectly correlated equal features along the diagonal.
# Plot covariance, as it displays many correlations intuitively
# Get separate dataframe without NaN values
data_no_nan = data.dropna(inplace=False)
fig, axes = plt.subplots(6, 6, figsize=(10, 10))
plt.subplots_adjust(wspace=0.4, hspace=0.4)
# Insert context regarding the numbers
features_format = ['Water\nQuality\nIndex', 'Access to\nSanitation', 'Child\nMortality\nper 1000', 'Asset Index', 'Women\'s\neducational\nAttainment\nin years', 'Women\'s\naverage\nBMI']
for i1, feat1 in enumerate(column_names):
for i2, feat2 in enumerate(column_names):
ax = axes[i1, i2]
if i2 == 0:
ax.set_ylabel(features_format[i1], labelpad=30, rotation=0, va='center')
if i1 == len(column_names) - 1:
ax.set_xlabel(features_format[i2], labelpad=10)
ax.scatter(data_no_nan[feat1], data_no_nan[feat2], s=0.5, c='black')
The strong correlation between certain features confirms our assumption that the SDGs and their indicators are strongly correlated and should not be tackled individually.
Unfortunately, we cannot perform a temporal analysis, as the data for each country was surveyed in exactly one year. Thus, a temporal analysis would not give us representative information on temporal development.
Satellite image data¶
The satellite image data consists of 117644 $8 \times 255 \times 255$ images as numpy arrays, one per row in the DHS datapoint. This dataset is large and thus I will refrain from applying augmentations to the images to artificially increase the size of the dataset. The first seven channels are "RED, GREEN, BLUE, NIR (Near Infrared), SWIR1 (Shortwave Infrared 1), SWIR2 (Shortwave Infrared 2), and TEMP1 (Thermal)". The eighth channel is a lower-resolution image which captures the spacial distribution of lighting in the night. This implies that the first three channels comprise a standard RGB image, the fourth, fifth and sixth channel measure different infrared wavelenghts (which describe thermal and radioactive properties, among other things), the seventh channel describes the thermal properties and the eighth channel describes the location of lighting which is visible during the night. We can visualise a randomly selected image as such:
import numpy as np
import matplotlib.pyplot as plt
# Load one of the images
full_image = np.load('dhs/satellite/AL-2008-5#/AL-2008-5#-00000433.npz')['x']
# Create a figure with subplots in a 2x3 grid
fig, axs = plt.subplots(2, 3, figsize=(15, 10)) # Adjust figsize to your needs
# Multiply by 3 to make the image brighter
# RGB representation
axs[0, 0].imshow(full_image[:3].transpose(1, 2, 0) * 3)
axs[0, 0].set_title('RGB representation')
axs[0, 0].axis('off') # Turn off axis
# NIR representation
axs[0, 1].imshow(full_image[3])
axs[0, 1].set_title('NIR representation')
axs[0, 1].axis('off')
# SWIR1 representation
axs[0, 2].imshow(full_image[4])
axs[0, 2].set_title('SWIR1 representation')
axs[0, 2].axis('off')
# SWIR2 representation
axs[1, 0].imshow(full_image[5])
axs[1, 0].set_title('SWIR2 representation')
axs[1, 0].axis('off')
# TEMP1 representation
axs[1, 1].imshow(full_image[6])
axs[1, 1].set_title('TEMP1 representation')
axs[1, 1].axis('off')
# Nightlight representation
axs[1, 2].imshow(full_image[7])
axs[1, 2].set_title('Nightlight representation')
axs[1, 2].axis('off')
plt.tight_layout()
plt.show()
Street images¶
The street-level images consist of 1'073'455 .jpeg
images in various sizes with three RBG channels. The mapping from the street images to the DHS data (with one corresponding satellite image per data point) is a surjective function with at least one, but usually multiple, street-level images per data point. The images were sourced from Mapillary, a service for sharing geo-tagged images. We can sample some of the images and visualise them for intuition. Due to the large number of images, I do not want to artificially increase the size of the data set and thus will not use data augmentation.
import matplotlib.image as mpimg
# Create a figure with subplots in a 1x3 grid
fig, axs = plt.subplots(1, 3, figsize=(20, 10))
# Some random image paths
imgs = ['dhs/mapillary/imagery/CO/CO-2010-6#-00001349/1972919636196838.jpeg', 'dhs/mapillary/imagery/MM/MM-2016-7#-00000019/1244428652656362.jpeg', 'dhs/mapillary/imagery/RW/RW-2015-7#-00000028/960400064499984.jpeg']
# Display the image in each subplot
for i, ax in enumerate(axs):
img = mpimg.imread(imgs[i])
ax.imshow(img)
ax.axis('off')
plt.show()
Task and evaluation¶
Plan of action¶
The task of predicting features given other features is not be particularly challenging due to the high degree of correlation between most features. I want to perform the more challenging task of using street-level and satellite images to predict the features.
SustainBench provides a baseline model which predicts the six features given the mean nightlight value of each satellite image. I want to extend this and use both the satellite images and the street-level images to predict the features, the goal being to outperform the baseline model. To do this, I am going to make use of transfer learning on a pre-trained ResNet34, a residual convolutional neural network. A residual neural network is a type of neural network where certain connections skip layers in order to circumvent the vanishing gradient problem. This allows for deeper neural networks when compared to non-residual networks. Making use of transfer-learning, thus fine-tuning a model pre-trained on a large image dataset, should allow for faster convergence. The ResNet34 has 34 layers and has been shown to perform well on images. Thus, I assume that we can use the ResNet34 for the regression task of predicting the feature values.
I will train three ResNet34 networks with three different inputs:
- The satellite images as inputs (8 channels).
- The mapillary street-level images (3 channels).
- Both (8 + 3 = 11 channels).
These will be compared regarding their test, validation and training losses and compared accordingly. I will then compute the Pearson correlation coefficient of the outputs of the test dataset and compare them to the results of the baseline model. I will also compare the models to the mean squared error losses of random predictions. If a model does not outperform random guessing, then it is useless and will not be further considered.
Finally, I want to visualise the features which the network has learned. Deep learning is a useful subset of machine learning for many reasons, one of them being that it performs feature extraction automatically. We are, at the end of the day, interested in the type of features the model learns to base its predictions on. This may allow for a deeper understanding of the correlation between visual and socio-economic features, but could also be useful for interpreting how the model makes its predictions, allowing for more transparent AI.
The idea behind the models containing street images in an applied setting is to use the model as part of a meta learning application. I.e., each location can have $n > 1$ street images. Thus, we would feed each of the $n$ inputs per location into the model and use the mean of all the predictions as our final prediction. When using the satellite images only, we have a bijective mapping between the images and the labels, thus meta learning is not appropriate.
Hardware¶
A model to be used in an applied setting should not only perform well, but also be deployable on feasible hardware. This model will be trained on the following consumer-grade hardware, implying the possibility of relatively inexpensive (< £2500) and uncomplicated deployment:
- Intel i7 13700K CPU
- NVIDIA RTX 4080 GPU (16GB VRAM)
- 32GB RAM with 32GB swap
- 2TB M.2 SSD
Design and build an ML system¶
I describe and justify my choice of model in the previous section. I will start this section by preprocessing the data.
import os
from PIL import Image
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms, datasets
import torch.optim as optim
from torch.utils.data import DataLoader, random_split, Dataset
import torchvision.models as models
from torchvision.models import ResNet34_Weights
Preprocess data¶
I need to preprocess the data for various reasons. Firstly, to get rid of NaN values. I get rid of all rows which contain at least one NaN value, as we want to avoid NaN values being propagated back through the network. I can do this because the dataset is very large and remains very large (> 60'000 rows) after removing the rows containing NaN values. Secondly, to concatenate the two datasets, which significantly speeds up the data loading later on. I split the dataset into mutually exclusive test and training/validation data as part of the preprocessing, such as to guarantee that that they do not leak into each other. Each processed, valid datapoint is put into the test set with a probability of $p = 0.15$ and into the test/validation set with $1-p = 0.85$. I do not execute these functions within this notebook, as they take a very long time to complete due to the large amount of data. Thus, I executed them as detached processes overnight.
column_names = ['water_index', 'sanitation_index', 'under5_mort', 'asset_index', 'women_edu', 'women_bmi']
test_prob = 0.15
def preprocess_satellite_images():
data = pd.read_csv('dhs_final_labels.csv')[column_names + ['DHSID_EA']]
available_ids = list(data.dropna()['DHSID_EA'])
root_path = 'dhs/satellite'
drop_indices = []
image_id_path_pairs = []
for i, dir in enumerate(os.listdir(root_path)):
if not os.path.isdir(os.path.join(root_path, dir)):
continue
for image_file_name in os.listdir(os.path.join(root_path, dir)):
# First entry of the pair is the name of the image as it appears in the csv
# The second entry is the relative path to the file from the base directory
image_id_path_pairs.append(
(os.path.splitext(image_file_name)[0], os.path.join(root_path, dir, image_file_name)))
print('start iteration')
for i, (id, path) in enumerate(image_id_path_pairs):
if not available_ids.__contains__(id):
drop_indices.append(i)
if i % 10000 == 0:
print(f'{i}/{len(image_id_path_pairs)} rows checked')
print(f'{len(drop_indices)}/{len(image_id_path_pairs)}')
clean_images_test = []
clean_images_train = []
# Split probabilistically
for i in range(len(image_id_path_pairs)):
if not drop_indices.__contains__(i):
if random.random() < test_prob:
clean_images_test.append(image_id_path_pairs[i])
else:
clean_images_train.append(image_id_path_pairs[i])
# Assert that the sets are mutually exclusive
assert(len(clean_images_test) + len(clean_images_train) == len(image_id_path_pairs) - len(drop_indices))
# Save the preprocessed data into a list list of pairs (satellite_image_id (string), path_to_image (string))
np.save('clear_satellite_image_id_path_pairs_test.npy', clean_images_test)
np.save('clear_satellite_image_id_path_pairs_train.npy', clean_images_train)
I convert the metadata csv file containing information on the street images into a two csv files (one for training/validation, one for testing) containing the corresponding information on the satellite image, which speeds up data loading later on.
def preprocess_street_images():
available_ids = list(data.dropna()['DHSID_EA'])
root_path = 'dhs/mapillary'
# Combine all the metadata to one single dataframe for easy lookup
metadatas = []
for csv_metadata in os.listdir(os.path.join(root_path, 'metadata')):
metadatas.append(pd.read_csv(os.path.join(root_path, 'metadata', csv_metadata)))
metadata = pd.concat(metadatas, ignore_index=True)
drop_indices = []
# Remove all the metadata which does not correspond to a legal row of DHS data
print('starting first part')
for index, row in metadata.iterrows():
row_id = row['DHSID_EA']
if not available_ids.__contains__(row_id):
drop_indices.append(index)
if index % 100000 == 0:
print(f'{index}/{len(metadata)} rows checked')
cleaned_metadata = metadata.drop(drop_indices)
print(f'{len(drop_indices)} values dropped')
test_indexes = []
print('finished first part')
# Add the features to the metadata for easier data loading later on and split into train/val and test data
print('starting the second part')
for index, row in cleaned_metadata.iterrows():
if index % 100000 == 0:
print(f'{index}/{len(cleaned_metadata)} rows checked')
row_id = row['DHSID_EA']
folder = row_id[:row_id.rindex('-')]
satellite_image_path = os.path.join('dhs/satellite', folder, row_id+'.npz')
cleaned_metadata.at[index, 'satellite_image_path'] = satellite_image_path
label = []
if len(data[data['DHSID_EA'] == row_id]) == 0:
# Error, shouldn't happen
print('no mapping')
elif len(data[data['DHSID_EA'] == row_id]) == 2:
# Error, shouldn't happen
print('why are there two mappings?')
else:
# Extract the label from the DHS dataset
label = data[data['DHSID_EA'] == row_id].iloc[0][column_names].to_numpy()
# Add label to combined dataset
for column_index, column in enumerate(column_names):
cleaned_metadata.at[index, column] = label[column_index]
# Add to test set with 15% probability
if random.random() < test_prob:
test_indexes.append(index)
print(f'test set size: {len(test_indexes)}')
# Create new dataframes from the evaluated indices
cleaned_metadata_test = cleaned_metadata.loc[test_indexes].copy()
cleaned_metadata_train = cleaned_metadata.drop(test_indexes)
# Assert mutual exclusiveness
assert(len(cleaned_metadata_train) + len(cleaned_metadata_test) == len(cleaned_metadata))
# Convert dataframes to csv files
cleaned_metadata_test.to_csv('metadata_clean_test.csv', index=False)
cleaned_metadata_train.to_csv('metadata_clean_train.csv', index=False)
print('done!')
Defining the model¶
I define the model as a ResNet34, whose architecture and pretrained weights I import. I need to adapt the architecture, as the ResNet is designed as a classification model, yet we need it for a regression task. Also, the different datasets have a different number of channels, so those need to be dynamically modified. I attach the hooks for plotting the learned features later.
class ResNetRegressor(nn.Module):
def __init__(self, input_size=8, output_size=6):
super(ResNetRegressor, self).__init__()
# Load a pre-trained ResNet model
self.resnet = models.resnet34(weights=ResNet34_Weights.DEFAULT)
# Modify the first convolution layer to accept a specified number of input channels
self.resnet.conv1 = nn.Conv2d(input_size, 64, kernel_size=3, stride=2, padding=3, bias=False)
# Replace the last fully connected layer for regression of six values (output_size=6)
# ResNet18 originally uses 512 features before the FC layer
num_features = self.resnet.fc.in_features
self.resnet.fc = nn.Linear(num_features, output_size)
def forward(self, x):
return self.resnet(x)
# Register the hooks for learned feature visualisation later on
def register_hook_to_last_layer(self):
activations = []
def get_activation(name):
def hook(model, input, output):
activations.append(output.detach())
return hook
# Attach the hook to the last convolutional layer of the last block
self.resnet.layer4[-1].conv2.register_forward_hook(get_activation('last_conv'))
return activations
def register_hook_to_first_layer(self):
activations = []
def get_activation(name):
def hook(model, input, output):
activations.append(output.detach())
return hook
# Attach the hook to the last convolutional layer of the first block
self.resnet.layer1[-1].conv1.register_forward_hook(get_activation('last_conv_layer1'))
return activations
def register_hook_to_second_layer(self):
activations = []
def get_activation(name):
def hook(model, input, output):
activations.append(output.detach())
return hook
# Attach the hook to the last convolutional layer of the second block
self.resnet.layer2[-1].conv2.register_forward_hook(get_activation('last_conv_layer2'))
return activations
Defining the datasets¶
I now define the datasets, inherited from the torch Dataset module. This format allows for defining data loaders down the line, which allow for very convenient data loading, with parallel processing, shuffling, etc. I define one dataset per model, thus a total of three datasets, as they need to load different data.
class DHSSatelliteDataset(Dataset):
def __init__(self, root_path, csv_file, column_names, transform=None, test=False):
"""
Args:
root_path (string): Path to the satellite image data.
csv_file (string): Path to the dhs data csv file.
column_names (list): List of relevant columns in the dhs dataset.
transform (callable, optional): Optional transform to be applied on a sample.
test (bool, optional): True if this is a test set.
"""
self.column_names = column_names
self.data = pd.read_csv(csv_file)[self.column_names + ['DHSID_EA']]
self.transform = transform
self.root_path = root_path
# Load test or train/val data depending on the boolean value test
if test:
self.image_id_path_pairs = np.load('clear_satellite_image_id_path_pairs_test.npy')
else:
self.image_id_path_pairs = np.load('clear_satellite_image_id_path_pairs_train.npy')
def __len__(self):
# The length is simply the number of images
return len(self.image_id_path_pairs)
def __getitem__(self, idx):
image_id, image_path = self.image_id_path_pairs[idx]
satellite_image = np.load(image_path)['x']
# The eight channels correspond to different measurements performed by the satellite
assert(satellite_image.shape == (8, 255, 255))
label = -1
# This is really just for debugging
if len(self.data[self.data['DHSID_EA'] == image_id]) == 0:
print('no mapping')
return
elif len(self.data[self.data['DHSID_EA'] == image_id]) == 2:
print('why are there two mappings?')
return
else:
label = self.data[self.data['DHSID_EA'] == image_id].iloc[0][self.column_names].to_numpy()
if self.transform:
satellite_image = self.transform(satellite_image)
new_label = []
# This is a bodge, but it works
for l in label:
new_label.append(l)
return satellite_image.permute(1, 0, 2).float(), torch.tensor(new_label).float()
class DHSCombinedDataset(Dataset):
def __init__(self, root_path, csv_file, satellite_transform=None, street_transform=None, test=False):
"""
Args:
root_path (string): Path to the satellite image data.
csv_file (string): Path to the dhs data csv file.
street_transform (callable, optional): Optional transform to be applied on the street data.
test (bool, optional): True if this is a test set.
"""
self.data = pd.read_csv(csv_file)[column_names + ['DHSID_EA', 'cname']]
self.data = self.data.dropna()
self.satellite_transform = satellite_transform
self.street_transform = street_transform
self.root_path = root_path
if test:
self.metadata = pd.read_csv('metadata_clean_test.csv')
else:
self.metadata = pd.read_csv('metadata_clean_train.csv')
def __len__(self):
# The length is simply the number of images
return len(self.metadata)
def __getitem__(self, idx):
metadata_row = self.metadata.iloc[idx]
street_image = Image.open(os.path.join(self.root_path, 'imagery', metadata_row['img_path']))
satellite_image = np.load(metadata_row['satellite_image_path'])['x']
label = metadata_row[column_names].to_numpy()
if self.street_transform:
street_image_transformed = self.street_transform(street_image)
if self.satellite_transform:
satellite_image_transformed = self.satellite_transform(satellite_image).permute(1, 0, 2)
final_image = torch.cat((street_image_transformed, satellite_image_transformed), dim=0)
new_label = []
for l in label:
new_label.append(l)
new_tensor_label = torch.tensor(new_label)
return final_image.float(), new_tensor_label.float()
class DHSStreetDataset(Dataset):
def __init__(self, root_path, csv_file, transform=None, test=False):
"""
Args:
root_path (string): Path to the satellite image data.
csv_file (string): Path to the dhs data csv file.
satellite_transform (callable, optional): Optional transform to be applied on the satellite data.
street_transform (callable, optional): Optional transform to be applied on the street data.
test (bool, optional): True if this is a test set.
"""
self.data = pd.read_csv(csv_file)[column_names + ['DHSID_EA', 'cname']]
self.data = self.data.dropna()
self.transform = transform
self.root_path = root_path
if test:
self.metadata = pd.read_csv('metadata_clean_test.csv')
else:
self.metadata = pd.read_csv('metadata_clean_train.csv')
def __len__(self):
# The length is simply the number of images
return len(self.metadata)
def __getitem__(self, idx):
metadata_row = self.metadata.iloc[idx]
image = Image.open(os.path.join(self.root_path, 'imagery', metadata_row['img_path']))
id = metadata_row['DHSID_EA']
if len(self.data[self.data['DHSID_EA'] == id]) == 0:
print('no mapping')
elif len(self.data[self.data['DHSID_EA'] == id]) == 2:
print('why are there two mappings?')
else:
label = self.data[self.data['DHSID_EA'] == id].iloc[0][column_names].to_numpy()
if self.transform:
image = self.transform(image)
new_label = []
for l in label:
new_label.append(l)
new_tensor_label = torch.tensor(new_label)
return image.float(), new_tensor_label.float()
Defining the train function¶
Here I define the hyperparemeters as well as the function for training the model. The function loads the test and validation data, creates the data loaders and runs the training loop. The data loader automatically loads the correct batch. The batch is then fed through the network, the loss is computed and propagated back through the network in order to tune the weights. The training and validation losses per epoch are kept track of in lists which are returned, together with the best model. The best model is not necessarily the final model, but rather the model which performs best on the validation data. This model is also saved to a file in the models folder.
# Hyperparameters
num_epochs = 50
batch_size = 256
learning_rate = 0.000005
# Use the GPU if it's available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
criterion = nn.MSELoss()
def train_model(model, dataset, name):
training_losses = []
validation_losses = []
# Split the dataset into training, validation, and test sets
train_size = int(train_split * len(dataset))
validation_size = len(dataset) - train_size
train_dataset, validation_dataset = random_split(dataset, [train_size, validation_size])
num_workers = 8
# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
validation_loader = DataLoader(validation_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
# Initialize the model, move it to the GPU
model = model.to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
print(f'Total training data: {len(train_loader.dataset)}')
lowest_val_loss = math.inf
best_model = None
# Main training loop
for epoch in range(num_epochs):
model.train()
total_epoch_training_loss = 0
for inputs, targets in train_loader:
# PyTorch makes training very easy
optimizer.zero_grad()
inputs = inputs.to(device)
targets = targets.to(device)
outputs = model(inputs)
loss = criterion(outputs, targets)
total_epoch_training_loss += loss
loss.backward()
optimizer.step()
normalised_epoch_training_loss = total_epoch_training_loss / len(train_dataset)
training_losses.append(normalised_epoch_training_loss)
# Validation loop
model.eval()
total_epoch_val_loss = 0
with torch.no_grad():
for inputs, targets in validation_loader:
inputs = inputs.to(device)
targets = targets.to(device)
outputs = model(inputs)
total_epoch_val_loss += criterion(outputs, targets)
normalised_epoch_val_loss = total_epoch_val_loss / len(validation_dataset)
validation_losses.append(normalised_epoch_val_loss)
# If the loss on the validation data is lower than the previously lowest loss, a new
# best model has been found
if total_epoch_val_loss < lowest_val_loss:
lowest_val_loss = total_epoch_val_loss
torch.save(model.state_dict(), f'models/{name}_best.pth')
best_model = model
# Print validation accuracy
print(f'Epoch {epoch+1}/{num_epochs}, Normalised Training Loss: {normalised_epoch_training_loss:.2f} Normalised Validation Loss: {normalised_epoch_val_loss:.2f}')
return best_model, training_losses, validation_losses
Defining the transforms, models and datasets¶
The transforms take the raw input and transform them to the format which the model requires. These transforms differ for the street and satellite images, due to their differing size. Both images end up as 255x255 normalised tensors. I then define the models and datasets.
# To make this reproducible
torch.manual_seed(42)
street_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Resize((255, 255), antialias=False),
transforms.Normalize((0.5,), (0.5,)),
])
satellite_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,)) # Normalize the images
])
satellite_dataset = DHSSatelliteDataset('dhs/satellite', 'dhs_final_labels.csv', column_names, satellite_transform, test=False)
satellite_model = ResNetRegressor(input_size=8)
street_dataset = DHSStreetDataset('dhs/mapillary', 'dhs_final_labels.csv', street_transform, test=False)
street_model = ResNetRegressor(input_size=3)
combined_dataset = DHSCombinedDataset('dhs/mapillary', 'dhs_final_labels.csv', satellite_transform=satellite_transform, street_transform=street_transform, test=False)
combined_model = ResNetRegressor(input_size=11)
Training the model¶
I am not actually training the model in this notebook, as Jupyter is finicky and training 50 epochs takes up to a day. Instead, I trained the model as a script within a detached process on my desktop computer. This next cell would execute the training, but I commented out nearly the entire cell so I can run the notebook without having to wait more than a day for this cell to finish executing. I save the losses to a pickle file to import again for the evaluation.
import pickle
#satellite_model, satellite_train_losses, satellite_val_losses = train_model(satellite_model, satellite_dataset, 'satellite')
#with open('satellite_train_losses.pkl', 'wb') as file:
# pickle.dump(satellite_train_losses, file)
#with open('satellite_val_losses.pkl', 'wb') as file:
# pickle.dump(satellite_val_losses, file)
#street_model, street_train_losses, street_val_losses = train_model(street_model, street_dataset, 'street')
#with open('street_train_losses.pkl', 'wb') as file:
# pickle.dump(street_train_losses, file)
#with open('street_val_losses.pkl', 'wb') as file:
# pickle.dump(street_val_losses, file)
#combined_model, combined_train_losses, combined_val_losses = train_model(combined_model, combined_dataset, 'combined')
#with open('combined_train_losses.pkl', 'wb') as file:
# pickle.dump(combined_train_losses, file)
#with open('combined_val_losses.pkl', 'wb') as file:
# pickle.dump(combined_val_losses, file)
Experimental analysis (performance & scalability)¶
Tuning the hyperparameters¶
I experimented with different models in a PyCharm environment, which is a lot easier to work with than a Jupyter notebook. I tested different CNNs with various models and established that the ResNet34 performs best given the limitations of the RTX 4080 GPU I run it on. A larger ResNet may further increase performance. The other hyperparameters I evaluated as performing well were
- The Adam optimiser, which is a popular choice. It uses adaptive learning rates and includes momentum, i.e., information on past gradients, which helps navigate noisy gradients.
- A learning rate of 0.000005, which is surprisingly small.
- A batch size of 256, which I would have preferred to be higher, yet 16GB VRAM did not allow for more. A large batch size gives more accurate predictions of the gradients, yet can lead to overfitting on small data sets. Our datasets are very large, thus a larger batch size would have been desirable.
- The mean squared error as a loss function. The MSE is a good choice for regression tasks and due to the errors being squared, it penalises outliers more strongly. I also tested the mean absolute error but it did not perform as well.
Plotting the loss over the test and validation data¶
First of all, I will plot the training and validation loss over the training epochs. The training and validation data are mutually exclusive, the training data being 80% and the validation data being 20% of the non-test data, randomly split. We expect the training loss to shrink consistently, as the model should increasingly fit the training data during each epoch. Plotting the training loss is more of a sanity check than anything else, as only the validation data indicates the performance of the model on new data. The validation data is data which the model has not trained on. We want a model which performs maximally well on new data, regardless of how well it performs on the training data. If the training loss is low, yet the validation loss is high, then we have overfit the model. In the next cell, I plot the training and validation losses of each model over the training epochs. This provides an intuitive visualisation of the training process and potential quality of each model.
# Load the best models from the file saved during training
satellite_model.load_state_dict(torch.load('models/satellite_best.pth'))
street_model.load_state_dict(torch.load('models/street_best.pth'))
combined_model.load_state_dict(torch.load('models/combined_best.pth'))
# Load the losses for plotting
with open('satellite_train_losses.pkl', 'rb') as file:
satellite_train_losses = [x.item() for x in pickle.load(file)]
with open('satellite_val_losses.pkl', 'rb') as file:
satellite_validation_losses = [x.item() for x in pickle.load(file)]
with open('street_train_losses.pkl', 'rb') as file:
street_train_losses = [x.item() for x in pickle.load(file)]
with open('street_val_losses.pkl', 'rb') as file:
street_validation_losses = [x.item() for x in pickle.load(file)]
with open('combined_train_losses.pkl', 'rb') as file:
combined_train_losses = [x.item() for x in pickle.load(file)]
with open('combined_val_losses.pkl', 'rb') as file:
combined_validation_losses = [x.item() for x in pickle.load(file)]
# Plotting
plt.figure(figsize=(12, 7))
plt.plot(list(range(0, len(satellite_train_losses))), satellite_train_losses, color='b', linestyle='--', label='Satellite Train')
plt.plot(list(range(0, len(satellite_validation_losses))), satellite_validation_losses, color='b', label='Satellite Validation')
plt.plot(list(range(0, len(street_validation_losses))), street_train_losses, color='r', linestyle='--', label='Street Train')
plt.plot(list(range(0, len(street_validation_losses))), street_validation_losses, color='r', label='Street Validation')
plt.plot(list(range(0, len(combined_validation_losses))), combined_train_losses, color='g', linestyle='--', label='Satellite + Street Train')
plt.plot(list(range(0, len(combined_validation_losses))), combined_validation_losses, color='g', label='Satellite + Street Validation')
plt.title('Normalised Loss per Epoch')
plt.xlabel('Epoch')
plt.ylabel('Normalised Loss')
plt.grid(True)
plt.legend()
plt.show()
We observe the following:
- The training error decreases consistently for all three models, thus we passed the sanity check.
- The model using both the satellite and the street images outperforms the other two models on the validation data by a significant margin.
- The model trained on the street images starts overfitting the training data after 10 iterations, as thereafter the training loss decreases, yet the validation loss starts rising again.
This is a satisfying result and it seems like the combined model has managed to learn to fit the training data while still managing to generalise on the validation data.
Performance on the test data¶
Reasoning about the quality of a model is useless until we have evaluated its performance on the test data. The validation loss gives us information on how well the model generalises, yet I have used the validation data to evaluate the hyperparameters. Thus, there exists a possibility that I overfit the model to the validation data. The test data is unknown data neither I nor the model have any information on. Thus, if the model performs well on the test data, we can infer that it will perform well in practice.
The evaluation metrics at first will be the mean squared error loss and the squared Pearson correlation coefficient. Both metrics indicate the relation of the predictions to the ground truth data. The squared Pearson correlation coefficient, already described in the experimental data analysis, computes the general degree of correlation (non-negative, as it is squared) between the predictions and the ground truth test data. A high correlation coefficient is desirable.
from scipy.stats import pearsonr
print(f'device: {device}')
satellite_dataset_test = DHSSatelliteDataset('dhs/satellite', 'dhs_final_labels.csv', column_names, satellite_transform, test=True)
street_dataset_test = DHSStreetDataset('dhs/mapillary', 'dhs_final_labels.csv', street_transform, test=True)
combined_dataset_test = DHSCombinedDataset('dhs/mapillary', 'dhs_final_labels.csv', satellite_transform=satellite_transform, street_transform=street_transform, test=True)
# Test each of the three models
for name, test_dataset, model in [('satellite', satellite_dataset_test, satellite_model), ('street', street_dataset_test, street_model), ('combined', combined_dataset_test, combined_model)]:
test_loader = DataLoader(test_dataset, batch_size=512, shuffle=False, num_workers=6)
model.eval()
model.to(device)
total_test_loss = 0
predictions = []
actuals = []
# We only need to compute the gradient when training
# When testing, we can disable the gradient which saves time and computational resources
with torch.no_grad():
for inputs, targets in test_loader:
inputs = inputs.to(device)
targets = targets.to(device)
outputs = model(inputs)
total_test_loss += criterion(outputs, targets)
predictions.extend(outputs.cpu().numpy())
actuals.extend(targets.cpu().numpy())
# Compute Pearson correlation coefficient for each feature
for i in range(6):
corr, _ = pearsonr([p[i] for p in predictions], [a[i] for a in actuals])
print(f'Squared Pearson correlation coefficient of {name} model for {column_names[i]}: r^2 = {round(corr**2, 5)}')
print(f'\n{name} normalised test loss: {round(total_test_loss.detach().item() / len(test_dataset), 5)}\n')
device: cuda Squared Pearson correlation coefficient of satellite model for water_index: r^2 = 0.18964 Squared Pearson correlation coefficient of satellite model for sanitation_index: r^2 = 0.35702 Squared Pearson correlation coefficient of satellite model for under5_mort: r^2 = 0.112 Squared Pearson correlation coefficient of satellite model for asset_index: r^2 = 0.36497 Squared Pearson correlation coefficient of satellite model for women_edu: r^2 = 0.26019 Squared Pearson correlation coefficient of satellite model for women_bmi: r^2 = 0.02987 satellite normalised test loss: 0.74031 Squared Pearson correlation coefficient of street model for water_index: r^2 = 0.2137 Squared Pearson correlation coefficient of street model for sanitation_index: r^2 = 0.26379 Squared Pearson correlation coefficient of street model for under5_mort: r^2 = 0.10356 Squared Pearson correlation coefficient of street model for asset_index: r^2 = 0.31326 Squared Pearson correlation coefficient of street model for women_edu: r^2 = 0.382 Squared Pearson correlation coefficient of street model for women_bmi: r^2 = 1e-05 street normalised test loss: 0.44135 Squared Pearson correlation coefficient of combined model for water_index: r^2 = 0.81794 Squared Pearson correlation coefficient of combined model for sanitation_index: r^2 = 0.85588 Squared Pearson correlation coefficient of combined model for under5_mort: r^2 = 0.99724 Squared Pearson correlation coefficient of combined model for asset_index: r^2 = 0.95624 Squared Pearson correlation coefficient of combined model for women_edu: r^2 = 0.97702 Squared Pearson correlation coefficient of combined model for women_bmi: r^2 = 0.95991 combined normalised test loss: 0.00162
We immediately observe the very high correlation of the combined model's predictions. I did not expect the results to be this good. When first encountering these results, I suspected that the results were too good, suggesting a data leak. Thus, I re-implemented my datasets such as to create two completely separate lists of cleaned and preprocessed data, one for training and validation, the other for testing, such that the data is guaranteed to be separate. But after re-training, the results remained as good as they were.
The other models alone are unable to produce significant correlation. We can observe that, as with the validation data, the combined model outperforms the other two significantly with regard to the test loss.
The benchmark model developed by the authors of the SustainBench team achieves the following squared Pearson correlation coefficients between the predictions and the true data:
- Water Quality Index: $r^2 = 0.4$
- Sanitation Index: $r^2 = 0.36$
- Child Mortality: $r^2 = 0.01$
- Asset Index: $r^2 = 0.63$
- Women's Education: $r^2 = 0.26$
- Women's BMI: $r^2 = 0.42$
The combined model thus seems to wildly outperform the baseline model. Overall, the baseline model outperforms the satellite and street model, the exception being in regard to the child mortality rate, for which all three of my models are superior.
Comparison to random predictions¶
We will now only consider the combined model, as it is clearly the superior model. So far, we have only evaluated the relative performance of the models when compared to each other. If the combined model does not outperform random guesses, it is useless. Thus, I will compare the test data loss of the combined model to the loss when performing random guesses.
mean = data_no_nan[column_names].mean()
total_benchmark_loss = 0
random_benchmark_loss = 0
test_loader = DataLoader(combined_dataset_test, batch_size=512, shuffle=False, num_workers=4)
random_predictions = []
actuals = []
for inputs, targets in test_loader:
random_guess = torch.tensor((np.random.rand(targets.shape[0], 6)-0.5)*10)
random_benchmark_loss += criterion(random_guess, targets)
random_predictions.extend(random_guess.numpy())
actuals.extend(targets.numpy())
# Compute Pearson correlation coefficient for each feature
for i in range(6):
corr, _ = pearsonr([p[i] for p in random_predictions], [a[i] for a in actuals])
print(f'Squared Pearson correlation coefficient of random benchmark for {column_names[i]}: r^2 = {corr**2}')
print(f'Random benchmark loss: {random_benchmark_loss / len(combined_dataset_test)}')
Squared Pearson correlation coefficient of random benchmark for water_index: r^2 = 4.1952912785140774e-05 Squared Pearson correlation coefficient of random benchmark for sanitation_index: r^2 = 9.294217353504598e-09 Squared Pearson correlation coefficient of random benchmark for under5_mort: r^2 = 7.390261732725432e-06 Squared Pearson correlation coefficient of random benchmark for asset_index: r^2 = 4.040565161363265e-05 Squared Pearson correlation coefficient of random benchmark for women_edu: r^2 = 8.574681306376353e-06 Squared Pearson correlation coefficient of random benchmark for women_bmi: r^2 = 2.3161961569539365e-06 Random benchmark loss: 0.7936563413002258
As expected, the random benchmark does not predict results with any significant correlation to the ground truth. We can observe that the combined model outperforms random guesses by a significant margin, with regard to both the test loss and the correlation. Thus, we have established that the combined model makes more meaningful predictions than when using randomly guesses or the SustainBench baseline model.
Performing inference¶
Having concluded that the combined model outperforms the benchmarks, we can try and use it to infer some data. In practice, we would use a meta-learning approach. Thereby we would predict the features given all the street level images per location and then use the mean of those predictions as our final prediction. For the purpose of demonstration, I am only going to generate a prediction given a single street image instead of simulating the meta-learner.
# Switch to evaluation mode and load the model to the CPU, as we only need a GPU if we want to make use of its parallelisation.
combined_model.eval()
combined_model.to('cpu')
# We only need to compute the gradient when training
# When testing, we can disable the gradient which saves time and computational resources
with torch.no_grad():
datapoint = combined_dataset_test.__getitem__(233)
image = datapoint[0]
fig, axs = plt.subplots(1, 2, figsize=(20, 10))
axs[0].set_title('Street input as RGB image')
axs[0].imshow(image[:3].permute(1, 2, 0)*2)
axs[1].set_title('Satellite input as THERM1 representation')
axs[1].imshow(image[9])
some_input = torch.unsqueeze(combined_dataset_test[0][0], dim=0)
outputs = list(combined_model(some_input)[0])
predictions = list(zip(column_names, outputs))
for i in range(6):
print(f'Predicted {predictions[i][0]}: {predictions[i][1].item()}, ground truth {datapoint[1][i]}')
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Predicted water_index: 4.3228020668029785, ground truth 4.0 Predicted sanitation_index: 2.4591689109802246, ground truth 2.75 Predicted under5_mort: -0.2389991134405136, ground truth 0.0 Predicted asset_index: -0.1608094573020935, ground truth -0.31347164511680603 Predicted women_edu: 6.261411666870117, ground truth 8.875 Predicted women_bmi: 26.67802619934082, ground truth 24.003999710083008
Visualising the learned features¶
It could be interesting to observe how the models make their predictions. The following tile defined a function which visualises the features which the network has learned at different depths using so-called hooks in the network. The hooks allow for accessing the learned features within the network.
# Get the features by registering the defined hooks during the feed forward
def plot_feature_extraction(model, dataset, name):
print(f'Extracted features for the {name} model:')
model = model.to('cpu')
activations1 = model.register_hook_to_first_layer()
activations2 = model.register_hook_to_second_layer()
activations4 = model.register_hook_to_last_layer()
# Make sure this single forward pass occurs on the cpu
output = model(torch.unsqueeze(dataset.__getitem__(0)[0], dim=0))
# Define the grid size
grid_size = math.floor(math.sqrt(64))
# Create a figure to hold the subplots
fig, axes = plt.subplots(grid_size, grid_size, figsize=(20, 20))
fig.suptitle('Learned features after first block')
# Plot each feature map
for i, ax in enumerate(axes.flat):
# Normalize the feature map for visualization
feature_map = activations1[0].squeeze(0)[i]
feature_map = feature_map - feature_map.min()
feature_map = feature_map / feature_map.max()
ax.imshow(feature_map, cmap='gray')
ax.axis('off')
plt.show()
print('\n')
# Create a figure to hold the subplots
fig, axes = plt.subplots(grid_size, grid_size, figsize=(20, 20))
fig.suptitle('Learned features after second block')
# Plot each feature map
for i, ax in enumerate(axes.flat):
feature_map = activations2[0].squeeze(0)[i]
feature_map = feature_map - feature_map.min()
feature_map = feature_map / feature_map.max()
ax.imshow(feature_map, cmap='gray')
ax.axis('off')
plt.show()
print('\n')
# Compute the grid size
grid_size = math.floor(math.sqrt(512))
# Create a figure to hold the subplots
fig, axes = plt.subplots(grid_size, grid_size, figsize=(20, 20))
fig.suptitle('Learned features after last block')
# Plot each feature map
for i, ax in enumerate(axes.flat):
# Remove the extra 1-dimension
feature_map = activations4[0].squeeze(0)[i]
# Normalize the feature map for visualization
feature_map = feature_map - feature_map.min()
feature_map = feature_map / feature_map.max()
ax.imshow(feature_map, cmap='gray')
ax.axis('off')
plt.show()
print('\n')
plot_feature_extraction(combined_model, combined_dataset, 'combined')
Extracted features for the combined model:
The three grids of images visualise how the ResNet34 learns increasingly abstract (i.e., increasingly deep) representations of the data in order to make its predictions. It is hard to interpret these results, as we cannot make out anything meaningful other than vague textures with a rough seeming texture in the top left corner. This may be due to the fact that the 11 input channels do not visually make sense together, i.e., there is no sensible way of plotting them in 2D. This may be the reason for the not clearly interpretable learned features of the network.
Plotting this for the street model is slightly useless, but still really interesting, as one can observe how it approximates a street at different levels of abstraction.
plot_feature_extraction(street_model, street_dataset, 'street')
Extracted features for the street model:
Ethical considerations¶
Interpretability¶
As observed in the previous section, the most successful model does not learn any visually interpretable features. This has two consequences. Firstly, this means that we cannot use this model to learn which visual features can be useful for predicting the socio-economic features. Secondly, this implies that using the model to make decisions could be problematic, as the use of predictions made by a model without knowing how the model made those predictions means it would render potential users very susceptible to the biases and errors it has learned. Without being able to interpret how the model makes predictions, it is hard to mitigate the biases and we must assume that we have unbiased data, which is an assumption we cannot make.
Biased predictions¶
Biases are undoubtedly present in the dataset. It consists of data collected by the United States Agency for International Development, a government institution in the Global North, describing phenomena in the Global South. This is a uni-directional approach which undoubtedly creates biased data, as the people designing the surveys invariably incorporate their subjective perspective. Positionality influencing design is inherent to humans and thus not a bad thing in and of itself, but if there is no diversity of actors involved in designing the surveys, the resulting information will be biased. During my research, I could not find any information on how the DHS attempts to mitigate these biases.
In addition, it is questionable if we in the Global North should be collecting data on the Global South in an altruistic and arguably patronising manner. This extends towards developing this model. For the sake of this argument, let us assume that this model would be used by local agents in the countries present in the DHS dataset in order to make informed policy decisions with the SDGs in mind. Then, using this model in such a way would be ethically questionable due to its uni-directional development. In order to develop a model which could be deployed in the affected countries in an ethical way, not only the data collection, but also the development of the model should happen with many different actors from the targeted countries involved. The focus should be shifted from model accuracy towards what a representative sample of affected people need, what they can work with and how we can enable the SDGs with approaches and tools developed at least in part by them. We cannot solve the SDGs by providing affected people with technology they never asked for.
Gerrymandering¶
Gerrymandering is a political strategy where electoral boundaries, i.e., the boundaries of a politically relevant area, are drawn strategically in order to achieve a certain outcome. It follows that using discrete geographical locations to measure some geographical data, be it public votes or socio-economic factors, is not objective and inherently related to the procedure of how those discrete locations are defined. In geography-based data like ours, where each datapoint denotes one geographical location, the boundaries of the locations can be drawn arbitrarily, which influences the results. This is a factor to consider when thinking about biases within the data.
Uncertainty quantification¶
Machine learning models invariably, on occasion, output nonsense, even if those occasions are rare. The rare occasion where our model misjudges a piece of information (such as child mortality) could potentially lead to serious and unwanted consequences in practice. These consequences could involve the unnecessary allocation of scarce resources or the false confidence that a problem (such as high child mortality in an area) has been solved. A future model could incorporate Bayesian uncertainty quantification, which would allow for evaluating how confident the predictions are. Quoting Hein et al. (2019), machine learning models should "know when they don't know", allowing the user to avoid false confidence.
Consequences of predictions based on mere visual observations¶
The use of mere visual observations in a model for predicting highly impactful aspects of human life (e.g. wealth, health, child mortality) is debatable. This method is akin to a human walking through a village and inferring, based on the dirt road, the number of houses and the shape of people's skulls, that around every $k$-th child dies before the age of 5. While potentially accurate and effective, such an approach raises significant ethical concerns, as it oversimplifies complex social phenomena and relies on superficial assessments that may not accurately reflect the underlying realities of these communities.
Sustainable development relevance & impact¶
Relevance of the exploratory data analysis¶
The exploratory data analysis demonstrated the interconnectedness of different SDGs. This notion is of immense importance when thinking about solving the SDGs. If we want to find sustainable and effective solutions, we cannot look at individual targets only. We must try to understand the complex system of correlated targets. The data analysis additionally allowed us to distinguish between the features with high correlation and those without. I.e., water quality and access to sanitation seem to have very high correlation (thus justifying their combination in SDG 6). Water quality and child mortality, on the other hand, seem to be completely uncorrelated, i.e., they have insignificant covariance. This approach (which are not machine learning, but closely related) can allow for understanding the interconnectedness of the targets.
It is important to mention that correlation does not imply causation. This is an important premise in applied statistics and follows from the questionable cause fallacy. The questionable cause fallacy is the false belief that if A and B occur together (such as a low asset index and low water quality), then A must cause B. E.g., if AIDS and homosexuality are strongly correlated, this does not imply that AIDS causes homosexuality (example borrowed from here). Similarly, we must keep in mind that we cannot directly infer a causal relationship from correlated SDGs.
Overall, data analysis is a useful tool as a supplement to a qualitative approach to understanding the complex interconnectedness of the SDGs. It is important to note that the numbers and plots derived as part of the exploratory data analysis have no meaning in and of themselves. They need to be embedded in detailed descriptions of the context within which they are relevant in order to become meaningful and thereby useful.
Relationship of the dataset to the SDGs¶
The indicators are our sole approach to measuring progress towards the SDGs. Thus, keeping track of the indicators is of great importance. The DHS dataset contains data which aligns with many different SDGs and can be used to infer their indicators.
Water quality index¶
The water quality index clearly aligns with SDG 6: Clean Water and Sanitation, as it measures how clean and accessible the water is. Detailed information on how it is computed can be found in the SustainBench paper (Yeah et al, 2021). It contains information on the access to clean water, i.e., information which aligns with the indicator 6.1.1 (proportion of population using safely managed drinking water services).
Sanitation index¶
The sanitation index also aligns with SDG 6: Clean Water and Sanitation. It contains information on the proportion of people who have access to sanitation facilities, as described by indicator 6.2.1 (proportion of population using safely managed sanitation services, including a hand-washing facility with soap and water). Again, the SustainBench paper contains detailed information on how the index is computed.
Child mortality rate¶
The SDG 3: Good Health and Well-being has the indicator 3.2.1 (under-five mortality rate) which corresponds exactly to this feature.
Asset index¶
This feature aligns with SDG 1: No Poverty, as a very low asset index corresponds to poverty, and with SDG 10: Reduced Inequalities, as financial indequality can be quantified by the variance of the asset index. In particular, the asset index contains information on the indicator 1.4.1 (proportion of population living in households with access to basic services).
Women's average educational attainment in years¶
This feature does not correpond to any indicator in particular, but is closely related to SDG 4: Quality Education and SDG 5: Gender Equality.
Women's average BMI¶
This feature aligns with SDG 3: Good Health and Well-being. An excessively high BMI is correlated with most metabolic and cardiovascular diseases. Thus, indicator 3.4.1 (mortality rate attributed to cardiovascular disease, cancer, diabetes or chronic respiratory disease) is closely related to this feature. An excessively low BMI is correlated with chronic hunger, thus relating to SDG 2: Zero Hunger.
Where machine learning can help¶
This data was, and data on indicators generally is, obtained by surveys (hence the S in DHS). Surveys are expensive and slow. A model such as the combined model could benefit the SDGs in two ways:
- Allowing for inference, i.e., surveying only half of the locations and then inferring the data using machine learning on the others, thereby recucing costs.
- Keeping track of the features (asset index, water index, etc.) in a fully automatic way using a machine learning model and up-to-date street and satellite images.
- Analysing the feature extraction of the CNN, observing which features in the images the CNN connects to the survey results.
On paper, the model I trained taking both street-level and satellite images as inputs could do a very good job at the first two tasks. It very accurately predicts the features given visual inputs. I highlighted some problems with this approach in the previous section.
The third aspect seems to be ethically much less questionable. It does not replace human competence, but rather acts as a tool to help people understand complex phenomena relating to progress towards the SDGs. A CNN, such as the ResNet34 I used for training, learns different features in the form of channels. Observing the features it learned can tell us which features it uses in order to make its predictions. If a model is good at making predictions, such as the combined model, then trying to understand how it makes the predictions could provide valuable insights into the correlation between visual features and indicators to the SDGs. Unfortunately, my combined model is not suited for this task, as it does not learn visually intepretable features. As previously discussed, this may be due to the fact that it receives input images with 11 channels, which cannot be displayed meaningfully in 2D, and thus it does not extract features which are interpretable as such.
The model trained on the street level images, on the other hand, was trained on three channels, which represent the RBG channels of an image on a screen. Thus, these RBG images reveal learned features which are visually interpretable as different shapes of a road, yet they do not provide any specific information on how the predictions are made. This may be due to the fact that the model does not sufficiently manage to fit the data, as observed in the earlier evaluation steps.
Novelty of Project¶
Projects similar to this one have been attempted in the past, as mentioned in the first section. This project extends the current research in three ways:
- It extends the SustainBench and proposes a model which outperforms the other models on the given benchmark tasks.
- It combines street and satellite images to predict socio-economic factors, while previous attempts focus solely on either street or satellite images.
- It provides a visual analysis of the learned features, unfortunately concluding that on this model, the features are not visually interpretable.
For future work it would be interesting to use non-visual approaches to understanding how the model makes its predictions. In addition, incorporating Bayesian uncertainty quantification would allow for more confident predictions of critical data. Finally, future work could try to find further solutions to mitigating the biases mentioned in the previous section. I conclude with the assessment that machine learning can be a very powerful tool for inferring SDG indicators, with the danger of potentially intransparent, biased and superficial predictions. Machine learning can thus not replace human competence, but potentially expand it.