Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tabnet training #11

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions clustering/tabnet-classifier-tools/.bumpversion.cfg
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
[bumpversion]
current_version = 0.1.0-dev0
commit = True
tag = False
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)(\-(?P<release>[a-z]+)(?P<dev>\d+))?
serialize =
{major}.{minor}.{patch}-{release}{dev}
{major}.{minor}.{patch}

[bumpversion:part:release]
optional_value = _
first_value = dev
values =
dev
_

[bumpversion:part:dev]

[bumpversion:file:pyproject.toml]
search = version = "{current_version}"
replace = version = "{new_version}"

[bumpversion:file:plugin.json]
[bumpversion:file:PytorchTabnet.cwl]
[bumpversion:file:ict.yaml]

[bumpversion:file:VERSION]

[bumpversion:file:src/polus/tabular/clustering/pytorch_tabnet/__init__.py]
1 change: 1 addition & 0 deletions clustering/tabnet-classifier-tools/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# PyTorch TabNet tool(0.1.0-dev0)
20 changes: 20 additions & 0 deletions clustering/tabnet-classifier-tools/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
FROM polusai/bfio:2.3.6

# environment variables defined in polusai/bfio
ENV EXEC_DIR="/opt/executables"
ENV POLUS_IMG_EXT=".ome.tif"
ENV POLUS_TAB_EXT=".arrow"
ENV POLUS_LOG="INFO"

# Work directory defined in the base container
WORKDIR ${EXEC_DIR}

COPY pyproject.toml ${EXEC_DIR}
COPY VERSION ${EXEC_DIR}
COPY README.md ${EXEC_DIR}
COPY src ${EXEC_DIR}/src

RUN pip3 install ${EXEC_DIR} --no-cache-dir

ENTRYPOINT ["python3", "-m", "polus.tabular.clustering.pytorch_tabnet"]
CMD ["--help"]
172 changes: 172 additions & 0 deletions clustering/tabnet-classifier-tools/PytorchTabnet.cwl
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
class: CommandLineTool
cwlVersion: v1.2
inputs:
batchSize:
inputBinding:
prefix: --batchSize
type: string?
catEmbDim:
inputBinding:
prefix: --catEmbDim
type: string?
classifier:
inputBinding:
prefix: --classifier
type: string
clipValue:
inputBinding:
prefix: --clipValue
type: double?
computeImportance:
inputBinding:
prefix: --computeImportance
type: boolean?
deviceName:
inputBinding:
prefix: --deviceName
type: string
dropLast:
inputBinding:
prefix: --dropLast
type: boolean?
epsilon:
inputBinding:
prefix: --epsilon
type: double?
evalMetric:
inputBinding:
prefix: --evalMetric
type: string
filePattern:
inputBinding:
prefix: --filePattern
type: string?
gamma:
inputBinding:
prefix: --gamma
type: double?
groupedFeatures:
inputBinding:
prefix: --groupedFeatures
type: string?
inpDir:
inputBinding:
prefix: --inpDir
type: Directory
lambdaSparse:
inputBinding:
prefix: --lambdaSparse
type: double?
lossFn:
inputBinding:
prefix: --lossFn
type: string
lr:
inputBinding:
prefix: --lr
type: double?
maskType:
inputBinding:
prefix: --maskType
type: string
maxEpochs:
inputBinding:
prefix: --maxEpochs
type: string?
momentum:
inputBinding:
prefix: --momentum
type: double?
nA:
inputBinding:
prefix: --nA
type: string?
nD:
inputBinding:
prefix: --nD
type: string?
nIndepDecoder:
inputBinding:
prefix: --nIndepDecoder
type: string?
nIndependent:
inputBinding:
prefix: --nIndependent
type: string?
nShared:
inputBinding:
prefix: --nShared
type: string?
nSharedDecoder:
inputBinding:
prefix: --nSharedDecoder
type: string?
nSteps:
inputBinding:
prefix: --nSteps
type: string?
numWorkers:
inputBinding:
prefix: --numWorkers
type: string?
optimizerFn:
inputBinding:
prefix: --optimizerFn
type: string
outDir:
inputBinding:
prefix: --outDir
type: Directory
patience:
inputBinding:
prefix: --patience
type: string?
preview:
inputBinding:
prefix: --preview
type: boolean?
schedulerFn:
inputBinding:
prefix: --schedulerFn
type: string
seed:
inputBinding:
prefix: --seed
type: string?
stepSize:
inputBinding:
prefix: --stepSize
type: string?
targetVar:
inputBinding:
prefix: --targetVar
type: string
testSize:
inputBinding:
prefix: --testSize
type: double?
virtualBatchSize:
inputBinding:
prefix: --virtualBatchSize
type: string?
warmStart:
inputBinding:
prefix: --warmStart
type: boolean?
weights:
inputBinding:
prefix: --weights
type: string?
outputs:
outDir:
outputBinding:
glob: $(inputs.outDir.basename)
type: Directory
requirements:
DockerRequirement:
dockerPull: polusai/pytorch-tabnet-tool:0.1.0-dev0
InitialWorkDirRequirement:
listing:
- entry: $(inputs.outDir)
writable: true
InlineJavascriptRequirement: {}
73 changes: 73 additions & 0 deletions clustering/tabnet-classifier-tools/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# PyTorch TabNet tool(v0.1.0-dev0)

This tool uses [tabnet](https://github.com/dreamquark-ai/tabnet/tree/develop), a deep learning model designed for tabular data structured in rows and columns. TabNet is suitable for classification, regression, and multi-task learning.

## Inputs:

### Input data:
The input tabular data that need to be trained. This plugin supports `.csv`, `.feather`and `.arrow` file formats

### Details:

PyTorch-TabNet can be employed for:.

1. TabNetClassifier: For binary and multi-class classification problems
2. TabNetRegressor: For simple and multi-task regression problems
3. TabNetMultiTaskClassifier: multi-task multi-classification problems


## Building

To build the Docker image for the conversion plugin, run
`./build-docker.sh`.

## Install WIPP Plugin

If WIPP is running, navigate to the plugins page and add a new plugin. Paste the contents of `plugin.json` into the pop-up window and submit.
For more information on WIPP, visit the [official WIPP page](https://isg.nist.gov/deepzoomweb/software/wipp).

## Options

This plugin takes 38 input arguments and one output argument:

| Name | Description | I/O | Type |
| ---------------- | --------------------------------------------------------------------------- | ------ | ------------- |
| `--inpdir` | Input tabular data | Input | genericData |
| `--filePattern` | Pattern to parse tabular files | Input | string |
| `--testSize` | Proportion of the dataset to include in the test set | Input | number |
| `--nD` | Width of the decision prediction layer | Input | integer |
| `--nA` | Width of the attention embedding for each mask | Input | integer |
| `--nSteps` | Number of steps in the architecture | Input | integer |
| `--gamma` | Coefficient for feature reuse in the masks | Input | number |
| `--catEmbDim` | List of embedding sizes for each categorical feature | Input | integer |
| `--nIndependent` | Number of independent Gated Linear Unit layers at each step | Input | integer |
| `--nShared` | Number of shared Gated Linear Unit layers at each step | Input | integer |
| `--epsilon` | Constant value | Input | number |
| `--seed` | Random seed for reproducibility | Input | integer |
| `--momentum` | Momentum for batch normalization | Input | number |
| `--clipValue` | Clipping of the gradient value | Input | number |
| `--lambdaSparse` | Extra sparsity loss coefficient | Input | number |
| `--optimizerFn` | Pytorch optimizer function | Input | enum |
| `--lr` | learning rate for the optimizer | Input | number |
| `--schedulerFn` | Parameters used initialize the optimizer | Input | enum |
| `--stepSize` | Parameter to apply to the scheduler_fn | Input | integer |
| `--deviceName` | Platform used for training | Input | enum |
| `--maskType` | A masking function for feature selection | Input | enum |
| `--groupedFeatures` | Allow the model to share attention across features within the same group | Input | integer |
| `--nSharedDecoder` | Number of shared GLU block in decoder | Input | integer |
| `--nIndepDecoder` | Number of independent GLU block in decoder | Input | integer |
| `--evalMetric` | Metrics utilized for early stopping evaluation | Input | enum |
| `--maxEpochs` | Maximum number of epochs for training | Input | integer |
| `--patience` | Consecutive epochs without improvement before early stopping | Input | integer |
| `--weights` | Sampling parameter only for TabNetClassifier | Input | integer |
| `--lossFn` | Loss function | Input | enum |
| `--batchSize` | Batch size | Input | integer |
| `--virtualBatchSize` | Size of mini-batches for Ghost Batch Normalization | Input | integer |
| `--numWorkers` | Number or workers used in torch.utils.data.Dataloader | Input | integer |
| `--dropLast` | Option to drop incomplete last batch during training | Input | boolean |
| `--warmStart` | For scikit-learn compatibility, enabling fitting the same model twice | Input | boolean |
| `--targetVar` | Target feature containing classification labels | Input | string |
| `--computeImportance` | Compute feature importance | Input | boolean |
| `--classifier` | Pytorch tabnet Classifier for training | Input | enum |
| `--preview` | Generate JSON file of sample outputs | Input | boolean |
| `--outdir` | Output collection | Output | genericData |
1 change: 1 addition & 0 deletions clustering/tabnet-classifier-tools/VERSION
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
0.1.0-dev0
4 changes: 4 additions & 0 deletions clustering/tabnet-classifier-tools/build-docker.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
#!/bin/bash

version=$(<VERSION)
docker build . -t polusai/pytorch-tabnet-tool:${version}
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
{
"capital-gain": 0.2208,
"marital-status": 0.1949,
"educational-num": 0.1724,
"hours-per-week": 0.1041,
"age": 0.0789,
"gender": 0.0756,
"occupation": 0.0635,
"relationship": 0.0264,
"capital-loss": 0.0225,
"native-country": 0.0223,
"race": 0.0102,
"fnlwgt": 0.0061,
"education": 0.0016,
"workclass": 0.001
}
Binary file not shown.
Loading
Loading