Model training
import warnings
warnings.filterwarnings("ignore")
from model.train import *
import random
data_path = './SVC/'
dataset = 'data/seqfish'
ckpt_folder = f'{data_path}checkpoints/'
device = 'cuda:1'
seed = 2025
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)
required model inputs
train_image: inputs for gene-level subcellular localization embeddinggene2vec_weight: inputs for gene-level functional embeddingtrain_cell_morphology, train_nuclear_morphology: inputs for cell-level morphological embeddingtrain_data_location: inputs for cell-level positional embeddingtrain_cell_cycle_label(optional): inputs for cell-level identity embeddings
train_seqfish = np.load(f"{data_path}{dataset}/train_seqfish.npz")
train_image = train_seqfish["data_ori"]
train_cell_morphology = train_seqfish["cell_morphology"]
train_nuclear_morphology = train_seqfish["nuclear_morphology"]
train_data_location = train_seqfish["location"]
train_cell_cycle_label = train_seqfish["identity_label"]
train_dataset = SVC_Dataset(
data_ori=train_image,
location=train_data_location,
cell_morphology_vec=train_cell_morphology,
nuclear_morphology_vec=train_nuclear_morphology,
identity_vec=train_cell_cycle_label,
)
print("number of training cells:", len(train_dataset),', number of genes:', train_image.shape[1])
cell_median_train = np.median(train_image.sum((1,2,3)))
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=4)
read_dir =f'{data_path}{dataset}/gene2vec_weight_seqfish.npy'
gene2vec_weight = torch.from_numpy(np.load(read_dir)).float()
number of training cells: 157 , number of genes: 1000
model = SVC(
gene2vec_weight = gene2vec_weight,
cell_identity_dim = train_cell_cycle_label.shape[1],
).to(device)
epoch_losses = train_SVC(
model=model,
train_loader=train_loader,
cell_median_train = cell_median_train,
device=device,
num_epochs=200,
ckpt_dir=ckpt_folder,
ckpt_name="SVC_seqfish"
)
Training: 0%| | 0/200 [00:00<?, ?epoch/s]
Training: 100%|██████████| 200/200 [07:55<00:00, 2.38s/epoch, loss=38.6997]
Finished training at epoch 200 | best loss 38.1892 at epoch 164