Spatial domain detection using subcellularly informed representations

import warnings
warnings.filterwarnings("ignore")
import torch 
import numpy as np
import random
import pandas as pd
import os
from tqdm import tqdm
from scipy.sparse import lil_matrix
from scipy.sparse import save_npz, load_npz
import matplotlib.pyplot as plt
from scipy.sparse import csr_matrix 
from scipy.sparse.linalg import eigsh 
from sklearn.cluster import KMeans
import scanpy as sc
from model.utils import *
data_path = './SVC/' 
dataset = 'data/xenium_mouse_brain' 
device = 'cuda:0'
gene_names = np.loadtxt(f'{data_path}{dataset}/gene_names.txt', dtype=str)
location_rep1 =  np.load(f"{data_path}{dataset}/cell_location_rep1.npy") 
location_rep2 = np.load(f"{data_path}{dataset}/cell_location_rep2.npy")

due to the large size of tissue-level subcellular ST data, we can construct chunked sparse latent cell-cell adjacency matrices to avoid memory overflow

cosine_simi_rep_1_part_1 = np.load(f'{data_path}output/xenium_mouse_brain/cosine_simi_scaling_rep1_part_1.npz')["cosine_simi_scaling"]
cosine_simi_rep_1_part_2 = np.load(f'{data_path}output/xenium_mouse_brain/cosine_simi_scaling_rep1_part_2.npz')["cosine_simi_scaling"]
cosine_simi_rep_1_part_3 = np.load(f'{data_path}output/xenium_mouse_brain/cosine_simi_scaling_rep1_part_3.npz')["cosine_simi_scaling"]

# cosine_simi_rep_2_part_1 = np.load(f'{data_path}output/xenium_mouse_brain/cosine_simi_scaling_rep2_part_1.npz')["cosine_simi_scaling"]
# cosine_simi_rep_2_part_2 = np.load(f'{data_path}output/xenium_mouse_brain/cosine_simi_scaling_rep2_part_2.npz')["cosine_simi_scaling"]
# cosine_simi_rep_2_part_3 = np.load(f'{data_path}output/xenium_mouse_brain/cosine_simi_scaling_rep2_part_3.npz')["cosine_simi_scaling"]
### Convert to sparse matrices csr_D1, csr_D2, csr_D3 for replicate 1; csr_D4, csr_D5, csr_D6 for replicate 2

csr_D_1 = convert_to_csr1(cosine_simi_rep_1_part_1)
csr_D_2 = convert_to_csr1(cosine_simi_rep_1_part_2)
csr_D_3 = convert_to_csr1(cosine_simi_rep_1_part_3)

# csr_D_4 = convert_to_csr1(cosine_simi_rep_2_part_1)
# csr_D_5 = convert_to_csr1(cosine_simi_rep_2_part_2)
# csr_D_6 = convert_to_csr1(cosine_simi_rep_2_part_3)

### Convert to sparse matrices csr_D_12, csr_D_13, csr_D_23 for replicate 1; csr_D_45, csr_D_46, csr_D_56 for replicate 2; 

# csr_D_12 = convert_to_csr2(cosine_simi_rep_1_part_1, cosine_simi_rep_1_part_2)
# csr_D_13 = convert_to_csr2(cosine_simi_rep_1_part_1, cosine_simi_rep_1_part_3)
# csr_D_23 = convert_to_csr2(cosine_simi_rep_1_part_2, cosine_simi_rep_1_part_3)

# csr_D_45 = convert_to_csr2(cosine_simi_rep_2_part_1, cosine_simi_rep_2_part_2)
# csr_D_46 = convert_to_csr2(cosine_simi_rep_2_part_1, cosine_simi_rep_2_part_3)
# csr_D_56 = convert_to_csr2(cosine_simi_rep_2_part_2, cosine_simi_rep_2_part_3)

### Convert to sparse matrices csr_D_14, csr_D_15, csr_D_16, csr_D_24, csr_D_25, csr_D_26, csr_D_34, csr_D_35, csr_D_36 for cross-replicate

# csr_D_14 = convert_to_csr2(cosine_simi_rep_1_part_1, cosine_simi_rep_2_part_1)
# csr_D_15 = convert_to_csr2(cosine_simi_rep_1_part_1, cosine_simi_rep_2_part_2)
# csr_D_16 = convert_to_csr2(cosine_simi_rep_1_part_1, cosine_simi_rep_2_part_3)

# csr_D_24 = convert_to_csr2(cosine_simi_rep_1_part_2, cosine_simi_rep_2_part_1)
# csr_D_25 = convert_to_csr2(cosine_simi_rep_1_part_2, cosine_simi_rep_2_part_2)
# csr_D_26 = convert_to_csr2(cosine_simi_rep_1_part_2, cosine_simi_rep_2_part_3)

# csr_D_34 = convert_to_csr2(cosine_simi_rep_1_part_3, cosine_simi_rep_2_part_1)
# csr_D_35 = convert_to_csr2(cosine_simi_rep_1_part_3, cosine_simi_rep_2_part_2)
# csr_D_36 = convert_to_csr2(cosine_simi_rep_1_part_3, cosine_simi_rep_2_part_3)
100%|██████████| 30848/30848 [02:03<00:00, 249.21it/s] 
k: 10
Processed 0/30848 rows
Processed 5000/30848 rows
Processed 10000/30848 rows
Processed 15000/30848 rows
Processed 20000/30848 rows
Processed 25000/30848 rows
Processed 30000/30848 rows
100%|██████████| 30720/30720 [02:06<00:00, 243.72it/s] 
k: 10
Processed 0/30720 rows
Processed 5000/30720 rows
Processed 10000/30720 rows
Processed 15000/30720 rows
Processed 20000/30720 rows
Processed 25000/30720 rows
Processed 30000/30720 rows
100%|██████████| 31879/31879 [02:12<00:00, 241.46it/s] 
k: 10
Processed 0/31879 rows
Processed 5000/31879 rows
Processed 10000/31879 rows
Processed 15000/31879 rows
Processed 20000/31879 rows
Processed 25000/31879 rows
Processed 30000/31879 rows
save_npz(f'{data_path}output/xenium_mouse_brain/domain_detection/csr_D_1', csr_D_1)
save_npz(f'{data_path}output/xenium_mouse_brain/domain_detection/csr_D_2', csr_D_2)
save_npz(f'{data_path}output/xenium_mouse_brain/domain_detection/csr_D_3', csr_D_3)
# save_npz(f'{data_path}output/xenium_mouse_brain/domain_detection/csr_D_4', csr_D_4)
# save_npz(f'{data_path}output/xenium_mouse_brain/domain_detection/csr_D_5', csr_D_5)
# save_npz(f'{data_path}output/xenium_mouse_brain/domain_detection/csr_D_6', csr_D_6)

# save_npz(f'{data_path}output/xenium_mouse_brain/domain_detection/csr_D_12', csr_D_12)
# save_npz(f'{data_path}output/xenium_mouse_brain/domain_detection/csr_D_13', csr_D_13)
# save_npz(f'{data_path}output/xenium_mouse_brain/domain_detection/csr_D_23', csr_D_23)
# save_npz(f'{data_path}output/xenium_mouse_brain/domain_detection/csr_D_14', csr_D_14)
# save_npz(f'{data_path}output/xenium_mouse_brain/domain_detection/csr_D_15', csr_D_15)
# save_npz(f'{data_path}output/xenium_mouse_brain/domain_detection/csr_D_16', csr_D_16)

# save_npz(f'{data_path}output/xenium_mouse_brain/domain_detection/csr_D_24', csr_D_24)
# save_npz(f'{data_path}output/xenium_mouse_brain/domain_detection/csr_D_25', csr_D_25)
# save_npz(f'{data_path}output/xenium_mouse_brain/domain_detection/csr_D_26', csr_D_26)

# save_npz(f'{data_path}output/xenium_mouse_brain/domain_detection/csr_D_34', csr_D_34)
# save_npz(f'{data_path}output/xenium_mouse_brain/domain_detection/csr_D_35', csr_D_35)
# save_npz(f'{data_path}output/xenium_mouse_brain/domain_detection/csr_D_36', csr_D_36)

# save_npz(f'{data_path}output/xenium_mouse_brain/domain_detection/csr_D_45', csr_D_45)
# save_npz(f'{data_path}output/xenium_mouse_brain/domain_detection/csr_D_46', csr_D_46)
# save_npz(f'{data_path}output/xenium_mouse_brain/domain_detection/csr_D_56', csr_D_56)

integrate all the csr matrices to construct the full cell-cell adjacency matrix

from pathlib import Path
from scipy.sparse import load_npz, bmat

base = Path(data_path) / "output/xenium_mouse_brain/domain_detection"

n_blocks = 6

blocks_dict = {}

for i in range(1, n_blocks + 1):
    blocks_dict[(i, i)] = load_npz(base / f"csr_D_{i}.npz")

for i in range(1, n_blocks + 1):
    for j in range(i + 1, n_blocks + 1):
        blocks_dict[(i, j)] = load_npz(base / f"csr_D_{i}{j}.npz")

blocks = [[None for _ in range(n_blocks)] for _ in range(n_blocks)]

for i in range(1, n_blocks + 1):
    for j in range(1, n_blocks + 1):
        if i == j:
            blocks[i - 1][j - 1] = blocks_dict[(i, i)]
        elif i < j:
            blocks[i - 1][j - 1] = blocks_dict[(i, j)]
        else:
            blocks[i - 1][j - 1] = blocks_dict[(j, i)].T

csr_matrices = bmat(blocks, format="csr")

print("shape of csr_matrices after integration:", csr_matrices.shape)
shape of csr_matrices after integration: (182067, 182067)
def find_min_positive_per_row(matrix: csr_matrix, top_k: int = 10):
    matrix = matrix.tocsr()
    S = lil_matrix((matrix.shape[0], matrix.shape[1]))
    for i in tqdm(range(matrix.shape[0])):
        row_start = matrix.indptr[i]
        row_end = matrix.indptr[i+1]
        row_data = matrix.data[row_start:row_end]
        row_cols = matrix.indices[row_start:row_end]
        
        positive_mask = row_data > 0
        positive_data = row_data[positive_mask]
        positive_cols = row_cols[positive_mask]
        
        
        sorted_indices = np.argsort(positive_data)[:top_k]
        S[i,positive_cols[sorted_indices]] = 1
    S = S.tocsr()
    S = S.multiply(S.T)
    return S

S = find_min_positive_per_row(csr_matrices)
save_npz(f'{data_path}output/xenium_mouse_brain/domain_detection/adjacency_matrix_SVC.npz', S)
100%|██████████| 182067/182067 [00:07<00:00, 23361.38it/s]

built a spatial cell-cell adjacency matrix based on the Euclidean distances between cells’ actual spatial coordinates

location = np.vstack((location_rep1, location_rep2))
N = location.shape[0]
S_spatial0 = lil_matrix((N, N))
k = 10
location_tensor = torch.Tensor(location).to(device)
for i in tqdm(range(N)):
    distances = location_tensor[i]-location_tensor
    distances = torch.sqrt(torch.sum(distances**2, axis=1))
    indices = torch.argsort(distances)[:k+1] 
    indices = indices[indices != i].cpu().numpy().copy() 
    S_spatial0[i, indices] = 1 

S_spatial0 = S_spatial0.tocsr()
S_spatial = S_spatial0.multiply(S_spatial0.T) 
print("shape of S_spatial:", S_spatial.shape)
save_npz(f'{data_path}output/xenium_mouse_brain/domain_detection/adjacency_matrix_spatial.npz', S_spatial)
100%|██████████| 182067/182067 [00:40<00:00, 4478.51it/s]
shape of S_spatial: (182067, 182067)

combine S and S_spatial to conduct spectral clustering to detect spatial domains

S = load_npz(f'{data_path}output/xenium_mouse_brain/domain_detection/adjacency_matrix_SVC.npz')
S_spatial = load_npz(f'{data_path}output/xenium_mouse_brain/domain_detection/adjacency_matrix_spatial.npz')
k=20
seed = 888 
np.random.seed(seed)  

S_total = S + S_spatial 
D_total = np.array(S_total.sum(axis=1)).flatten()
N = S_total.shape[0]
L = csr_matrix((D_total, (np.arange(N), np.arange(N)))) - S_total
D_inv_sqrt = np.sqrt(1.0 / D_total)
D_inv_sqrt_matrix = csr_matrix((D_inv_sqrt, (np.arange(N), np.arange(N))))
L_norm = D_inv_sqrt_matrix @ L @ D_inv_sqrt_matrix

v0 = np.random.rand(L_norm.shape[0])  
eigenvalues, eigenvectors = eigsh(L_norm, k=k, which='SM', v0=v0) 
U = eigenvectors[D_total!=0]
H_total = U / np.linalg.norm(U, axis=1, keepdims=True)
kmeans = KMeans(n_clusters=k, random_state=seed)#
labels_total = np.zeros(N)
labels_total[D_total==0] = -1
labels_total[D_total!=0] = kmeans.fit_predict(H_total)

np.savetxt(f'{data_path}output/xenium_mouse_brain/domain_detection/domain_labels.txt', labels_total)
isolation: 4
labels: [ 4.  4.  4. ... 19. 19. 19.]
labels_total = np.loadtxt(f'{data_path}output/xenium_mouse_brain/domain_detection/domain_labels.txt')
labels_total_rep1 = labels_total[:len(location_rep1)]
labels_total_rep2 = labels_total[len(location_rep1):]
np.savetxt(f'{data_path}output/xenium_mouse_brain/domain_detection/domain_labels_rep1.txt', labels_total_rep1)
np.savetxt(f'{data_path}output/xenium_mouse_brain/domain_detection/domain_labels_rep2.txt', labels_total_rep2)
labels: [ 4.  4.  4. ... 19. 19. 19.]
import matplotlib.pyplot as plt
cmap1 = plt.get_cmap('Dark2')
colors = ("#F7BEEC",'#284d88',"#BA9B6C","#AA87CD","#cbc0d3",
          '#83A759','#D1C301',"#ADA78F",'#E0A28C',"#748092",
          "#A9CBCA", "#ffb703","#FE8317",'#2A9D8F',"#d96256",
          '#BEE3A6',"#FAD471","#219ebc","#8ed7e6",'#0B5E53')

plt.figure(figsize=(10, 1.5))
plt.bar(range(len(colors)), height=1, color=colors)
__ = plt.axis('off')
../../../_images/c6db866147dafb9938cf5f1a8b1ac57821ec610ebc95b049fb2458d389b50d74.png
fig, ax = plt.subplots(1, 2, figsize=(16, 6), gridspec_kw={'wspace': 0.05})
unique_labels = np.unique(labels_total[labels_total != -1])

for i in range(len(unique_labels)):
    idx1 = np.where(labels_total_rep1 == i)
    idx2 = np.where(labels_total_rep2 == i)
    ax[0].scatter(location_rep1[:,0][idx1], location_rep1[:,1][idx1], s=1, alpha=0.7,label=f'{int(i)}',c=colors[int(i)])
    ax[1].scatter(location_rep2[:,0][idx2], location_rep2[:,1][idx2], s=1, alpha=0.7,label=f'{int(i)}',c=colors[int(i)])

ax[0].set_title('Section 1',fontsize=20,pad=10)
ax[0].set_xlim([location_rep1[:,0].min(), location_rep1[:,0].max()])
ax[0].set_ylim([location_rep1[:,1].min(), location_rep1[:,1].max()])


ax[1].set_title('Section 2',fontsize=20,pad=10)
ax[1].set_xlim([location_rep2[:,0].min(), location_rep2[:,0].max()])
ax[1].set_ylim([location_rep2[:,1].min(), location_rep2[:,1].max()])


ax[0].axis('off')
ax[1].axis('off')
plt.legend(title="Spatial domain", loc='upper right',bbox_to_anchor=(1.25, 1.1),markerscale=8,fontsize=11,frameon=False)

plt.setp(plt.gca().get_legend().get_title(), fontsize=12)

plt.show()
93447
88620
../../../_images/5ac0163fd40abc0a9374ad6e7dca3fc80b5b5a727d97b92e42863055314ef721.png

visualize marker gene expression in each spatial domain

adata_rep1 = sc.read_10x_h5(f"{data_path}{dataset}/cell_feature_matrix_rep1.h5")
adata_rep2 = sc.read_10x_h5(f"{data_path}{dataset}/cell_feature_matrix_rep2.h5")
print(adata_rep1)
print(adata_rep2)

cell_names_rep1 =  np.loadtxt(f"{data_path}{dataset}/cell_names_rep1.txt", dtype=str)
cell_names_rep2 =  np.loadtxt(f"{data_path}{dataset}/cell_names_rep2.txt", dtype=str) 
cell_to_labels_rep1 = dict(zip(cell_names_rep1, labels_total_rep1))
cell_to_labels_rep2 = dict(zip(cell_names_rep2, labels_total_rep2))
adata_rep1.obs['labels'] = adata_rep1.obs.index.map(cell_to_labels_rep1)
adata_rep2.obs['labels'] = adata_rep2.obs.index.map(cell_to_labels_rep2)
adata = adata_rep1.concatenate(adata_rep2, batch_key='replicate', batch_categories=['replicate1', 'replicate2'])
print(adata)
AnnData object with n_obs × n_vars = 162033 × 248
    var: 'gene_ids', 'feature_types', 'genome'
AnnData object with n_obs × n_vars = 154654 × 248
    var: 'gene_ids', 'feature_types', 'genome'
AnnData object with n_obs × n_vars = 316687 × 248
    obs: 'labels', 'replicate'
    var: 'gene_ids', 'feature_types', 'genome'
marker_gene_names = ['Rasgrf2','Rorb','Fezf2','Rprm','Cplx3','Wfs1','Prss35','Prox1']
mean_expression_count_per_label = np.zeros((len(marker_gene_names),20))
for j in range(len(marker_gene_names)):
    transcripts_truth_j = adata[:,adata.var.index == marker_gene_names[j]]
    for i in range(20):
        transcripts_truth_j_i = transcripts_truth_j[transcripts_truth_j.obs['labels'] == i]
        transcripts_truth_j_i_X = transcripts_truth_j_i.X
        transcripts_truth_j_i_X_log1p = np.log1p(transcripts_truth_j_i_X)
        mean_expression_count_per_label[j,i] = transcripts_truth_j_i_X_log1p.mean()
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.gridspec import GridSpec
fig = plt.figure(figsize=(9, 12)) #12.5, 18)) 
gs = GridSpec(4, 2, width_ratios=[1, 1], wspace=0.3, hspace=0.5)
selected_domains = np.arange(20) 
k= 0 
axes = []
x_positions = np.arange(20) 
cmap = plt.cm.get_cmap('tab20b')          
colors = [cmap(i) for i in range(10)]  


for i in range(4):     
    for j in range(2):     
      
        ax = fig.add_subplot(gs[i, j])             
        axes.append(ax)
        heights = mean_expression_count_per_label[k]
        heights_top = np.argsort(heights)[::-1][:1]
        heights_not_top = np.argsort(heights)[::-1][1:]
        # 绘制垂直线
        ax.vlines(
            x=x_positions,          
            ymin=0,                 
            ymax=heights,          
            colors=colors[k],      
            linewidth=4,       
            label='Ground truth'   
        )

        ax.scatter(x=x_positions[heights_not_top],y=heights[heights_not_top], edgecolor=colors[k], s=150,marker='^',facecolor='none',linewidth=1.5)
        ax.scatter(x=x_positions[heights_top],y=heights[heights_top], color=colors[k], s=150,marker='^')
   
        ax.set_xticks([0,1,2,3,4,5,6,7,8,10,12,14,16,18])
        ax.set_xticklabels([0,1,2,3,4,5,6,7,8,10,12,14,16,18], fontsize=16)
        y_labels =ax.get_yticks()
        y_labels_positive = [i for i in y_labels if i >= 0]
        if marker_gene_names[k] =='Prss35':
            y_labels_positive = [0,0.1,0.2,0.3,0.4,0.5,0.6]
        elif marker_gene_names[k] =='Myl4':
            y_labels_positive = [0,0.2,0.4,0.6,0.8,1.0,1.2]
        ax.set_yticks(y_labels_positive)
        ax.set_yticklabels(y_labels_positive, fontsize=16)
        
        ax.set_title(f'{marker_gene_names[k]}', fontsize=22, fontstyle='italic')
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)
        ax.spines['left'].set_linewidth(2)
        ax.spines['bottom'].set_linewidth(2)
        k+=1

plt.show()
../../../_images/b60e664e511da65d25f48cc4ec8b76aa5afb2f2a00e0a2958693f819b0fe9f3c.png