Compare commits
5 Commits
e52ee45261
...
28096a932a
| Author | SHA1 | Date | |
|---|---|---|---|
| 28096a932a | |||
| 4596d15c0a | |||
| f527929ba1 | |||
| 552404d998 | |||
| 56d78d0973 |
2
.gitignore
vendored
Normal file
2
.gitignore
vendored
Normal file
@ -0,0 +1,2 @@
|
||||
/data
|
||||
/lightning_logs
|
||||
15
README.md
15
README.md
@ -1,3 +1,18 @@
|
||||
# A repository for exploration into earthquake forecasting
|
||||
|
||||
|
||||
### Things to do
|
||||
1. Update the metric for the RNN to include an accuracy for high magnitude stuff
|
||||
2. Try farther look ahead models and/or just an amount of time ahead
|
||||
- can consider something like an auto-transformer type build where it feeds it's prediction back into itself until it's next prediction is outside of a time?
|
||||
|
||||
#### Frame the issue
|
||||
|
||||
Need to care more about the rare events. Let's say above mag 8. or mag 6.
|
||||
|
||||
|
||||
##### Run command
|
||||
(current best performance)
|
||||
C:/Users/ewell/anaconda3/envs/test_space/python.exe d:/projects/earthquake_prediction_exploration/rnn_regression_model.py --max_epochs 50 -x time_to_next_event lat lon depth mag --hidden_size 1000 --seq_size 100 --batch_size 1000
|
||||
|
||||
|
||||
|
||||
1117
data_acquisition/normal_data_handler.ipynb
Normal file
1117
data_acquisition/normal_data_handler.ipynb
Normal file
File diff suppressed because one or more lines are too long
734
full_regression_model.ipynb
Normal file
734
full_regression_model.ipynb
Normal file
@ -0,0 +1,734 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 160,
|
||||
"id": "98beda53",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import numpy as np\n",
|
||||
"import pandas as pd\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"import scipy.stats as stats\n",
|
||||
"from scipy.stats import genextreme\n",
|
||||
"from scipy.stats import genpareto\n",
|
||||
"import requests\n",
|
||||
"import json\n",
|
||||
"\n",
|
||||
"# from datetime import date\n",
|
||||
"# from datetime import datetime\n",
|
||||
"# from datetime import timedelta\n",
|
||||
"# import pytz\n",
|
||||
"\n",
|
||||
"import torch\n",
|
||||
"import torch.nn as nn\n",
|
||||
"import torch.optim as optim\n",
|
||||
"import lightning as L\n",
|
||||
"\n",
|
||||
"import os\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "f15607d7",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Import data"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 161,
|
||||
"id": "57505637",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"df = pd.read_csv('reduced_data.csv')\n",
|
||||
"# df['time'] = pd.to_datetime(df['time'], format='ISO8601')\n",
|
||||
"\n",
|
||||
"# df['lastupdate'] = pd.to_datetime(df['lastupdate'], format='ISO8601')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 162,
|
||||
"id": "053cc158",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<div>\n",
|
||||
"<style scoped>\n",
|
||||
" .dataframe tbody tr th:only-of-type {\n",
|
||||
" vertical-align: middle;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe tbody tr th {\n",
|
||||
" vertical-align: top;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe thead th {\n",
|
||||
" text-align: right;\n",
|
||||
" }\n",
|
||||
"</style>\n",
|
||||
"<table border=\"1\" class=\"dataframe\">\n",
|
||||
" <thead>\n",
|
||||
" <tr style=\"text-align: right;\">\n",
|
||||
" <th></th>\n",
|
||||
" <th>bin</th>\n",
|
||||
" <th>time_to_next_event</th>\n",
|
||||
" <th>time</th>\n",
|
||||
" <th>lat</th>\n",
|
||||
" <th>lon</th>\n",
|
||||
" <th>depth</th>\n",
|
||||
" <th>mag</th>\n",
|
||||
" </tr>\n",
|
||||
" </thead>\n",
|
||||
" <tbody>\n",
|
||||
" <tr>\n",
|
||||
" <th>0</th>\n",
|
||||
" <td>27</td>\n",
|
||||
" <td>3408</td>\n",
|
||||
" <td>1.262306e+09</td>\n",
|
||||
" <td>37.5835</td>\n",
|
||||
" <td>22.1704</td>\n",
|
||||
" <td>15.0</td>\n",
|
||||
" <td>2.8</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>1</th>\n",
|
||||
" <td>21</td>\n",
|
||||
" <td>495</td>\n",
|
||||
" <td>1.262310e+09</td>\n",
|
||||
" <td>42.8370</td>\n",
|
||||
" <td>13.0010</td>\n",
|
||||
" <td>9.4</td>\n",
|
||||
" <td>2.1</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>2</th>\n",
|
||||
" <td>22</td>\n",
|
||||
" <td>1377</td>\n",
|
||||
" <td>1.262310e+09</td>\n",
|
||||
" <td>46.1624</td>\n",
|
||||
" <td>12.2834</td>\n",
|
||||
" <td>10.0</td>\n",
|
||||
" <td>2.3</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>3</th>\n",
|
||||
" <td>50</td>\n",
|
||||
" <td>274</td>\n",
|
||||
" <td>1.262312e+09</td>\n",
|
||||
" <td>26.4100</td>\n",
|
||||
" <td>99.9100</td>\n",
|
||||
" <td>10.0</td>\n",
|
||||
" <td>5.0</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>4</th>\n",
|
||||
" <td>25</td>\n",
|
||||
" <td>11</td>\n",
|
||||
" <td>1.262312e+09</td>\n",
|
||||
" <td>39.0807</td>\n",
|
||||
" <td>31.0495</td>\n",
|
||||
" <td>7.0</td>\n",
|
||||
" <td>2.5</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>...</th>\n",
|
||||
" <td>...</td>\n",
|
||||
" <td>...</td>\n",
|
||||
" <td>...</td>\n",
|
||||
" <td>...</td>\n",
|
||||
" <td>...</td>\n",
|
||||
" <td>...</td>\n",
|
||||
" <td>...</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>1062937</th>\n",
|
||||
" <td>22</td>\n",
|
||||
" <td>146</td>\n",
|
||||
" <td>1.749481e+09</td>\n",
|
||||
" <td>39.2261</td>\n",
|
||||
" <td>28.9878</td>\n",
|
||||
" <td>10.8</td>\n",
|
||||
" <td>2.3</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>1062938</th>\n",
|
||||
" <td>40</td>\n",
|
||||
" <td>166</td>\n",
|
||||
" <td>1.749481e+09</td>\n",
|
||||
" <td>17.5500</td>\n",
|
||||
" <td>-94.3070</td>\n",
|
||||
" <td>186.8</td>\n",
|
||||
" <td>4.1</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>1062939</th>\n",
|
||||
" <td>30</td>\n",
|
||||
" <td>36</td>\n",
|
||||
" <td>1.749481e+09</td>\n",
|
||||
" <td>35.3752</td>\n",
|
||||
" <td>-3.6249</td>\n",
|
||||
" <td>25.7</td>\n",
|
||||
" <td>3.0</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>1062940</th>\n",
|
||||
" <td>20</td>\n",
|
||||
" <td>25</td>\n",
|
||||
" <td>1.749481e+09</td>\n",
|
||||
" <td>39.2152</td>\n",
|
||||
" <td>29.0178</td>\n",
|
||||
" <td>5.4</td>\n",
|
||||
" <td>2.0</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>1062941</th>\n",
|
||||
" <td>32</td>\n",
|
||||
" <td>173</td>\n",
|
||||
" <td>1.749481e+09</td>\n",
|
||||
" <td>14.8400</td>\n",
|
||||
" <td>119.4300</td>\n",
|
||||
" <td>16.0</td>\n",
|
||||
" <td>3.2</td>\n",
|
||||
" </tr>\n",
|
||||
" </tbody>\n",
|
||||
"</table>\n",
|
||||
"<p>1062942 rows × 7 columns</p>\n",
|
||||
"</div>"
|
||||
],
|
||||
"text/plain": [
|
||||
" bin time_to_next_event time lat lon depth mag\n",
|
||||
"0 27 3408 1.262306e+09 37.5835 22.1704 15.0 2.8\n",
|
||||
"1 21 495 1.262310e+09 42.8370 13.0010 9.4 2.1\n",
|
||||
"2 22 1377 1.262310e+09 46.1624 12.2834 10.0 2.3\n",
|
||||
"3 50 274 1.262312e+09 26.4100 99.9100 10.0 5.0\n",
|
||||
"4 25 11 1.262312e+09 39.0807 31.0495 7.0 2.5\n",
|
||||
"... ... ... ... ... ... ... ...\n",
|
||||
"1062937 22 146 1.749481e+09 39.2261 28.9878 10.8 2.3\n",
|
||||
"1062938 40 166 1.749481e+09 17.5500 -94.3070 186.8 4.1\n",
|
||||
"1062939 30 36 1.749481e+09 35.3752 -3.6249 25.7 3.0\n",
|
||||
"1062940 20 25 1.749481e+09 39.2152 29.0178 5.4 2.0\n",
|
||||
"1062941 32 173 1.749481e+09 14.8400 119.4300 16.0 3.2\n",
|
||||
"\n",
|
||||
"[1062942 rows x 7 columns]"
|
||||
]
|
||||
},
|
||||
"execution_count": 162,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"df"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "b17d72f3",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Dataset loader"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 163,
|
||||
"id": "e6dacf0d",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class TimeseriesDataset(torch.utils.data.Dataset):\n",
|
||||
" def __init__(self, x, y, seq_len=100):\n",
|
||||
" self.seq_len = seq_len\n",
|
||||
" self.x = x\n",
|
||||
" self.y = y\n",
|
||||
" \n",
|
||||
" def __len__(self):\n",
|
||||
" return self.x.__len__() - self.seq_len\n",
|
||||
" \n",
|
||||
" def __getitem__(self, index):\n",
|
||||
" return (self.x[index:index+self.seq_len], self.y[index+self.seq_len])\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 164,
|
||||
"id": "f42c537e",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"division_p1 = int(len(df)*0.6)\n",
|
||||
"division_p2 = int(len(df)*0.8)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 165,
|
||||
"id": "2181b1ad",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"x_columns = ['time_to_next_event','time','lat','lon','depth','mag']\n",
|
||||
"y_columns = ['time_to_next_event', 'lat', 'lon','mag']"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 166,
|
||||
"id": "10a8ac71",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"x = torch.tensor(df[x_columns].values).to(torch.float)\n",
|
||||
"# y1 = torch.tensor(df[['bin']].values)\n",
|
||||
"y = torch.tensor(df[y_columns].values).to(torch.float)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 167,
|
||||
"id": "ee26bfb2",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"train_x = x[:division_p1]\n",
|
||||
"val_x = x[division_p1:division_p2]\n",
|
||||
"test_x = x[division_p2:]\n",
|
||||
"\n",
|
||||
"train_y = y[:division_p1]\n",
|
||||
"val_y = y[division_p1:division_p2]\n",
|
||||
"test_y = y[division_p2:]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 168,
|
||||
"id": "4a739cb6",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"train_dataset = TimeseriesDataset(train_x, train_y, seq_len=1000)\n",
|
||||
"train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=10, shuffle=True)\n",
|
||||
"\n",
|
||||
"val_dataset = TimeseriesDataset(val_x, val_y, seq_len=1000)\n",
|
||||
"val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=10, shuffle=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "119f29ee",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Helper classes\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 169,
|
||||
"id": "973b9c65",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"## Copied from https://stackoverflow.com/questions/50817916/how-do-i-add-lstm-gru-or-other-recurrent-layers-to-a-sequential-in-pytorch\n",
|
||||
"class SelectItem(nn.Module):\n",
|
||||
" def __init__(self, item_index):\n",
|
||||
" super(SelectItem, self).__init__()\n",
|
||||
" self._name = 'selectitem'\n",
|
||||
" self.item_index = item_index\n",
|
||||
"\n",
|
||||
" def forward(self, inputs):\n",
|
||||
" return inputs[self.item_index]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "971e8319",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Model design"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "92b3b995",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class MultiRNNLightning(L.LightningModule):\n",
|
||||
" def __init__(self, input_size, hidden_size, output_size):\n",
|
||||
" super().__init__()\n",
|
||||
" self.model = nn.Sequential(nn.LSTM(input_size, hidden_size, batch_first=True),\n",
|
||||
" SelectItem(1),\n",
|
||||
" SelectItem(0),\n",
|
||||
" nn.ReLU(),\n",
|
||||
" nn.Linear(hidden_size, output_size))\n",
|
||||
" \n",
|
||||
" # self.model = nn.GRU(input_size, hidden_size, batch_first=True)\n",
|
||||
" \n",
|
||||
" # self.device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
||||
" # self.model.to(self.device)\n",
|
||||
"\n",
|
||||
" def forward(self, x):\n",
|
||||
" # print(\"hi\")\n",
|
||||
" return torch.squeeze(self.model(x), 0)\n",
|
||||
" \n",
|
||||
" def configure_optimizers(self):\n",
|
||||
" optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-3)\n",
|
||||
" scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(\n",
|
||||
" optimizer, mode=\"min\", factor=0.1, patience=5\n",
|
||||
" )\n",
|
||||
" \n",
|
||||
" return {\n",
|
||||
" \"optimizer\": optimizer,\n",
|
||||
" \"lr_scheduler\": {\n",
|
||||
" \"scheduler\": scheduler,\n",
|
||||
" \"monitor\": \"val_loss\",\n",
|
||||
" },\n",
|
||||
" }\n",
|
||||
" \n",
|
||||
" def training_step(self, batch, batch_idx):\n",
|
||||
" x, y = batch\n",
|
||||
" \n",
|
||||
" # x.to(self.device)\n",
|
||||
" # y.to(self.device)\n",
|
||||
" \n",
|
||||
" # print(\"x shape\", x.shape)\n",
|
||||
" y_hat = self(x)\n",
|
||||
" # print(y_hat.shape)\n",
|
||||
" # print(\"y shape\", y.shape)\n",
|
||||
" # return 0\n",
|
||||
" # print(\"yhat_shape\", y_hat.shape)\n",
|
||||
" loss = nn.functional.mse_loss(y_hat, y)\n",
|
||||
" self.log('train_loss', loss) #, on_epoch=True)\n",
|
||||
" return loss\n",
|
||||
" \n",
|
||||
" def validation_step(self, val_batch, batch_idx):\n",
|
||||
" x, y = val_batch\n",
|
||||
" \n",
|
||||
" # x.to(self.device)\n",
|
||||
" # y.to(self.device)\n",
|
||||
" \n",
|
||||
" \n",
|
||||
" y_hat = self(x)\n",
|
||||
" loss = nn.functional.mse_loss(y_hat, y)\n",
|
||||
" self.log('val_loss', loss) #, on_epoch=True)\n",
|
||||
" return loss\n",
|
||||
"\n",
|
||||
" # def training_step(self, batch, batch_idx):\n",
|
||||
" # x, y = batch\n",
|
||||
" # class_y_hat, regress_y_hat = self(x)\n",
|
||||
" \n",
|
||||
" # loss = F.cross_entropy(y_hat, y)\n",
|
||||
" # acc = (y_hat.argmax(1) == y).float().mean()\n",
|
||||
" \n",
|
||||
" # self.log(\"train_loss\", loss)\n",
|
||||
" # self.log(\"train_acc\", acc)\n",
|
||||
" # return loss\n",
|
||||
" \n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "741ce0f7",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### Model Params"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "886809fd",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"input_size = len(x_columns)\n",
|
||||
"hidden_size = 1000\n",
|
||||
"output_size = len(y_columns)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "25598f51",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Train model"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 172,
|
||||
"id": "61fa862f",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"GPU available: True (cuda), used: True\n",
|
||||
"TPU available: False, using: 0 TPU cores\n",
|
||||
"HPU available: False, using: 0 HPUs\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"model = MultiRNNLightning(input_size, hidden_size, output_size)\n",
|
||||
"\n",
|
||||
"trainer = L.Trainer(max_epochs=10)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "e847d0c2",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n",
|
||||
"\n",
|
||||
" | Name | Type | Params | Mode \n",
|
||||
"---------------------------------------------\n",
|
||||
"0 | model | Sequential | 167 K | train\n",
|
||||
"---------------------------------------------\n",
|
||||
"167 K Trainable params\n",
|
||||
"0 Non-trainable params\n",
|
||||
"167 K Total params\n",
|
||||
"0.669 Total estimated model params size (MB)\n",
|
||||
"6 Modules in train mode\n",
|
||||
"0 Modules in eval mode\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "adec569b050341aab3b2c94e22152976",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"Sanity Checking: | | 0/? [00:00<?, ?it/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"c:\\Users\\ewell\\anaconda3\\envs\\test_space\\Lib\\site-packages\\lightning\\pytorch\\trainer\\connectors\\data_connector.py:476: Your `val_dataloader`'s sampler has shuffling enabled, it is strongly recommended that you turn shuffling off for val/test dataloaders.\n",
|
||||
"c:\\Users\\ewell\\anaconda3\\envs\\test_space\\Lib\\site-packages\\lightning\\pytorch\\trainer\\connectors\\data_connector.py:425: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.\n",
|
||||
"c:\\Users\\ewell\\anaconda3\\envs\\test_space\\Lib\\site-packages\\lightning\\pytorch\\trainer\\connectors\\data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "b2c2156c3db24dadbeff025679a1e7b4",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"Training: | | 0/? [00:00<?, ?it/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "80970987a81d4b059981dce44b21cdc8",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"Validation: | | 0/? [00:00<?, ?it/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "d21a94f8276d4dea8ecd543ff7d5bdd2",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"Validation: | | 0/? [00:00<?, ?it/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "5c1b8ebb789e4c80ad4ec452c9ccbbda",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"Validation: | | 0/? [00:00<?, ?it/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "88b213664aaa4bc7bf57b4e5cfff3ba7",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"Validation: | | 0/? [00:00<?, ?it/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "26637bbce8b2455bb72a6268f2e2196c",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"Validation: | | 0/? [00:00<?, ?it/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "c067cd1c645d49c8a11cae04e3d676dc",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"Validation: | | 0/? [00:00<?, ?it/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "eb0dc33095f841c7a30376c0bdb0fce0",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"Validation: | | 0/? [00:00<?, ?it/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "12667695d03e4baf9ab468a6b57349e0",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"Validation: | | 0/? [00:00<?, ?it/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "e65958ac8567425d9e9722f6023f8660",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"Validation: | | 0/? [00:00<?, ?it/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "c6ff7598384d45c3bd192822cd0070b0",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"Validation: | | 0/? [00:00<?, ?it/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"`Trainer.fit` stopped: `max_epochs=10` reached.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"trainer.fit(model, train_loader, val_loader)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 174,
|
||||
"id": "5d146588",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# %reload_ext tensorboard\n",
|
||||
"# %tensorboard --logdir=lightning_logs/"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.11"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
@ -2,7 +2,7 @@
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 33,
|
||||
"execution_count": 1,
|
||||
"id": "b4cc996f",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@ -29,9 +29,17 @@
|
||||
"import os\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "11cc2375",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Import data"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 34,
|
||||
"execution_count": 2,
|
||||
"id": "b0fe3fe6",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@ -44,7 +52,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 35,
|
||||
"execution_count": 3,
|
||||
"id": "083c39ae",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@ -140,58 +148,58 @@
|
||||
" <td>...</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>548087</th>\n",
|
||||
" <th>548217</th>\n",
|
||||
" <td>21</td>\n",
|
||||
" <td>113</td>\n",
|
||||
" <td>1.747781e+09</td>\n",
|
||||
" <td>35.9836</td>\n",
|
||||
" <td>28.1056</td>\n",
|
||||
" <td>7.1</td>\n",
|
||||
" <td>2.1</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>548218</th>\n",
|
||||
" <td>18</td>\n",
|
||||
" <td>941</td>\n",
|
||||
" <td>1.747781e+09</td>\n",
|
||||
" <td>36.2797</td>\n",
|
||||
" <td>36.0253</td>\n",
|
||||
" <td>7.0</td>\n",
|
||||
" <td>1.9</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>548219</th>\n",
|
||||
" <td>11</td>\n",
|
||||
" <td>141</td>\n",
|
||||
" <td>1.747782e+09</td>\n",
|
||||
" <td>37.2222</td>\n",
|
||||
" <td>36.9486</td>\n",
|
||||
" <td>7.3</td>\n",
|
||||
" <td>1.2</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>548220</th>\n",
|
||||
" <td>32</td>\n",
|
||||
" <td>143</td>\n",
|
||||
" <td>1.747783e+09</td>\n",
|
||||
" <td>-31.1900</td>\n",
|
||||
" <td>-68.2500</td>\n",
|
||||
" <td>81.0</td>\n",
|
||||
" <td>3.3</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>548221</th>\n",
|
||||
" <td>40</td>\n",
|
||||
" <td>326</td>\n",
|
||||
" <td>1.747745e+09</td>\n",
|
||||
" <td>-4.3300</td>\n",
|
||||
" <td>132.9700</td>\n",
|
||||
" <td>10.0</td>\n",
|
||||
" <td>790</td>\n",
|
||||
" <td>1.747783e+09</td>\n",
|
||||
" <td>17.6300</td>\n",
|
||||
" <td>-101.3340</td>\n",
|
||||
" <td>41.2</td>\n",
|
||||
" <td>4.0</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>548088</th>\n",
|
||||
" <td>40</td>\n",
|
||||
" <td>401</td>\n",
|
||||
" <td>1.747745e+09</td>\n",
|
||||
" <td>-30.1300</td>\n",
|
||||
" <td>-69.4600</td>\n",
|
||||
" <td>10.0</td>\n",
|
||||
" <td>4.1</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>548089</th>\n",
|
||||
" <td>13</td>\n",
|
||||
" <td>75</td>\n",
|
||||
" <td>1.747745e+09</td>\n",
|
||||
" <td>38.9889</td>\n",
|
||||
" <td>27.9292</td>\n",
|
||||
" <td>8.9</td>\n",
|
||||
" <td>1.4</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>548090</th>\n",
|
||||
" <td>31</td>\n",
|
||||
" <td>35</td>\n",
|
||||
" <td>1.747746e+09</td>\n",
|
||||
" <td>-8.0000</td>\n",
|
||||
" <td>107.0500</td>\n",
|
||||
" <td>16.0</td>\n",
|
||||
" <td>3.1</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>548091</th>\n",
|
||||
" <td>47</td>\n",
|
||||
" <td>448</td>\n",
|
||||
" <td>1.747746e+09</td>\n",
|
||||
" <td>-23.3231</td>\n",
|
||||
" <td>-179.9220</td>\n",
|
||||
" <td>540.0</td>\n",
|
||||
" <td>4.7</td>\n",
|
||||
" </tr>\n",
|
||||
" </tbody>\n",
|
||||
"</table>\n",
|
||||
"<p>548092 rows × 7 columns</p>\n",
|
||||
"<p>548222 rows × 7 columns</p>\n",
|
||||
"</div>"
|
||||
],
|
||||
"text/plain": [
|
||||
@ -202,16 +210,16 @@
|
||||
"3 31 39 1.577837e+09 19.1900 -67.8400 28.0 3.1\n",
|
||||
"4 35 69 1.577837e+09 -25.6400 -70.5200 53.0 3.5\n",
|
||||
"... ... ... ... ... ... ... ...\n",
|
||||
"548087 40 326 1.747745e+09 -4.3300 132.9700 10.0 4.0\n",
|
||||
"548088 40 401 1.747745e+09 -30.1300 -69.4600 10.0 4.1\n",
|
||||
"548089 13 75 1.747745e+09 38.9889 27.9292 8.9 1.4\n",
|
||||
"548090 31 35 1.747746e+09 -8.0000 107.0500 16.0 3.1\n",
|
||||
"548091 47 448 1.747746e+09 -23.3231 -179.9220 540.0 4.7\n",
|
||||
"548217 21 113 1.747781e+09 35.9836 28.1056 7.1 2.1\n",
|
||||
"548218 18 941 1.747781e+09 36.2797 36.0253 7.0 1.9\n",
|
||||
"548219 11 141 1.747782e+09 37.2222 36.9486 7.3 1.2\n",
|
||||
"548220 32 143 1.747783e+09 -31.1900 -68.2500 81.0 3.3\n",
|
||||
"548221 40 790 1.747783e+09 17.6300 -101.3340 41.2 4.0\n",
|
||||
"\n",
|
||||
"[548092 rows x 7 columns]"
|
||||
"[548222 rows x 7 columns]"
|
||||
]
|
||||
},
|
||||
"execution_count": 35,
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@ -222,7 +230,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 36,
|
||||
"execution_count": 4,
|
||||
"id": "967ec29d",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@ -244,7 +252,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 37,
|
||||
"execution_count": 5,
|
||||
"id": "77c900ad",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@ -267,19 +275,10 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 6,
|
||||
"id": "f0c623e1",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/var/folders/l9/cmw34wr13t3cp2_91pq103100000gn/T/ipykernel_1457/1713124793.py:1: RuntimeWarning: divide by zero encountered in divide\n",
|
||||
" inv_mag_frequency = 1/np.bincount(df['bin'].to_numpy())\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# inv_mag_frequency = 1/np.bincount(df['bin'].to_numpy())/\n",
|
||||
"# print(inv_mag_frequency)\n",
|
||||
@ -288,9 +287,17 @@
|
||||
"# print(sum(mag_frequency))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "5eea5b40",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Create model"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 7,
|
||||
"id": "00b89e53",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@ -316,11 +323,93 @@
|
||||
"\n",
|
||||
"input_size = 5\n",
|
||||
"hidden_size = 20\n",
|
||||
"class_count = 100\n",
|
||||
"regressor_output_size = 3\n",
|
||||
"class_count = 120\n",
|
||||
"regressor_output_size = 4\n",
|
||||
"# model = MultiRNN(input_size, hidden_size, class_count, regressor_output_size)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "0e346047",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Create dataset/dataloader"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"id": "2744a649",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class TimeseriesDataset(torch.utils.data.Dataset):\n",
|
||||
" def __init__(self, x, class_y, regress_y, seq_len=100):\n",
|
||||
" self.seq_len = seq_len\n",
|
||||
" self.x = x\n",
|
||||
" self.class_y = class_y\n",
|
||||
" self.regress_y = regress_y\n",
|
||||
" \n",
|
||||
" def __len__(self):\n",
|
||||
" return self.x.__len__() - (self.seq_len-1)\n",
|
||||
" \n",
|
||||
" def __getitem__(self, index):\n",
|
||||
" return (self.x[index:index+self.seq_len], self.class_y[index+self.seq_len], self.regress_y[index+self.seq_len])\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"id": "5714b17d",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"x = torch.tensor(df[['time_to_next_event','time','lat','lon','depth','mag']].values).to(torch.float)\n",
|
||||
"y1 = torch.tensor(df[['bin']].values)\n",
|
||||
"y2 = torch.tensor(df[['time_to_next_event', 'time', 'lat', 'lon']].values).to(torch.float)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"id": "72f7b0f1",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"train_dataset = TimeseriesDataset(x, y1, y2, seq_len=1000)\n",
|
||||
"train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=10, shuffle=False)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "ee154b1b",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"0 torch.Size([10, 1000, 6]) torch.Size([10, 1]) torch.Size([10, 4])\n",
|
||||
"54723\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# for i, d in enumerate(train_loader):\n",
|
||||
"# print(i, d[0].shape, d[1].shape, d[2].shape)\n",
|
||||
"# break\n",
|
||||
"# print(len(train_loader))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "2277a6ee",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Create the Lightning setup for model"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
@ -332,33 +421,12 @@
|
||||
" def __init__(self, input_size, hidden_size, class_count, output_size):\n",
|
||||
" super().__init__()\n",
|
||||
" self.model = MultiRNN(input_size, hidden_size, class_count, regressor_output_size)\n",
|
||||
" self.device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
||||
" self.model.to(self.device)\n",
|
||||
"\n",
|
||||
" def forward(self, x):\n",
|
||||
" return self.model.forward(x)\n",
|
||||
" \n",
|
||||
" def training_step(self, batch, batch_idx):\n",
|
||||
" x, y1, y2 = batch\n",
|
||||
" y_hat = self(x)\n",
|
||||
" loss1 = nn.functional.cross_entropy(y_hat, y1)\n",
|
||||
" loss2 = nn.functional.mse_loss(y_hat,y2)\n",
|
||||
" # loss = F.cross_entropy(y_hat, y)\n",
|
||||
" loss = loss1 + loss2\n",
|
||||
" self.log('train_loss_class', loss1)\n",
|
||||
" self.log('train_loss_regress', loss1)\n",
|
||||
" return loss\n",
|
||||
"\n",
|
||||
" def training_step(self, batch, batch_idx):\n",
|
||||
" x, y = batch\n",
|
||||
" y_hat = self(x)\n",
|
||||
" \n",
|
||||
" loss = F.cross_entropy(y_hat, y)\n",
|
||||
" acc = (y_hat.argmax(1) == y).float().mean()\n",
|
||||
" \n",
|
||||
" self.log(\"train_loss\", loss)\n",
|
||||
" self.log(\"train_acc\", acc)\n",
|
||||
" return loss\n",
|
||||
" \n",
|
||||
"\n",
|
||||
" def configure_optimizers(self):\n",
|
||||
" optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-3)\n",
|
||||
" scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(\n",
|
||||
@ -369,11 +437,66 @@
|
||||
" \"optimizer\": optimizer,\n",
|
||||
" \"lr_scheduler\": {\n",
|
||||
" \"scheduler\": scheduler,\n",
|
||||
" \"monitor\": \"val_loss\",\n",
|
||||
" \"monitor\": \"val_loss_class\",\n",
|
||||
" },\n",
|
||||
" }"
|
||||
" }\n",
|
||||
" \n",
|
||||
" def training_step(self, batch, batch_idx):\n",
|
||||
" x, y1, y2 = batch\n",
|
||||
" \n",
|
||||
" x.to(self.device)\n",
|
||||
" y1.to(self.device)\n",
|
||||
" y2.to(self.device)\n",
|
||||
" \n",
|
||||
" \n",
|
||||
" class_y_hat, regress_y_hat = self(x)\n",
|
||||
" loss1 = nn.functional.cross_entropy(class_y_hat, y1)\n",
|
||||
" loss2 = nn.functional.mse_loss(regress_y_hat, y2)\n",
|
||||
" # loss = F.cross_entropy(y_hat, y)\n",
|
||||
" loss = loss1 + loss2\n",
|
||||
" self.log('train_loss_class', loss1)\n",
|
||||
" self.log('train_loss_regress', loss1) #, on_epoch=True)\n",
|
||||
" return loss\n",
|
||||
" \n",
|
||||
" def validation_step(self, val_batch, batch_idx):\n",
|
||||
" x, y1, y2 = val_batch\n",
|
||||
" \n",
|
||||
" x.to(self.device)\n",
|
||||
" y1.to(self.device)\n",
|
||||
" y2.to(self.device)\n",
|
||||
" \n",
|
||||
" \n",
|
||||
" class_y_hat, regress_y_hat = self(x)\n",
|
||||
" loss1 = nn.functional.cross_entropy(class_y_hat, y1)\n",
|
||||
" loss2 = nn.functional.mse_loss(regress_y_hat, y2)\n",
|
||||
" # loss = F.cross_entropy(y_hat, y)\n",
|
||||
" loss = loss1 + loss2\n",
|
||||
" self.log('val_loss_class', loss1)\n",
|
||||
" self.log('val_loss_regress', loss1) #, on_epoch=True)\n",
|
||||
" return loss\n",
|
||||
"\n",
|
||||
" # def training_step(self, batch, batch_idx):\n",
|
||||
" # x, y = batch\n",
|
||||
" # class_y_hat, regress_y_hat = self(x)\n",
|
||||
" \n",
|
||||
" # loss = F.cross_entropy(y_hat, y)\n",
|
||||
" # acc = (y_hat.argmax(1) == y).float().mean()\n",
|
||||
" \n",
|
||||
" # self.log(\"train_loss\", loss)\n",
|
||||
" # self.log(\"train_acc\", acc)\n",
|
||||
" # return loss\n",
|
||||
" \n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "97657a45",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
@ -418,7 +541,7 @@
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "sideprojects",
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
@ -432,7 +555,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.8"
|
||||
"version": "3.11.11"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
1080
on_data.ipynb
1080
on_data.ipynb
File diff suppressed because one or more lines are too long
193
rnn_regression_model.py
Normal file
193
rnn_regression_model.py
Normal file
@ -0,0 +1,193 @@
|
||||
#%%
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import matplotlib.pyplot as plt
|
||||
import scipy.stats as stats
|
||||
from scipy.stats import genextreme
|
||||
from scipy.stats import genpareto
|
||||
|
||||
from argparse import ArgumentParser
|
||||
import pathlib
|
||||
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
import lightning as L
|
||||
|
||||
import os
|
||||
#%%
|
||||
class TimeseriesDataset(torch.utils.data.Dataset):
|
||||
def __init__(self, x, y, seq_len=100):
|
||||
self.seq_len = seq_len
|
||||
self.x = x
|
||||
self.y = y
|
||||
|
||||
def __len__(self):
|
||||
return self.x.__len__() - self.seq_len
|
||||
|
||||
def __getitem__(self, index):
|
||||
return (self.x[index:index+self.seq_len], self.y[index+self.seq_len])
|
||||
|
||||
|
||||
#%%
|
||||
## Copied from https://stackoverflow.com/questions/50817916/how-do-i-add-lstm-gru-or-other-recurrent-layers-to-a-sequential-in-pytorch
|
||||
class SelectItem(nn.Module):
|
||||
def __init__(self, item_index):
|
||||
super(SelectItem, self).__init__()
|
||||
self._name = 'selectitem'
|
||||
self.item_index = item_index
|
||||
|
||||
def forward(self, inputs):
|
||||
return inputs[self.item_index]
|
||||
|
||||
#%%
|
||||
class MultiRNNLightning(L.LightningModule):
|
||||
def __init__(self, input_size, output_size, args, conf):
|
||||
super().__init__()
|
||||
self.model = nn.Sequential(nn.LSTM(input_size, args.hidden_size, batch_first=True, bidirectional=args.b),
|
||||
SelectItem(1),
|
||||
SelectItem(0),
|
||||
nn.ReLU(),
|
||||
nn.Linear(args.hidden_size*(2 if args.b else 1), output_size))
|
||||
|
||||
self.save_hyperparameters(conf)
|
||||
|
||||
# self.model = nn.GRU(input_size, hidden_size, batch_first=True)
|
||||
|
||||
# self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
# self.model.to(self.device)
|
||||
|
||||
@staticmethod
|
||||
def add_model_specific_args(parent_parser):
|
||||
parser = ArgumentParser(parents=[parent_parser], add_help=False)
|
||||
parser.add_argument('--hidden_size', type=int, default=100)
|
||||
parser.add_argument('-b', action='store_true')
|
||||
return parser
|
||||
|
||||
def forward(self, x):
|
||||
# print("hi")
|
||||
return torch.squeeze(self.model(x), 0)
|
||||
|
||||
def configure_optimizers(self):
|
||||
optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-3)
|
||||
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
||||
optimizer, mode="min", factor=0.1, patience=5
|
||||
)
|
||||
|
||||
return {
|
||||
"optimizer": optimizer,
|
||||
"lr_scheduler": {
|
||||
"scheduler": scheduler,
|
||||
"monitor": "val_loss",
|
||||
},
|
||||
}
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
x, y = batch
|
||||
y_hat = self(x)
|
||||
loss = nn.functional.mse_loss(y_hat, y)
|
||||
# print(batch)
|
||||
self.log('train_loss', loss, prog_bar=True) #, on_epoch=True)
|
||||
self.log('training_loss_per_unit', loss/y.shape[0])
|
||||
return loss
|
||||
|
||||
def validation_step(self, val_batch, batch_idx):
|
||||
x, y = val_batch
|
||||
y_hat = self(x)
|
||||
loss = nn.functional.mse_loss(y_hat, y)
|
||||
self.log('val_loss', loss, prog_bar=True) #, on_epoch=True)
|
||||
# print(val_batch)
|
||||
self.log('val_loss_per_unit', loss/y.shape[0])
|
||||
return loss
|
||||
|
||||
def test_step(self, test_batch, batch_idx):
|
||||
x, y = test_batch
|
||||
y_hat = self(x)
|
||||
loss = nn.functional.mse_loss(y_hat, y)
|
||||
self.log('test_loss', loss, prog_bar=True) #, on_epoch=True)
|
||||
self.log('test_loss_per_unit', loss/y.shape[0])
|
||||
return loss
|
||||
|
||||
|
||||
|
||||
|
||||
#%%
|
||||
parser = ArgumentParser()
|
||||
|
||||
|
||||
|
||||
#%%
|
||||
parser.add_argument('--val_size', type=float, default=0.2)
|
||||
parser.add_argument('--test_size', type=float, default=0.2)
|
||||
|
||||
parser.add_argument('--seq_size', type=int, default=1000)
|
||||
|
||||
parser.add_argument('--data_dir', type=pathlib.Path, default=pathlib.Path("./data"))
|
||||
|
||||
parser.add_argument('--batch_size', type=int, default=10)
|
||||
|
||||
parser.add_argument('-x', type=str, nargs='+', default=['time_to_next_event','time','lat','lon','depth','mag'])
|
||||
parser.add_argument('-y', type=str, nargs='+', default=['time_to_next_event','lat','lon','depth','mag'])
|
||||
|
||||
parser = MultiRNNLightning.add_model_specific_args(parser)
|
||||
# parser = L.Trainer.add_argparse_args(parser)
|
||||
parser.add_argument('--max_epochs', type=int, default=1)
|
||||
|
||||
#%%
|
||||
args = parser.parse_args()
|
||||
|
||||
#%%
|
||||
datadir = args.data_dir / 'reduced_data.csv'
|
||||
print(datadir.as_posix())
|
||||
df = pd.read_csv(datadir.as_posix())
|
||||
|
||||
#%%
|
||||
val_size = int(args.val_size*len(df))
|
||||
test_size = int(args.test_size*len(df))
|
||||
train_size = len(df)-val_size-test_size
|
||||
|
||||
assert val_size >= 0, "validation size is less than 0"
|
||||
assert test_size >= 0, "test size is less than 0"
|
||||
assert train_size > 0, "train size is non-positive"
|
||||
|
||||
|
||||
#%%
|
||||
x_columns = args.x
|
||||
y_columns = args.y
|
||||
|
||||
x = torch.tensor(df[x_columns].values).to(torch.float)
|
||||
y = torch.tensor(df[y_columns].values).to(torch.float)
|
||||
|
||||
mag_index = -1 if 'mag' not in y_columns else y_columns.find('mag')
|
||||
|
||||
train_x, val_x, test_x = torch.split(x, [train_size, val_size, test_size])
|
||||
|
||||
train_y, val_y, test_y = torch.split(y, [train_size, val_size, test_size])
|
||||
|
||||
train_dataset = TimeseriesDataset(train_x, train_y, seq_len=args.seq_size)
|
||||
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
|
||||
|
||||
val_dataset = TimeseriesDataset(val_x, val_y, seq_len=args.seq_size)
|
||||
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False)
|
||||
|
||||
test_dataset = TimeseriesDataset(test_x, test_y, seq_len=args.seq_size)
|
||||
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False)
|
||||
|
||||
#%%
|
||||
input_size = len(x_columns)
|
||||
output_size = len(y_columns)
|
||||
|
||||
#%%
|
||||
|
||||
model = MultiRNNLightning(input_size, output_size, args, {"input_dim":input_size, "output_dim":output_size, "hidden_dim": args.hidden_size, "bidirectional": args.b, "batch_size":args.batch_size, "seq_len":args.seq_size, "input_features": args.x, "output_features": args.y})
|
||||
trainer = L.Trainer(max_epochs=args.max_epochs)
|
||||
|
||||
|
||||
|
||||
#%%
|
||||
trainer.fit(model, train_loader, val_loader)
|
||||
|
||||
|
||||
#%%
|
||||
trainer.test(model, test_loader)
|
||||
Loading…
Reference in New Issue
Block a user