Model training

import warnings
warnings.filterwarnings("ignore")
from model.train import *
import random
data_path = './SVC/' 
dataset = 'data/seqfish' 
ckpt_folder = f'{data_path}checkpoints/'
device = 'cuda:1'

seed = 2025
torch.manual_seed(seed) 
random.seed(seed) 
np.random.seed(seed)

required model inputs

  • train_image: inputs for gene-level subcellular localization embedding

  • gene2vec_weight: inputs for gene-level functional embedding

  • train_cell_morphology, train_nuclear_morphology: inputs for cell-level morphological embedding

  • train_data_location: inputs for cell-level positional embedding

  • train_cell_cycle_label(optional): inputs for cell-level identity embeddings

train_seqfish = np.load(f"{data_path}{dataset}/train_seqfish.npz") 
train_image = train_seqfish["data_ori"]
train_cell_morphology = train_seqfish["cell_morphology"]
train_nuclear_morphology = train_seqfish["nuclear_morphology"]
train_data_location = train_seqfish["location"]
train_cell_cycle_label = train_seqfish["identity_label"]

train_dataset = SVC_Dataset(
    data_ori=train_image,
    location=train_data_location,
    cell_morphology_vec=train_cell_morphology,
    nuclear_morphology_vec=train_nuclear_morphology,
    identity_vec=train_cell_cycle_label,
)
print("number of training cells:", len(train_dataset),', number of genes:', train_image.shape[1])
cell_median_train = np.median(train_image.sum((1,2,3)))
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=4)
read_dir =f'{data_path}{dataset}/gene2vec_weight_seqfish.npy' 
gene2vec_weight = torch.from_numpy(np.load(read_dir)).float()
number of training cells: 157 , number of genes: 1000
model = SVC(
    gene2vec_weight = gene2vec_weight,
    cell_identity_dim = train_cell_cycle_label.shape[1],
).to(device)
epoch_losses = train_SVC(
    model=model,
    train_loader=train_loader,
    cell_median_train = cell_median_train,
    device=device,
    num_epochs=200,
    ckpt_dir=ckpt_folder,
    ckpt_name="SVC_seqfish"
)
Training:   0%|          | 0/200 [00:00<?, ?epoch/s]
Training: 100%|██████████| 200/200 [07:55<00:00,  2.38s/epoch, loss=38.6997]
Finished training at epoch 200 | best loss 38.1892 at epoch 164