{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Model training" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import warnings\n", "warnings.filterwarnings(\"ignore\")\n", "from model.train import *\n", "import random" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "data_path = './SVC/' \n", "dataset = 'data/seqfish' \n", "ckpt_folder = f'{data_path}checkpoints/'\n", "device = 'cuda:1'\n", "\n", "seed = 2025\n", "torch.manual_seed(seed) \n", "random.seed(seed) \n", "np.random.seed(seed)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**required model inputs**\n", "\n", "- `train_image`: inputs for *gene-level subcellular localization embedding*\n", "\n", "- `gene2vec_weight`: inputs for *gene-level functional embedding*\n", "\n", "- `train_cell_morphology, train_nuclear_morphology`: inputs for *cell-level morphological embedding*\n", "\n", "- `train_data_location`: inputs for *cell-level positional embedding*\n", "\n", "- `train_cell_cycle_label`(optional): inputs for *cell-level identity embeddings*" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "number of training cells: 157 , number of genes: 1000\n" ] } ], "source": [ "train_seqfish = np.load(f\"{data_path}{dataset}/train_seqfish.npz\") \n", "train_image = train_seqfish[\"data_ori\"]\n", "train_cell_morphology = train_seqfish[\"cell_morphology\"]\n", "train_nuclear_morphology = train_seqfish[\"nuclear_morphology\"]\n", "train_data_location = train_seqfish[\"location\"]\n", "train_cell_cycle_label = train_seqfish[\"identity_label\"]\n", "\n", "train_dataset = SVC_Dataset(\n", " data_ori=train_image,\n", " location=train_data_location,\n", " cell_morphology_vec=train_cell_morphology,\n", " nuclear_morphology_vec=train_nuclear_morphology,\n", " identity_vec=train_cell_cycle_label,\n", ")\n", "print(\"number of training cells:\", len(train_dataset),', number of genes:', train_image.shape[1])\n", "cell_median_train = np.median(train_image.sum((1,2,3)))\n", "train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=4)\n", "read_dir =f'{data_path}{dataset}/gene2vec_weight_seqfish.npy' \n", "gene2vec_weight = torch.from_numpy(np.load(read_dir)).float()" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "model = SVC(\n", " gene2vec_weight = gene2vec_weight,\n", " cell_identity_dim = train_cell_cycle_label.shape[1],\n", ").to(device)\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Training: 0%| | 0/200 [00:00