|
import csv |
|
import pystan |
|
import pickle |
|
import numpy as np |
|
import hashlib |
|
import matplotlib |
|
import matplotlib.pyplot as plt |
|
|
|
################ some auxiliary functions ############### |
|
|
|
def CachedStanModel(model_code, model_name=None): |
|
""" |
|
Function to cache compiled Stan models. See: |
|
https://pystan.readthedocs.io/en/latest/ |
|
""" |
|
code_hash = hashlib.md5(model_code.encode('ascii')).hexdigest() |
|
if model_name is None: |
|
cache_fn = 'cached-model-{}.pkl'.format(code_hash) |
|
else: |
|
cache_fn = 'cached-{}-{}.pkl'.format(model_name, code_hash) |
|
try: |
|
sm = pickle.load(open(cache_fn, 'rb')) |
|
except: |
|
sm = pystan.StanModel(model_code=model_code) |
|
with open(cache_fn, 'wb') as f: |
|
pickle.dump(sm, f) |
|
else: |
|
print("Using cached StanModel") |
|
return sm |
|
|
|
def mkEllipse(Xt, scale=1): |
|
""" |
|
Function to make an ellipse that covers a point cloud. |
|
The scale parameter can be used to rescale the ellipse. |
|
""" |
|
## for each point cloud, do a 'PCA' |
|
meanXt = np.mean(Xt, axis=0) |
|
## center the cloud to do eigen decomposition |
|
Xtprime = np.array([X - np.mean(Xt, axis=0) for X in Xt]) |
|
C = np.dot(Xtprime.transpose(), Xtprime) / Xt.shape[0] |
|
eivals, U = np.linalg.eigh(C) |
|
## compute angle (in degrees) |
|
angle = np.arccos(U[0,0]) * np.sign(U[0,1]) * 360/(2*np.pi) |
|
height = scale * 2 * np.sqrt(eivals[1]) |
|
width = scale * 2 * np.sqrt(eivals[0]) |
|
ell = matplotlib.patches.Ellipse(xy=meanXt, |
|
width=width, height=height, angle=angle) |
|
return ell |
|
|
|
def getYear(lab): |
|
""" |
|
Single purpose function to get sampling year |
|
from antigen/antiserum label. |
|
""" |
|
if lab[-3:]=="REC": |
|
yr_str = lab[-5:-3] |
|
else: |
|
yr_str = lab[-2:] |
|
yr = int(yr_str) |
|
## a bodge for the Y2K bug |
|
return yr + 2000 if yr < 50 else yr + 1900 |
|
|
|
############### compile Stan model and HI data ############### |
|
|
|
## open the test file containing the Stan model and compile it. |
|
with open("mds_model.stan", 'r') as f: |
|
mds_model = f.read() |
|
|
|
sm = CachedStanModel(model_code=mds_model) |
|
|
|
## import influenza HI data |
|
with open("baselinemap.csv", 'r') as f: |
|
## NB: this depends on how you saved the HI data |
|
reader = csv.reader(f, delimiter=';') |
|
table = [row for row in reader] |
|
|
|
## the header starts at the first field |
|
## (for you, this could be the second field) |
|
sr_labs = ["SR_" + elt for elt in table[0] if elt != ''] |
|
## the first column contains the antigen labels |
|
ag_labs = ["AG_" + row[0] for row in table[1:]] |
|
values = [row[1:] for row in table[1:]] |
|
|
|
labs = sr_labs + ag_labs |
|
nodes = list(range(1, len(labs)+1)) |
|
labDict = {n : l for n, l in zip(nodes, labs)} |
|
nodeDict = {l : n for n, l in zip(nodes, labs)} |
|
|
|
## uncensored titers |
|
uTiterDict = { |
|
(nodeDict[l1], nodeDict[l2]) : np.log2(float(values[i2][i1])) |
|
for i1, l1 in enumerate(sr_labs) |
|
for i2, l2 in enumerate(ag_labs) |
|
if values[i2][i1] != '*' and values[i2][i1][0] != '<' |
|
} |
|
|
|
## left-censored titers (i.e. right-centered distances) |
|
cTiterDict = { |
|
(nodeDict[l1], nodeDict[l2]) : np.log2(float(values[i2][i1][1:])) |
|
for i1, l1 in enumerate(sr_labs) |
|
for i2, l2 in enumerate(ag_labs) |
|
if values[i2][i1] != '*' and values[i2][i1][0] == '<' |
|
} |
|
|
|
edges = list(uTiterDict.keys()) + list(cTiterDict.keys()) |
|
edges.sort() |
|
|
|
titerDict = { |
|
e : w for e, w in list(uTiterDict.items()) + list(cTiterDict.items()) |
|
} |
|
|
|
## censoring: 0 means uncensored, 2 means right-censored |
|
censorDict = {e : 0 for e in uTiterDict.keys()} |
|
censorDict.update({e : 2 for e in cTiterDict.keys()}) |
|
|
|
## find the maximum titer for each antiserum |
|
maxTiterDict = { |
|
l : np.max([logH for (i1, i2), logH in uTiterDict.items() if labDict[i1]==l]) |
|
for l in sr_labs |
|
} |
|
|
|
## the distance d_{ij} is defined as log_2(H_{max,j}) - log_2(H_{i,j}) |
|
distanceDict = { |
|
e : maxTiterDict[labDict[e[0]]] - w for e, w in titerDict.items() |
|
} |
|
|
|
distances = [distanceDict[e] for e in edges] |
|
censoring = [censorDict[e] for e in edges] |
|
|
|
data = { |
|
'D' : 2, |
|
'E' : len(edges), |
|
'N': len(nodes), |
|
'distances' : distances, |
|
'censoring' : censoring, |
|
'edges' : edges |
|
} |
|
|
|
############# use Stan to minimize the MDS error ############ |
|
|
|
fit_opt = sm.optimizing(data=data) |
|
Xs_opt = fit_opt["Xcr"] |
|
|
|
optXts = Xs_opt.transpose() |
|
|
|
fig = plt.figure(figsize=(5,8)) |
|
ax1 = fig.add_subplot(111, aspect='equal') |
|
|
|
optxs = [X[0] for X in optXts] |
|
optys = [X[1] for X in optXts] |
|
|
|
optxs_ag = [m for m, n in zip(optxs, nodes) if labDict[n][:2]=="AG"] |
|
optys_ag = [m for m, n in zip(optys, nodes) if labDict[n][:2]=="AG"] |
|
|
|
optxs_sr = [m for m, n in zip(optxs, nodes) if labDict[n][:2]=="SR"] |
|
optys_sr = [m for m, n in zip(optys, nodes) if labDict[n][:2]=="SR"] |
|
|
|
year_ag = [getYear(labDict[n]) for n in nodes if labDict[n][:2]=="AG"] |
|
|
|
C = ax1.scatter(optxs_ag, optys_ag, s=10, c=year_ag, cmap='viridis', |
|
linewidth=1, edgecolor='k', zorder=2) |
|
ax1.scatter(optxs_ag, optys_ag, s=5, c=year_ag, cmap='viridis', |
|
linewidth=0, zorder=3) |
|
fig.colorbar(C, ax=ax1, shrink=0.3) |
|
|
|
ax1.scatter(optxs_sr, optys_sr, s=10, c='w', marker='s', |
|
linewidth=1, edgecolor='k', zorder=4) |
|
|
|
ax1.set_ylabel("Antigenic dimension 1") |
|
ax1.set_xlabel("Antigenic dimension 2") |
|
|
|
fig.savefig("mds-iav-ag-cart.png", dpi=300, bbox_inches='tight') |
|
|
|
################# now do some Bayesian MDS ################ |
|
|
|
fit = sm.sampling(data=data, chains=1, iter=1000, warmup=500) |
|
la = fit.extract(permuted=True) |
|
Xs = la['Xcr'] |
|
|
|
Xts = [np.array([Xs[i].transpose()[j] for i in range(len(Xs))]) |
|
for j in range(len(nodes))] |
|
meanXts = [np.mean(Xt, axis=0) for Xt in Xts] |
|
|
|
ells = [mkEllipse(Xt, scale=1) for Xt in Xts] |
|
|
|
fig = plt.figure(figsize=(5,8)) |
|
ax1 = fig.add_subplot(111, aspect='equal') |
|
|
|
meanxs = [meanXt[0] for meanXt in meanXts] |
|
meanys = [meanXt[1] for meanXt in meanXts] |
|
|
|
meanxs_ag = [m for m, n in zip(meanxs, nodes) if labDict[n][:2]=="AG"] |
|
meanys_ag = [m for m, n in zip(meanys, nodes) if labDict[n][:2]=="AG"] |
|
|
|
meanxs_sr = [m for m, n in zip(meanxs, nodes) if labDict[n][:2]=="SR"] |
|
meanys_sr = [m for m, n in zip(meanys, nodes) if labDict[n][:2]=="SR"] |
|
|
|
year_ag = [getYear(labDict[n]) for n in nodes if labDict[n][:2]=="AG"] |
|
|
|
C = ax1.scatter(meanxs_ag, meanys_ag, s=10, c=year_ag, cmap='viridis', |
|
linewidth=1, edgecolor='k', zorder=2) |
|
ax1.scatter(meanxs_ag, meanys_ag, s=5, c=year_ag, cmap='viridis', |
|
linewidth=0, zorder=3) |
|
fig.colorbar(C, ax=ax1, shrink=0.3) |
|
|
|
ax1.scatter(meanxs_sr, meanys_sr, s=10, c='w', marker='s', |
|
linewidth=1, edgecolor='k', zorder=4) |
|
|
|
## define a colormap to color the ellipses |
|
cmap = matplotlib.cm.get_cmap('viridis') |
|
norm = matplotlib.colors.Normalize(vmin=np.min(year_ag), vmax=np.max(year_ag)) |
|
|
|
## draw ellipses |
|
for n, ell in enumerate(ells): |
|
ax1.add_artist(ell) |
|
ell.set_clip_box(ax1.bbox) |
|
ell.set_alpha(0.5) |
|
lab = labDict[n+1] |
|
if lab[:2]=="AG": |
|
ell.set_facecolor(cmap(norm(getYear(lab)))) |
|
else: |
|
ell.set_facecolor("darkgray") |
|
ell.set_linewidth(0) |
|
|
|
## plot individual samples |
|
for Xt in Xts: |
|
xs = [x[0] for x in Xt] |
|
ys = [x[1] for x in Xt] |
|
ax1.scatter(xs, ys, s=0.5, color='lightgray', alpha=0.3, zorder=1) |
|
|
|
ax1.set_ylabel("Antigenic dimension 1") |
|
ax1.set_xlabel("Antigenic dimension 2") |
|
|
|
fig.savefig("bmds-iav-ag-cart.png", dpi=300, bbox_inches='tight') |