An initial and unfinished attempt at creating an RNN model for the problem
Signed-off-by: Ethan Wellenreiter <ewellenreiter@gmail.com>
This commit is contained in:
parent
c3524eda21
commit
e52ee45261
440
model_training.ipynb
Normal file
440
model_training.ipynb
Normal file
@ -0,0 +1,440 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 33,
|
||||
"id": "b4cc996f",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import numpy as np\n",
|
||||
"import pandas as pd\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"import scipy.stats as stats\n",
|
||||
"from scipy.stats import genextreme\n",
|
||||
"from scipy.stats import genpareto\n",
|
||||
"import requests\n",
|
||||
"import json\n",
|
||||
"\n",
|
||||
"# from datetime import date\n",
|
||||
"# from datetime import datetime\n",
|
||||
"# from datetime import timedelta\n",
|
||||
"# import pytz\n",
|
||||
"\n",
|
||||
"import torch\n",
|
||||
"import torch.nn as nn\n",
|
||||
"import torch.optim as optim\n",
|
||||
"import lightning as L\n",
|
||||
"\n",
|
||||
"import os\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 34,
|
||||
"id": "b0fe3fe6",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"df = pd.read_csv('reduced_data.csv')\n",
|
||||
"# df['time'] = pd.to_datetime(df['time'], format='ISO8601')\n",
|
||||
"\n",
|
||||
"# df['lastupdate'] = pd.to_datetime(df['lastupdate'], format='ISO8601')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 35,
|
||||
"id": "083c39ae",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<div>\n",
|
||||
"<style scoped>\n",
|
||||
" .dataframe tbody tr th:only-of-type {\n",
|
||||
" vertical-align: middle;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe tbody tr th {\n",
|
||||
" vertical-align: top;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe thead th {\n",
|
||||
" text-align: right;\n",
|
||||
" }\n",
|
||||
"</style>\n",
|
||||
"<table border=\"1\" class=\"dataframe\">\n",
|
||||
" <thead>\n",
|
||||
" <tr style=\"text-align: right;\">\n",
|
||||
" <th></th>\n",
|
||||
" <th>bin</th>\n",
|
||||
" <th>time_to_next_event</th>\n",
|
||||
" <th>time</th>\n",
|
||||
" <th>lat</th>\n",
|
||||
" <th>lon</th>\n",
|
||||
" <th>depth</th>\n",
|
||||
" <th>mag</th>\n",
|
||||
" </tr>\n",
|
||||
" </thead>\n",
|
||||
" <tbody>\n",
|
||||
" <tr>\n",
|
||||
" <th>0</th>\n",
|
||||
" <td>27</td>\n",
|
||||
" <td>24</td>\n",
|
||||
" <td>1.577837e+09</td>\n",
|
||||
" <td>19.2200</td>\n",
|
||||
" <td>-67.1300</td>\n",
|
||||
" <td>12.0</td>\n",
|
||||
" <td>2.8</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>1</th>\n",
|
||||
" <td>30</td>\n",
|
||||
" <td>77</td>\n",
|
||||
" <td>1.577837e+09</td>\n",
|
||||
" <td>-2.7400</td>\n",
|
||||
" <td>127.9000</td>\n",
|
||||
" <td>20.0</td>\n",
|
||||
" <td>3.0</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>2</th>\n",
|
||||
" <td>25</td>\n",
|
||||
" <td>425</td>\n",
|
||||
" <td>1.577837e+09</td>\n",
|
||||
" <td>19.0800</td>\n",
|
||||
" <td>-67.0900</td>\n",
|
||||
" <td>6.0</td>\n",
|
||||
" <td>2.5</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>3</th>\n",
|
||||
" <td>31</td>\n",
|
||||
" <td>39</td>\n",
|
||||
" <td>1.577837e+09</td>\n",
|
||||
" <td>19.1900</td>\n",
|
||||
" <td>-67.8400</td>\n",
|
||||
" <td>28.0</td>\n",
|
||||
" <td>3.1</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>4</th>\n",
|
||||
" <td>35</td>\n",
|
||||
" <td>69</td>\n",
|
||||
" <td>1.577837e+09</td>\n",
|
||||
" <td>-25.6400</td>\n",
|
||||
" <td>-70.5200</td>\n",
|
||||
" <td>53.0</td>\n",
|
||||
" <td>3.5</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>...</th>\n",
|
||||
" <td>...</td>\n",
|
||||
" <td>...</td>\n",
|
||||
" <td>...</td>\n",
|
||||
" <td>...</td>\n",
|
||||
" <td>...</td>\n",
|
||||
" <td>...</td>\n",
|
||||
" <td>...</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>548087</th>\n",
|
||||
" <td>40</td>\n",
|
||||
" <td>326</td>\n",
|
||||
" <td>1.747745e+09</td>\n",
|
||||
" <td>-4.3300</td>\n",
|
||||
" <td>132.9700</td>\n",
|
||||
" <td>10.0</td>\n",
|
||||
" <td>4.0</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>548088</th>\n",
|
||||
" <td>40</td>\n",
|
||||
" <td>401</td>\n",
|
||||
" <td>1.747745e+09</td>\n",
|
||||
" <td>-30.1300</td>\n",
|
||||
" <td>-69.4600</td>\n",
|
||||
" <td>10.0</td>\n",
|
||||
" <td>4.1</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>548089</th>\n",
|
||||
" <td>13</td>\n",
|
||||
" <td>75</td>\n",
|
||||
" <td>1.747745e+09</td>\n",
|
||||
" <td>38.9889</td>\n",
|
||||
" <td>27.9292</td>\n",
|
||||
" <td>8.9</td>\n",
|
||||
" <td>1.4</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>548090</th>\n",
|
||||
" <td>31</td>\n",
|
||||
" <td>35</td>\n",
|
||||
" <td>1.747746e+09</td>\n",
|
||||
" <td>-8.0000</td>\n",
|
||||
" <td>107.0500</td>\n",
|
||||
" <td>16.0</td>\n",
|
||||
" <td>3.1</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>548091</th>\n",
|
||||
" <td>47</td>\n",
|
||||
" <td>448</td>\n",
|
||||
" <td>1.747746e+09</td>\n",
|
||||
" <td>-23.3231</td>\n",
|
||||
" <td>-179.9220</td>\n",
|
||||
" <td>540.0</td>\n",
|
||||
" <td>4.7</td>\n",
|
||||
" </tr>\n",
|
||||
" </tbody>\n",
|
||||
"</table>\n",
|
||||
"<p>548092 rows × 7 columns</p>\n",
|
||||
"</div>"
|
||||
],
|
||||
"text/plain": [
|
||||
" bin time_to_next_event time lat lon depth mag\n",
|
||||
"0 27 24 1.577837e+09 19.2200 -67.1300 12.0 2.8\n",
|
||||
"1 30 77 1.577837e+09 -2.7400 127.9000 20.0 3.0\n",
|
||||
"2 25 425 1.577837e+09 19.0800 -67.0900 6.0 2.5\n",
|
||||
"3 31 39 1.577837e+09 19.1900 -67.8400 28.0 3.1\n",
|
||||
"4 35 69 1.577837e+09 -25.6400 -70.5200 53.0 3.5\n",
|
||||
"... ... ... ... ... ... ... ...\n",
|
||||
"548087 40 326 1.747745e+09 -4.3300 132.9700 10.0 4.0\n",
|
||||
"548088 40 401 1.747745e+09 -30.1300 -69.4600 10.0 4.1\n",
|
||||
"548089 13 75 1.747745e+09 38.9889 27.9292 8.9 1.4\n",
|
||||
"548090 31 35 1.747746e+09 -8.0000 107.0500 16.0 3.1\n",
|
||||
"548091 47 448 1.747746e+09 -23.3231 -179.9220 540.0 4.7\n",
|
||||
"\n",
|
||||
"[548092 rows x 7 columns]"
|
||||
]
|
||||
},
|
||||
"execution_count": 35,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"df"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 36,
|
||||
"id": "967ec29d",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# # Set threshold at the 95th percentile\n",
|
||||
"# # threshold = np.percentile(data, 95)\n",
|
||||
"# threshold = 6 ## the actual threshold value\n",
|
||||
"# data = df['mag']\n",
|
||||
"# extremes = data[data > threshold]\n",
|
||||
"# print(f\"Threshold: {threshold}, Number of extremes: {len(extremes)}\")\n",
|
||||
"\n",
|
||||
"# # Visualize the threshold and extremes\n",
|
||||
"# plt.hist(data, bins=30, edgecolor='k', alpha=0.7, label='Data')\n",
|
||||
"# plt.axvline(threshold, color='red', linestyle='--', label='Threshold')\n",
|
||||
"# plt.title('Threshold for POT')\n",
|
||||
"# plt.legend()\n",
|
||||
"# plt.show()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 37,
|
||||
"id": "77c900ad",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# from scipy.stats import genpareto\n",
|
||||
"\n",
|
||||
"# # Fit a GPD to the extremes\n",
|
||||
"# gpd_params = genpareto.fit(extremes - threshold) # Subtract threshold for GPD fit\n",
|
||||
"# print(f\"GPD Parameters: Shape={gpd_params[0]}, Location={gpd_params[1]}, Scale={gpd_params[2]}\")\n",
|
||||
"\n",
|
||||
"# # Visualize the GPD fit\n",
|
||||
"# x = np.linspace(min(extremes), max(extremes), 100)\n",
|
||||
"# pdf = genpareto.pdf(x - threshold, *gpd_params)\n",
|
||||
"# plt.hist(extremes, bins=10, density=True, alpha=0.7, label='Data')\n",
|
||||
"# plt.plot(x, pdf, label='GPD Fit', color='blue')\n",
|
||||
"# plt.title('GPD Fit to Extremes')\n",
|
||||
"# plt.legend()\n",
|
||||
"# plt.show()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "f0c623e1",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/var/folders/l9/cmw34wr13t3cp2_91pq103100000gn/T/ipykernel_1457/1713124793.py:1: RuntimeWarning: divide by zero encountered in divide\n",
|
||||
" inv_mag_frequency = 1/np.bincount(df['bin'].to_numpy())\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# inv_mag_frequency = 1/np.bincount(df['bin'].to_numpy())/\n",
|
||||
"# print(inv_mag_frequency)\n",
|
||||
"# inv_mag_frequency[inv_mag_frequency == np.inf] = 2\n",
|
||||
"# print(inv_mag_frequency)\n",
|
||||
"# print(sum(mag_frequency))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "00b89e53",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class MultiRNN(nn.Module):\n",
|
||||
" def __init__(self, input_size, hidden_size, class_count, output_size):\n",
|
||||
" super(MultiRNN, self).__init__()\n",
|
||||
" self.rnn = nn.GRU(input_size, hidden_size, batch_first=True)\n",
|
||||
" self.classify = nn.Linear(hidden_size, class_count)\n",
|
||||
" self.regress = nn.Linear(hidden_size, output_size)\n",
|
||||
" \n",
|
||||
" def forward(self, x):\n",
|
||||
" h0 = torch.zeros(1, x.size(0), hidden_size).to(x.device)\n",
|
||||
" out, _ = self.rnn(x, h0)\n",
|
||||
" classes = self.classify(out[:,-1])\n",
|
||||
" regresses = self.regress(out[:,-1])\n",
|
||||
"\n",
|
||||
" if self.training:\n",
|
||||
" return classes, regresses\n",
|
||||
" else:\n",
|
||||
" return torch.argmax(classes, dim=-1), regresses\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"input_size = 5\n",
|
||||
"hidden_size = 20\n",
|
||||
"class_count = 100\n",
|
||||
"regressor_output_size = 3\n",
|
||||
"# model = MultiRNN(input_size, hidden_size, class_count, regressor_output_size)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "a35aa9f2",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class MutliRNNLightning(L.LightningModule):\n",
|
||||
" def __init__(self, input_size, hidden_size, class_count, output_size):\n",
|
||||
" super().__init__()\n",
|
||||
" self.model = MultiRNN(input_size, hidden_size, class_count, regressor_output_size)\n",
|
||||
"\n",
|
||||
" def forward(self, x):\n",
|
||||
" return self.model.forward(x)\n",
|
||||
" \n",
|
||||
" def training_step(self, batch, batch_idx):\n",
|
||||
" x, y1, y2 = batch\n",
|
||||
" y_hat = self(x)\n",
|
||||
" loss1 = nn.functional.cross_entropy(y_hat, y1)\n",
|
||||
" loss2 = nn.functional.mse_loss(y_hat,y2)\n",
|
||||
" # loss = F.cross_entropy(y_hat, y)\n",
|
||||
" loss = loss1 + loss2\n",
|
||||
" self.log('train_loss_class', loss1)\n",
|
||||
" self.log('train_loss_regress', loss1)\n",
|
||||
" return loss\n",
|
||||
"\n",
|
||||
" def training_step(self, batch, batch_idx):\n",
|
||||
" x, y = batch\n",
|
||||
" y_hat = self(x)\n",
|
||||
" \n",
|
||||
" loss = F.cross_entropy(y_hat, y)\n",
|
||||
" acc = (y_hat.argmax(1) == y).float().mean()\n",
|
||||
" \n",
|
||||
" self.log(\"train_loss\", loss)\n",
|
||||
" self.log(\"train_acc\", acc)\n",
|
||||
" return loss\n",
|
||||
" \n",
|
||||
"\n",
|
||||
" def configure_optimizers(self):\n",
|
||||
" optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-3)\n",
|
||||
" scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(\n",
|
||||
" optimizer, mode=\"min\", factor=0.1, patience=5\n",
|
||||
" )\n",
|
||||
" \n",
|
||||
" return {\n",
|
||||
" \"optimizer\": optimizer,\n",
|
||||
" \"lr_scheduler\": {\n",
|
||||
" \"scheduler\": scheduler,\n",
|
||||
" \"monitor\": \"val_loss\",\n",
|
||||
" },\n",
|
||||
" }"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "b49f9d92",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"sequence_length = 1000\n",
|
||||
"\n",
|
||||
"# criterion = nn.MSELoss()\n",
|
||||
"class_loss = nn.CrossEntropyLoss(weight=torch.from_numpy(inv_mag_frequency))\n",
|
||||
"regressor_loss = nn.MSELoss()\n",
|
||||
"# optimizer = optim.Adam(model.parameters(), lr=0.001)\n",
|
||||
"\n",
|
||||
"X = torch.tensor(df.iloc[:sequence_length][['time', 'lat', 'lon', 'depth', 'mag']].values).to(torch.float)\n",
|
||||
"X = torch.stack((torch.tensor(df.iloc[sequence_length:sequence_length*2][['time', 'lat', 'lon', 'depth', 'mag']].values).to(torch.float),X))\n",
|
||||
"# print(X.shape)\n",
|
||||
"model.train()\n",
|
||||
"outputs = model(X)\n",
|
||||
"# print(outputs[0].shape)\n",
|
||||
"# print(outputs[1].shape)\n",
|
||||
"\n",
|
||||
"# model.eval()\n",
|
||||
"# outputs = model(X)\n",
|
||||
"# print(outputs[0].shape)\n",
|
||||
"# print(outputs[1].shape)\n",
|
||||
"\n",
|
||||
"# num_epochs = 100\n",
|
||||
"# for epoch in range(num_epochs):\n",
|
||||
"# model.train()\n",
|
||||
"# outputs = model(X.unsqueeze(2))\n",
|
||||
"# loss = criterion(outputs, y.unsqueeze(2))\n",
|
||||
" \n",
|
||||
"# optimizer.zero_grad()\n",
|
||||
"# loss.backward()\n",
|
||||
"# optimizer.step()\n",
|
||||
" \n",
|
||||
"# if (epoch + 1) % 10 == 0:\n",
|
||||
"# print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "sideprojects",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.8"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
Loading…
Reference in New Issue
Block a user