Current generation of model running

It's a simple 1 look ahead RNN. Currently set to be a LSTM
Need to update the metrics to also include high-magnitude prediction accuracy

Signed-off-by: Ethan Wellenreiter <ewellenreiter@gmail.com>
This commit is contained in:
Ethan Wellenreiter 2025-06-23 13:23:12 -04:00
parent 56d78d0973
commit 552404d998
2 changed files with 927 additions and 0 deletions

734
full_regression_model.ipynb Normal file
View File

@ -0,0 +1,734 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 160,
"id": "98beda53",
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import pandas as pd\n",
"import matplotlib.pyplot as plt\n",
"import scipy.stats as stats\n",
"from scipy.stats import genextreme\n",
"from scipy.stats import genpareto\n",
"import requests\n",
"import json\n",
"\n",
"# from datetime import date\n",
"# from datetime import datetime\n",
"# from datetime import timedelta\n",
"# import pytz\n",
"\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.optim as optim\n",
"import lightning as L\n",
"\n",
"import os\n"
]
},
{
"cell_type": "markdown",
"id": "f15607d7",
"metadata": {},
"source": [
"### Import data"
]
},
{
"cell_type": "code",
"execution_count": 161,
"id": "57505637",
"metadata": {},
"outputs": [],
"source": [
"df = pd.read_csv('reduced_data.csv')\n",
"# df['time'] = pd.to_datetime(df['time'], format='ISO8601')\n",
"\n",
"# df['lastupdate'] = pd.to_datetime(df['lastupdate'], format='ISO8601')"
]
},
{
"cell_type": "code",
"execution_count": 162,
"id": "053cc158",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>bin</th>\n",
" <th>time_to_next_event</th>\n",
" <th>time</th>\n",
" <th>lat</th>\n",
" <th>lon</th>\n",
" <th>depth</th>\n",
" <th>mag</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>27</td>\n",
" <td>3408</td>\n",
" <td>1.262306e+09</td>\n",
" <td>37.5835</td>\n",
" <td>22.1704</td>\n",
" <td>15.0</td>\n",
" <td>2.8</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>21</td>\n",
" <td>495</td>\n",
" <td>1.262310e+09</td>\n",
" <td>42.8370</td>\n",
" <td>13.0010</td>\n",
" <td>9.4</td>\n",
" <td>2.1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>22</td>\n",
" <td>1377</td>\n",
" <td>1.262310e+09</td>\n",
" <td>46.1624</td>\n",
" <td>12.2834</td>\n",
" <td>10.0</td>\n",
" <td>2.3</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>50</td>\n",
" <td>274</td>\n",
" <td>1.262312e+09</td>\n",
" <td>26.4100</td>\n",
" <td>99.9100</td>\n",
" <td>10.0</td>\n",
" <td>5.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>25</td>\n",
" <td>11</td>\n",
" <td>1.262312e+09</td>\n",
" <td>39.0807</td>\n",
" <td>31.0495</td>\n",
" <td>7.0</td>\n",
" <td>2.5</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1062937</th>\n",
" <td>22</td>\n",
" <td>146</td>\n",
" <td>1.749481e+09</td>\n",
" <td>39.2261</td>\n",
" <td>28.9878</td>\n",
" <td>10.8</td>\n",
" <td>2.3</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1062938</th>\n",
" <td>40</td>\n",
" <td>166</td>\n",
" <td>1.749481e+09</td>\n",
" <td>17.5500</td>\n",
" <td>-94.3070</td>\n",
" <td>186.8</td>\n",
" <td>4.1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1062939</th>\n",
" <td>30</td>\n",
" <td>36</td>\n",
" <td>1.749481e+09</td>\n",
" <td>35.3752</td>\n",
" <td>-3.6249</td>\n",
" <td>25.7</td>\n",
" <td>3.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1062940</th>\n",
" <td>20</td>\n",
" <td>25</td>\n",
" <td>1.749481e+09</td>\n",
" <td>39.2152</td>\n",
" <td>29.0178</td>\n",
" <td>5.4</td>\n",
" <td>2.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1062941</th>\n",
" <td>32</td>\n",
" <td>173</td>\n",
" <td>1.749481e+09</td>\n",
" <td>14.8400</td>\n",
" <td>119.4300</td>\n",
" <td>16.0</td>\n",
" <td>3.2</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>1062942 rows × 7 columns</p>\n",
"</div>"
],
"text/plain": [
" bin time_to_next_event time lat lon depth mag\n",
"0 27 3408 1.262306e+09 37.5835 22.1704 15.0 2.8\n",
"1 21 495 1.262310e+09 42.8370 13.0010 9.4 2.1\n",
"2 22 1377 1.262310e+09 46.1624 12.2834 10.0 2.3\n",
"3 50 274 1.262312e+09 26.4100 99.9100 10.0 5.0\n",
"4 25 11 1.262312e+09 39.0807 31.0495 7.0 2.5\n",
"... ... ... ... ... ... ... ...\n",
"1062937 22 146 1.749481e+09 39.2261 28.9878 10.8 2.3\n",
"1062938 40 166 1.749481e+09 17.5500 -94.3070 186.8 4.1\n",
"1062939 30 36 1.749481e+09 35.3752 -3.6249 25.7 3.0\n",
"1062940 20 25 1.749481e+09 39.2152 29.0178 5.4 2.0\n",
"1062941 32 173 1.749481e+09 14.8400 119.4300 16.0 3.2\n",
"\n",
"[1062942 rows x 7 columns]"
]
},
"execution_count": 162,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df"
]
},
{
"cell_type": "markdown",
"id": "b17d72f3",
"metadata": {},
"source": [
"### Dataset loader"
]
},
{
"cell_type": "code",
"execution_count": 163,
"id": "e6dacf0d",
"metadata": {},
"outputs": [],
"source": [
"class TimeseriesDataset(torch.utils.data.Dataset):\n",
" def __init__(self, x, y, seq_len=100):\n",
" self.seq_len = seq_len\n",
" self.x = x\n",
" self.y = y\n",
" \n",
" def __len__(self):\n",
" return self.x.__len__() - self.seq_len\n",
" \n",
" def __getitem__(self, index):\n",
" return (self.x[index:index+self.seq_len], self.y[index+self.seq_len])\n"
]
},
{
"cell_type": "code",
"execution_count": 164,
"id": "f42c537e",
"metadata": {},
"outputs": [],
"source": [
"division_p1 = int(len(df)*0.6)\n",
"division_p2 = int(len(df)*0.8)"
]
},
{
"cell_type": "code",
"execution_count": 165,
"id": "2181b1ad",
"metadata": {},
"outputs": [],
"source": [
"x_columns = ['time_to_next_event','time','lat','lon','depth','mag']\n",
"y_columns = ['time_to_next_event', 'lat', 'lon','mag']"
]
},
{
"cell_type": "code",
"execution_count": 166,
"id": "10a8ac71",
"metadata": {},
"outputs": [],
"source": [
"x = torch.tensor(df[x_columns].values).to(torch.float)\n",
"# y1 = torch.tensor(df[['bin']].values)\n",
"y = torch.tensor(df[y_columns].values).to(torch.float)"
]
},
{
"cell_type": "code",
"execution_count": 167,
"id": "ee26bfb2",
"metadata": {},
"outputs": [],
"source": [
"train_x = x[:division_p1]\n",
"val_x = x[division_p1:division_p2]\n",
"test_x = x[division_p2:]\n",
"\n",
"train_y = y[:division_p1]\n",
"val_y = y[division_p1:division_p2]\n",
"test_y = y[division_p2:]"
]
},
{
"cell_type": "code",
"execution_count": 168,
"id": "4a739cb6",
"metadata": {},
"outputs": [],
"source": [
"train_dataset = TimeseriesDataset(train_x, train_y, seq_len=1000)\n",
"train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=10, shuffle=True)\n",
"\n",
"val_dataset = TimeseriesDataset(val_x, val_y, seq_len=1000)\n",
"val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=10, shuffle=True)"
]
},
{
"cell_type": "markdown",
"id": "119f29ee",
"metadata": {},
"source": [
"### Helper classes\n"
]
},
{
"cell_type": "code",
"execution_count": 169,
"id": "973b9c65",
"metadata": {},
"outputs": [],
"source": [
"## Copied from https://stackoverflow.com/questions/50817916/how-do-i-add-lstm-gru-or-other-recurrent-layers-to-a-sequential-in-pytorch\n",
"class SelectItem(nn.Module):\n",
" def __init__(self, item_index):\n",
" super(SelectItem, self).__init__()\n",
" self._name = 'selectitem'\n",
" self.item_index = item_index\n",
"\n",
" def forward(self, inputs):\n",
" return inputs[self.item_index]"
]
},
{
"cell_type": "markdown",
"id": "971e8319",
"metadata": {},
"source": [
"### Model design"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "92b3b995",
"metadata": {},
"outputs": [],
"source": [
"class MultiRNNLightning(L.LightningModule):\n",
" def __init__(self, input_size, hidden_size, output_size):\n",
" super().__init__()\n",
" self.model = nn.Sequential(nn.LSTM(input_size, hidden_size, batch_first=True),\n",
" SelectItem(1),\n",
" SelectItem(0),\n",
" nn.ReLU(),\n",
" nn.Linear(hidden_size, output_size))\n",
" \n",
" # self.model = nn.GRU(input_size, hidden_size, batch_first=True)\n",
" \n",
" # self.device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
" # self.model.to(self.device)\n",
"\n",
" def forward(self, x):\n",
" # print(\"hi\")\n",
" return torch.squeeze(self.model(x), 0)\n",
" \n",
" def configure_optimizers(self):\n",
" optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-3)\n",
" scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(\n",
" optimizer, mode=\"min\", factor=0.1, patience=5\n",
" )\n",
" \n",
" return {\n",
" \"optimizer\": optimizer,\n",
" \"lr_scheduler\": {\n",
" \"scheduler\": scheduler,\n",
" \"monitor\": \"val_loss\",\n",
" },\n",
" }\n",
" \n",
" def training_step(self, batch, batch_idx):\n",
" x, y = batch\n",
" \n",
" # x.to(self.device)\n",
" # y.to(self.device)\n",
" \n",
" # print(\"x shape\", x.shape)\n",
" y_hat = self(x)\n",
" # print(y_hat.shape)\n",
" # print(\"y shape\", y.shape)\n",
" # return 0\n",
" # print(\"yhat_shape\", y_hat.shape)\n",
" loss = nn.functional.mse_loss(y_hat, y)\n",
" self.log('train_loss', loss) #, on_epoch=True)\n",
" return loss\n",
" \n",
" def validation_step(self, val_batch, batch_idx):\n",
" x, y = val_batch\n",
" \n",
" # x.to(self.device)\n",
" # y.to(self.device)\n",
" \n",
" \n",
" y_hat = self(x)\n",
" loss = nn.functional.mse_loss(y_hat, y)\n",
" self.log('val_loss', loss) #, on_epoch=True)\n",
" return loss\n",
"\n",
" # def training_step(self, batch, batch_idx):\n",
" # x, y = batch\n",
" # class_y_hat, regress_y_hat = self(x)\n",
" \n",
" # loss = F.cross_entropy(y_hat, y)\n",
" # acc = (y_hat.argmax(1) == y).float().mean()\n",
" \n",
" # self.log(\"train_loss\", loss)\n",
" # self.log(\"train_acc\", acc)\n",
" # return loss\n",
" \n",
"\n"
]
},
{
"cell_type": "markdown",
"id": "741ce0f7",
"metadata": {},
"source": [
"#### Model Params"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "886809fd",
"metadata": {},
"outputs": [],
"source": [
"input_size = len(x_columns)\n",
"hidden_size = 1000\n",
"output_size = len(y_columns)"
]
},
{
"cell_type": "markdown",
"id": "25598f51",
"metadata": {},
"source": [
"### Train model"
]
},
{
"cell_type": "code",
"execution_count": 172,
"id": "61fa862f",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"GPU available: True (cuda), used: True\n",
"TPU available: False, using: 0 TPU cores\n",
"HPU available: False, using: 0 HPUs\n"
]
}
],
"source": [
"model = MultiRNNLightning(input_size, hidden_size, output_size)\n",
"\n",
"trainer = L.Trainer(max_epochs=10)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e847d0c2",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n",
"\n",
" | Name | Type | Params | Mode \n",
"---------------------------------------------\n",
"0 | model | Sequential | 167 K | train\n",
"---------------------------------------------\n",
"167 K Trainable params\n",
"0 Non-trainable params\n",
"167 K Total params\n",
"0.669 Total estimated model params size (MB)\n",
"6 Modules in train mode\n",
"0 Modules in eval mode\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "adec569b050341aab3b2c94e22152976",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Sanity Checking: | | 0/? [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"c:\\Users\\ewell\\anaconda3\\envs\\test_space\\Lib\\site-packages\\lightning\\pytorch\\trainer\\connectors\\data_connector.py:476: Your `val_dataloader`'s sampler has shuffling enabled, it is strongly recommended that you turn shuffling off for val/test dataloaders.\n",
"c:\\Users\\ewell\\anaconda3\\envs\\test_space\\Lib\\site-packages\\lightning\\pytorch\\trainer\\connectors\\data_connector.py:425: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.\n",
"c:\\Users\\ewell\\anaconda3\\envs\\test_space\\Lib\\site-packages\\lightning\\pytorch\\trainer\\connectors\\data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "b2c2156c3db24dadbeff025679a1e7b4",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Training: | | 0/? [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "80970987a81d4b059981dce44b21cdc8",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Validation: | | 0/? [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "d21a94f8276d4dea8ecd543ff7d5bdd2",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Validation: | | 0/? [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "5c1b8ebb789e4c80ad4ec452c9ccbbda",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Validation: | | 0/? [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "88b213664aaa4bc7bf57b4e5cfff3ba7",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Validation: | | 0/? [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "26637bbce8b2455bb72a6268f2e2196c",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Validation: | | 0/? [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "c067cd1c645d49c8a11cae04e3d676dc",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Validation: | | 0/? [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "eb0dc33095f841c7a30376c0bdb0fce0",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Validation: | | 0/? [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "12667695d03e4baf9ab468a6b57349e0",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Validation: | | 0/? [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "e65958ac8567425d9e9722f6023f8660",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Validation: | | 0/? [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "c6ff7598384d45c3bd192822cd0070b0",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Validation: | | 0/? [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"`Trainer.fit` stopped: `max_epochs=10` reached.\n"
]
}
],
"source": [
"trainer.fit(model, train_loader, val_loader)"
]
},
{
"cell_type": "code",
"execution_count": 174,
"id": "5d146588",
"metadata": {},
"outputs": [],
"source": [
"# %reload_ext tensorboard\n",
"# %tensorboard --logdir=lightning_logs/"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.11"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

193
rnn_regression_model.py Normal file
View File

@ -0,0 +1,193 @@
#%%
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import scipy.stats as stats
from scipy.stats import genextreme
from scipy.stats import genpareto
from argparse import ArgumentParser
import pathlib
import torch
import torch.nn as nn
import torch.optim as optim
import lightning as L
import os
#%%
class TimeseriesDataset(torch.utils.data.Dataset):
def __init__(self, x, y, seq_len=100):
self.seq_len = seq_len
self.x = x
self.y = y
def __len__(self):
return self.x.__len__() - self.seq_len
def __getitem__(self, index):
return (self.x[index:index+self.seq_len], self.y[index+self.seq_len])
#%%
## Copied from https://stackoverflow.com/questions/50817916/how-do-i-add-lstm-gru-or-other-recurrent-layers-to-a-sequential-in-pytorch
class SelectItem(nn.Module):
def __init__(self, item_index):
super(SelectItem, self).__init__()
self._name = 'selectitem'
self.item_index = item_index
def forward(self, inputs):
return inputs[self.item_index]
#%%
class MultiRNNLightning(L.LightningModule):
def __init__(self, input_size, output_size, args, conf):
super().__init__()
self.model = nn.Sequential(nn.LSTM(input_size, args.hidden_size, batch_first=True, bidirectional=args.b),
SelectItem(1),
SelectItem(0),
nn.ReLU(),
nn.Linear(args.hidden_size*(2 if args.b else 1), output_size))
self.save_hyperparameters(conf)
# self.model = nn.GRU(input_size, hidden_size, batch_first=True)
# self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# self.model.to(self.device)
@staticmethod
def add_model_specific_args(parent_parser):
parser = ArgumentParser(parents=[parent_parser], add_help=False)
parser.add_argument('--hidden_size', type=int, default=100)
parser.add_argument('-b', action='store_true')
return parser
def forward(self, x):
# print("hi")
return torch.squeeze(self.model(x), 0)
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer, mode="min", factor=0.1, patience=5
)
return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": scheduler,
"monitor": "val_loss",
},
}
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = nn.functional.mse_loss(y_hat, y)
# print(batch)
self.log('train_loss', loss, prog_bar=True) #, on_epoch=True)
self.log('training_loss_per_unit', loss/y.shape[0])
return loss
def validation_step(self, val_batch, batch_idx):
x, y = val_batch
y_hat = self(x)
loss = nn.functional.mse_loss(y_hat, y)
self.log('val_loss', loss, prog_bar=True) #, on_epoch=True)
# print(val_batch)
self.log('val_loss_per_unit', loss/y.shape[0])
return loss
def test_step(self, test_batch, batch_idx):
x, y = test_batch
y_hat = self(x)
loss = nn.functional.mse_loss(y_hat, y)
self.log('test_loss', loss, prog_bar=True) #, on_epoch=True)
self.log('test_loss_per_unit', loss/y.shape[0])
return loss
#%%
parser = ArgumentParser()
#%%
parser.add_argument('--val_size', type=float, default=0.2)
parser.add_argument('--test_size', type=float, default=0.2)
parser.add_argument('--seq_size', type=int, default=1000)
parser.add_argument('--data_dir', type=pathlib.Path, default=pathlib.Path("./data"))
parser.add_argument('--batch_size', type=int, default=10)
parser.add_argument('-x', type=str, nargs='+', default=['time_to_next_event','time','lat','lon','depth','mag'])
parser.add_argument('-y', type=str, nargs='+', default=['time_to_next_event','lat','lon','depth','mag'])
parser = MultiRNNLightning.add_model_specific_args(parser)
# parser = L.Trainer.add_argparse_args(parser)
parser.add_argument('--max_epochs', type=int, default=1)
#%%
args = parser.parse_args()
#%%
datadir = args.data_dir / 'reduced_data.csv'
print(datadir.as_posix())
df = pd.read_csv(datadir.as_posix())
#%%
val_size = int(args.val_size*len(df))
test_size = int(args.test_size*len(df))
train_size = len(df)-val_size-test_size
assert val_size >= 0, "validation size is less than 0"
assert test_size >= 0, "test size is less than 0"
assert train_size > 0, "train size is non-positive"
#%%
x_columns = args.x
y_columns = args.y
x = torch.tensor(df[x_columns].values).to(torch.float)
y = torch.tensor(df[y_columns].values).to(torch.float)
mag_index = -1 if 'mag' not in y_columns else y_columns.find('mag')
train_x, val_x, test_x = torch.split(x, [train_size, val_size, test_size])
train_y, val_y, test_y = torch.split(y, [train_size, val_size, test_size])
train_dataset = TimeseriesDataset(train_x, train_y, seq_len=args.seq_size)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
val_dataset = TimeseriesDataset(val_x, val_y, seq_len=args.seq_size)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False)
test_dataset = TimeseriesDataset(test_x, test_y, seq_len=args.seq_size)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False)
#%%
input_size = len(x_columns)
output_size = len(y_columns)
#%%
model = MultiRNNLightning(input_size, output_size, args, {"input_dim":input_size, "output_dim":output_size, "hidden_dim": args.hidden_size, "bidirectional": args.b, "batch_size":args.batch_size, "seq_len":args.seq_size, "input_features": args.x, "output_features": args.y})
trainer = L.Trainer(max_epochs=args.max_epochs)
#%%
trainer.fit(model, train_loader, val_loader)
#%%
trainer.test(model, test_loader)