HoloNet.predicting.mgc_repeat_training#
- HoloNet.predicting.mgc_repeat_training(X, adj, target, repeat_num=50, train_set_ratio=0.85, val_set_ratio=0.15, hidden_num=None, max_epoch=300, lr=0.1, weight_decay=0.0005, step_size=10, gamma=0.9, display_loss=False, only_cell_type=False, hide_repeat_tqdm=False, device='cpu')#
Using cell-type tensor and normalized adjancency matrix as the inputs, repeated training GNN to generate the target gene expression.
- Parameters
- X
Tensor A tensor (cell_num * cell_type_num) with cell-type information. derived from ‘get_continuous_cell_type_tensor’ or ‘get_one_hot_cell_type_tensor’ function.
- adj
Tensor A normalized adjancency matrix derived from ‘adj_normalize’ function.
- target
Tensor The scaled expression tensor of one target gene (cell_num * 1), derived from ‘get_one_case_expr’ function.
- repeat_num
int(default:50) The number of repeated training, defaultly as 50.
- train_set_ratio
float(default:0.85) A value from 0-1. The ratio of cells using as the training set.
- val_set_ratio
float(default:0.15) A value from 0-1. The ratio of cells using as the validation set.
- hidden_num
int|NoneOptional[int] (default:None) The dim of ‘MultiGraphConvolution_Layer’ output. Always use 1 or same as feature_num.
- max_epoch
int(default:300) The maximum epoch of training/
- lr
float(default:0.1) The learning rate.
- weight_decay
float(default:0.0005) The weight decay (L2 penalty)
- step_size
int(default:10) Period of learning rate decay.
- gamma
float(default:0.9) Multiplicative factor of learning rate decay.
- display_loss
bool(default:False) If true, display the loss during training.
- only_cell_type
bool(default:False) If true, the model only use the Feature matrix training target, serving as a baseline model.
- hide_repeat_tqdm
bool(default:False) If true, hide the tqdm for repeated training.
- use_gpu
If true, model will be trained in GPU when GPU is available.
- device
str(default:'cpu') Give a device to use
- X
- Return type
List[MGC_Model]- Returns
: A list of trained MGC model for generating the expression of one target gene.