Source code for cneuromax.fitting.deeplearning.litmodule.nnmodule.mlp
""":class:`.MLP` & its config."""fromdataclassesimportdataclassfromtypingimportAnnotatedasAnfromeinopsimportrearrangefromjaxtypingimportFloatfromomegaconfimportMISSINGfromtorchimportTensor,nnfromcneuromax.utils.beartypeimportge,lt
[docs]@dataclassclassMLPConfig:"""Holds :class:`MLP` config values. Args: dims: List of dimensions for each layer. p_dropout: Dropout probability. """dims:list[int]=MISSINGp_dropout:An[float,ge(0),lt(1)]=0.0
[docs]classMLP(nn.Module):"""Multi-layer perceptron (MLP). Allows for a variable number of layers, activation functions, and dropout probability. Args: config activation_fn: The singular activation function to use in between each layer. Attributes: model (torch.nn.Sequential): The internal :class:`torch.nn.Module` that holds the MLP's layers. """def__init__(self:"MLP",config:MLPConfig,activation_fn:nn.Module,)->None:super().__init__()self.model=nn.Sequential()foriinrange(len(config.dims)-1):self.model.add_module(name=f"fc_{i}",module=nn.Linear(config.dims[i],config.dims[i+1]),)ifi<len(config.dims)-2:self.model.add_module(name=f"act_{i}",module=activation_fn)ifconfig.p_dropout:# > 0.0:self.model.add_module(name=f"drop_{i}",module=nn.Dropout(config.p_dropout),)
[docs]defforward(self:"MLP",x:Float[Tensor," batch_size *d_input"],)->Float[Tensor," batch_size output_size"]:"""Flattens input's dimensions and passes it through the model. Note: This MLP is currently only capable of returning 1D \ :class:`torch.Tensor` batches. TODO: Add support for returning 2D+ :class:`torch.Tensor` batches. Args: x: The input data batch. Returns: The output batch. """out:Float[Tensor," batch_size flattened_d_input"]=rearrange(x,"batch_size ... -> batch_size (...)",)out:Float[Tensor," batch_size output_size"]=self.model(out)returnout