Gene co-localization analysis

In this section, we extract latent representations for each gene within individual cells and perform gene co-localization analysis.

This analysis is used to assess whether SVC learns semantically meaningful and cell-specific gene representations.

import warnings
warnings.filterwarnings("ignore")
from model.train import *
import matplotlib.pyplot as plt
import seaborn as sns
data_path = './SVC/' 
dataset = 'data/seqfish' 
device = 'cuda:1'
train_seqfish = np.load(f"{data_path}{dataset}/train_seqfish.npz") 
train_image = train_seqfish["data_ori"]
train_cell_morphology = train_seqfish["cell_morphology"]
train_nuclear_morphology = train_seqfish["nuclear_morphology"]
train_data_location = train_seqfish["location"]
train_cell_cycle_label = train_seqfish["identity_label"]
train_cell_cycle = train_seqfish["identity"]

train_dataset = SVC_Dataset(
    data_ori=train_image,
    location=train_data_location,
    cell_morphology_vec=train_cell_morphology,
    nuclear_morphology_vec=train_nuclear_morphology,
    identity_vec=train_cell_cycle_label,
)
print("number of training cells:", len(train_dataset),', number of genes:', train_image.shape[1])
cell_median_train = np.median(train_image.sum((1,2,3)))

read_dir =f'{data_path}{dataset}/gene2vec_weight_seqfish.npy'
gene_names = np.loadtxt(f'{data_path}{dataset}/gene_names.txt', dtype=str).tolist()
gene2vec_weight = torch.from_numpy(np.load(read_dir)).float() 
train_loader = DataLoader(train_dataset, batch_size = 16, shuffle = False, num_workers = 4)
number of training cells: 157 , number of genes: 1000

load the trained model and extract latent representations

ckpt_dir = f"{data_path}/checkpoints/"
ckpt = torch.load(ckpt_dir +'SVC_seqfish.pth', map_location=device)

state = ckpt.get("model_state_dict", ckpt)  
state = {k.removeprefix("module."): v for k, v in state.items()}

model = SVC(
    gene2vec_weight = gene2vec_weight,
    cell_identity_dim = train_cell_cycle_label.shape[1],
).to(device)

model.load_state_dict(state)

latent_np = extract_latent_embeddings(
    model=model,
    loader=train_loader,
    device=device
)

# np.save(f"{data_path}/output/seqfish/embeddings_train.npy", latent_np)

gene co-localization analysis for 49 well-annotated genes

# seqfish+ genes in extended data fig 3
protrusionList = ['Cyb5r3', 'Sh3pxd2a', 'Ddr2', 'Kif1c', 'Kctd10', 
                    'Dynll2', 'Arhgap11a', 'Dync1li2', 'Palld', 'Naa50']

nuclearList = ['Col1a1', 'Fn1', 'Fbln2', 'Col6a2', 'Bgn',  
                 'Nid1', 'Lox', 'P4hb', 'Aebp1', 'Emp1',
                 'Col5a1', 'Sdc4', 'Postn', 'Col3a1', 'Pdia6',
                 'Col5a2', 'Itgb1', 'Calu', 'Pdia3']
                 
cytoplasmList = ['Ddb1', 'Myh9', 'Actn1', 'Tagln2', 'Kpnb1',
                   'Hnrnpf', 'Ppp1ca', 'Hnrnpl', 'Pcbp1', 'Tagln', 
                   'Fscn1', 'Psat1', 'Cald1', 'Snd1', 'Uba1',
                   'Hnrnpm', 'Cap1', 'Ssrp1', 'Ugdh', 'Caprin1']

nuclearSub1List = ['Thbs1', 'Serpine1', 'Ptx3', 'Tnc']
nuclearSub2List = ['Flna', 'Flnb', 'Plec', 'Flnc', 'Iqgap1']
nuclearSub3List = ['Numa1', 'Gorasp2']
print(len(nuclearList),len(cytoplasmList),len(protrusionList))
19 20 10

calculate gene colocalization score based on pairwise distances of each gene pair within cells

gene_interested = [gene for gene in cytoplasmList + nuclearList + protrusionList] 
train_count_sum = train_image.sum(axis=(2,3))

gene_colocal_score = cal_gene_colocal_score(latent_np, gene_names, train_count_sum, gene_interested)
  0%|          | 0/157 [00:00<?, ?it/s]
100%|██████████| 157/157 [00:01<00:00, 147.50it/s]

A heatmap is used to visualize the gene co-localization matrix inferred by SVC, revealing three distinct spatial localization clusters among the 49 genes.

custom_cmap = create_color_cmap('#FEFEE1','#9bc07c','#081D58', name='white_to_red')

fig, ax = plt.subplots(figsize=(15, 10))

mask = np.triu(np.ones_like(gene_colocal_score, dtype=bool))
gene_interested_capitalized = [gene for gene in gene_interested]
heatmap = sns.heatmap(gene_colocal_score, cmap=custom_cmap,xticklabels=gene_interested_capitalized, yticklabels=gene_interested_capitalized,mask=mask, cbar_kws={'shrink': 0.45,'pad':-0.05,'anchor':(-0.7,0.9),"location":"right"}, # 调整颜色条大小 mask=mask,#
                   square=True) 
heatmap.set_xticklabels(heatmap.get_xticklabels(), fontweight='bold',fontsize=13,fontstyle='italic')
heatmap.set_yticklabels(heatmap.get_xticklabels(), fontweight='bold',fontsize=13,fontstyle='italic')

cbar = heatmap.collections[0].colorbar
cbar.ax.set_position([1, 0.1, 0.9, 0.2])

cbar.ax.tick_params(labelsize=15)
cbar.set_label('Gene co-localization score', rotation=270, fontsize=18,loc='top',labelpad=25)##,fontweight='bold' proportion of cells in which the gene pair is considered as the 10 nearest neighbors
cbar_ax = cbar.ax
text_obj = cbar_ax.get_yaxis().get_label()

text_obj.set_position((-0.3, 0.1))
ax.plot([0, 20], [20, 20], color='lightgray', linestyle='--',linewidth=6)
ax.plot([0, 39], [39, 39], color='lightgray', linestyle='--',linewidth=6)
ax.plot([20, 20], [20, 49], color='lightgray', linestyle='--',linewidth=6)
ax.plot([39, 39], [39, 49], color='lightgray', linestyle='--',linewidth=6)

y_bottom=0
ax.text(15, y_bottom+9, 'Cytoplasm', va='center', ha='right', color='black', fontsize=22, fontweight='bold',rotation=315)
ax.text(38, y_bottom+29, 'Nuclear/Perinuclear', va='center', ha='right', color='black', fontsize=22, fontweight='bold',rotation=315)
ax.text(48.5, y_bottom+43, 'Protrusion', va='center', ha='right', color='black', fontsize=22, fontweight='bold',rotation=315)

plt.tight_layout()
# fig.savefig(f"{data_path}output/seqfish/figures/gene_colocalization_49.png",
#             transparent=True,      
#             dpi=300,
#             bbox_inches="tight")  
plt.show()
../../_images/bbfc49803616f044d6c501d03637aad6f7a91e60520d94514c6d4f17a38dcffb.png

Further visualizing observed transcript profiles within original cells for three representative co-localization gene pairs.

print("most co-localized gene with Ddb1:",gene_interested[gene_colocal_score[gene_interested.index('Ddb1'),:].argsort()[::-1][0]])
print("\n")
print("most co-localized gene with Col1a1:",gene_interested[gene_colocal_score[gene_interested.index('Col1a1'),:].argsort()[::-1][0]])
print("\n")
print("most co-localized gene with Cyb5r3:",gene_interested[gene_colocal_score[gene_interested.index('Cyb5r3'),:].argsort()[::-1][0]])
most co-localized gene with Ddb1: Myh9


most co-localized gene with Col1a1: Fn1


most co-localized gene with Cyb5r3: Sh3pxd2a
seqfish_data = pd.read_pickle(f"{data_path}{dataset}/seqfish_data_dict.pkl")['data_df']
print(seqfish_data.head())
df_cell_contour = pd.read_pickle(f"{data_path}{dataset}/cell_mask_contour_preprocessed.pkl")
print(df_cell_contour.head())
df_nuclear_contour = pd.read_pickle(f"{data_path}{dataset}/nuclear_mask_contour_preprocessed.pkl")
print(df_nuclear_contour.head())
             x           y           gene cell nucleus batch  umi  centerX  \
0  1217.437557  557.583252  4933401b06rik  5-0      -1     0    1     1003   
1  1096.190309  394.835294  4933401b06rik  5-0       5     0    1     1003   
2  1093.189494  572.832405  4933401b06rik  5-0      -1     0    1     1003   
3  1005.120220  297.196271  4933401b06rik  5-0      -1     0    1     1003   
4  1142.815026  378.376491  4933401b06rik  5-0      -1     0    1     1003   

   centerY        type  sc_total  
0      425  fibroblast     32224  
1      425  fibroblast     32224  
2      425  fibroblast     32224  
3      425  fibroblast     32224  
4      425  fibroblast     32224  
  cell    x    y  centerX  centerY  direction_vec  distance_to_center
0  0-0  521  496     1079      724         -158.0          602.783543
1  0-0  519  496     1079      724         -158.0          604.635427
2  0-0  517  494     1079      724         -157.5          607.242950
3  0-0  515  494     1079      724         -158.0          609.094410
4  0-0  513  492     1079      724         -157.5          611.702542
  cell     x    y  centerX  centerY
0  0-0  1051  631     1079      724
1  0-0  1050  632     1079      724
2  0-0  1049  632     1079      724
3  0-0  1048  632     1079      724
4  0-0  1047  632     1079      724
batch='0'
preserve_idx_seqfish_data =[]
for i in seqfish_data.cell:
    if i.split('-')[1] == batch:
        preserve_idx_seqfish_data.append(True)
    else:
        preserve_idx_seqfish_data.append(False)
        
preserve_idx_seqfish_data = np.array(preserve_idx_seqfish_data)

preserve_idx_cell_mask_contour =[]
for i in df_cell_contour.cell:
    if i.split('-')[1] == batch:
        preserve_idx_cell_mask_contour.append(True)
    else:
        preserve_idx_cell_mask_contour.append(False)
        
preserve_idx_cell_mask_contour = np.array(preserve_idx_cell_mask_contour)

preserve_idx_nuclear_mask_contour =[]
for i in df_nuclear_contour.cell:
    if i.split('-')[1] == batch:
        preserve_idx_nuclear_mask_contour.append(True)
    else:
        preserve_idx_nuclear_mask_contour.append(False)
        
preserve_idx_nuclear_mask_contour = np.array(preserve_idx_nuclear_mask_contour)
seqfish_data_i = seqfish_data[preserve_idx_seqfish_data]
cell_mask_contour_i= df_cell_contour[preserve_idx_cell_mask_contour]
nuclear_mask_contour_i = df_nuclear_contour[preserve_idx_nuclear_mask_contour]

fig, ax = plt.subplots(1, 3, figsize=(19, 6), gridspec_kw={'hspace': 0., 'wspace': 0.01})

gene_pair0 = ['Ddb1', 'Myh9']
seqfish_data_i_gene1 = seqfish_data_i[seqfish_data_i.gene == gene_pair0[0]]
seqfish_data_i_gene2 = seqfish_data_i[seqfish_data_i.gene == gene_pair0[1]]

gene_pair1 = ['Col1a1', 'Fn1']
seqfish_data_i_gene3 = seqfish_data_i[seqfish_data_i.gene == gene_pair1[0]]
seqfish_data_i_gene4 = seqfish_data_i[seqfish_data_i.gene == gene_pair1[1]]

gene_pair2 =['Cyb5r3', 'Sh3pxd2a']
seqfish_data_i_gene5 = seqfish_data_i[seqfish_data_i.gene == gene_pair2[0]]
seqfish_data_i_gene6 = seqfish_data_i[seqfish_data_i.gene == gene_pair2[1]]

ax[0].scatter(seqfish_data_i_gene1['x'], seqfish_data_i_gene1['y'], label=gene_pair0[0], alpha=0.8, s=3,color='#d6594c')
ax[0].scatter(seqfish_data_i_gene2['x'], seqfish_data_i_gene2['y'], label=gene_pair0[1], alpha=0.8, s=3,color='#7dabcf')
ax[1].scatter(seqfish_data_i_gene3['x'], seqfish_data_i_gene3['y'], label=gene_pair1[0], alpha=0.8, s=3,color='#d6594c')
ax[1].scatter(seqfish_data_i_gene4['x'], seqfish_data_i_gene4['y'], label=gene_pair1[1], alpha=0.8, s=3,color='#7dabcf')
ax[2].scatter(seqfish_data_i_gene5['x'], seqfish_data_i_gene5['y'], label=gene_pair2[0], alpha=0.8, s=3,color='#d6594c')
ax[2].scatter(seqfish_data_i_gene6['x'], seqfish_data_i_gene6['y'], label=gene_pair2[1], alpha=0.8, s=3,color='#7dabcf')

for i in range(3):
    ax[i].axis('off')
    ax[i].scatter(nuclear_mask_contour_i['x'], nuclear_mask_contour_i['y'], alpha=0.6, s=0.5,color='darkgrey')

    ax[i].scatter(cell_mask_contour_i['x'], cell_mask_contour_i['y'], color='grey', s=2)
    ax[i].scatter(cell_mask_contour_i['centerX'], cell_mask_contour_i['centerY'], marker='+', color='dimgray', s=100)
    leg = ax[i].legend(fontsize=20,markerscale=5,bbox_to_anchor=(0.45, 0.72),frameon=False,handletextpad=0.1)
    for t in leg.get_texts():     
        t.set_fontstyle('italic')
    ax[i].set_aspect('equal')

# plt.savefig(f'{data_path}/output/seqfish/figures/gene_colocalization_examples_batch{batch}.png', dpi=300, bbox_inches='tight',transparent=True)
plt.show()
../../_images/eb4775232dfd4704c8de0202ef4783209123112bc3f7f9e13a1033c484050a7d.png