An initial and unfinished attempt at creating an RNN model for the problem

Signed-off-by: Ethan Wellenreiter <ewellenreiter@gmail.com>
This commit is contained in:
Ethan Wellenreiter 2025-05-20 15:30:58 -04:00
parent c3524eda21
commit e52ee45261

440
model_training.ipynb Normal file
View File

@ -0,0 +1,440 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 33,
"id": "b4cc996f",
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import pandas as pd\n",
"import matplotlib.pyplot as plt\n",
"import scipy.stats as stats\n",
"from scipy.stats import genextreme\n",
"from scipy.stats import genpareto\n",
"import requests\n",
"import json\n",
"\n",
"# from datetime import date\n",
"# from datetime import datetime\n",
"# from datetime import timedelta\n",
"# import pytz\n",
"\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.optim as optim\n",
"import lightning as L\n",
"\n",
"import os\n"
]
},
{
"cell_type": "code",
"execution_count": 34,
"id": "b0fe3fe6",
"metadata": {},
"outputs": [],
"source": [
"df = pd.read_csv('reduced_data.csv')\n",
"# df['time'] = pd.to_datetime(df['time'], format='ISO8601')\n",
"\n",
"# df['lastupdate'] = pd.to_datetime(df['lastupdate'], format='ISO8601')"
]
},
{
"cell_type": "code",
"execution_count": 35,
"id": "083c39ae",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>bin</th>\n",
" <th>time_to_next_event</th>\n",
" <th>time</th>\n",
" <th>lat</th>\n",
" <th>lon</th>\n",
" <th>depth</th>\n",
" <th>mag</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>27</td>\n",
" <td>24</td>\n",
" <td>1.577837e+09</td>\n",
" <td>19.2200</td>\n",
" <td>-67.1300</td>\n",
" <td>12.0</td>\n",
" <td>2.8</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>30</td>\n",
" <td>77</td>\n",
" <td>1.577837e+09</td>\n",
" <td>-2.7400</td>\n",
" <td>127.9000</td>\n",
" <td>20.0</td>\n",
" <td>3.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>25</td>\n",
" <td>425</td>\n",
" <td>1.577837e+09</td>\n",
" <td>19.0800</td>\n",
" <td>-67.0900</td>\n",
" <td>6.0</td>\n",
" <td>2.5</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>31</td>\n",
" <td>39</td>\n",
" <td>1.577837e+09</td>\n",
" <td>19.1900</td>\n",
" <td>-67.8400</td>\n",
" <td>28.0</td>\n",
" <td>3.1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>35</td>\n",
" <td>69</td>\n",
" <td>1.577837e+09</td>\n",
" <td>-25.6400</td>\n",
" <td>-70.5200</td>\n",
" <td>53.0</td>\n",
" <td>3.5</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>548087</th>\n",
" <td>40</td>\n",
" <td>326</td>\n",
" <td>1.747745e+09</td>\n",
" <td>-4.3300</td>\n",
" <td>132.9700</td>\n",
" <td>10.0</td>\n",
" <td>4.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>548088</th>\n",
" <td>40</td>\n",
" <td>401</td>\n",
" <td>1.747745e+09</td>\n",
" <td>-30.1300</td>\n",
" <td>-69.4600</td>\n",
" <td>10.0</td>\n",
" <td>4.1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>548089</th>\n",
" <td>13</td>\n",
" <td>75</td>\n",
" <td>1.747745e+09</td>\n",
" <td>38.9889</td>\n",
" <td>27.9292</td>\n",
" <td>8.9</td>\n",
" <td>1.4</td>\n",
" </tr>\n",
" <tr>\n",
" <th>548090</th>\n",
" <td>31</td>\n",
" <td>35</td>\n",
" <td>1.747746e+09</td>\n",
" <td>-8.0000</td>\n",
" <td>107.0500</td>\n",
" <td>16.0</td>\n",
" <td>3.1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>548091</th>\n",
" <td>47</td>\n",
" <td>448</td>\n",
" <td>1.747746e+09</td>\n",
" <td>-23.3231</td>\n",
" <td>-179.9220</td>\n",
" <td>540.0</td>\n",
" <td>4.7</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>548092 rows × 7 columns</p>\n",
"</div>"
],
"text/plain": [
" bin time_to_next_event time lat lon depth mag\n",
"0 27 24 1.577837e+09 19.2200 -67.1300 12.0 2.8\n",
"1 30 77 1.577837e+09 -2.7400 127.9000 20.0 3.0\n",
"2 25 425 1.577837e+09 19.0800 -67.0900 6.0 2.5\n",
"3 31 39 1.577837e+09 19.1900 -67.8400 28.0 3.1\n",
"4 35 69 1.577837e+09 -25.6400 -70.5200 53.0 3.5\n",
"... ... ... ... ... ... ... ...\n",
"548087 40 326 1.747745e+09 -4.3300 132.9700 10.0 4.0\n",
"548088 40 401 1.747745e+09 -30.1300 -69.4600 10.0 4.1\n",
"548089 13 75 1.747745e+09 38.9889 27.9292 8.9 1.4\n",
"548090 31 35 1.747746e+09 -8.0000 107.0500 16.0 3.1\n",
"548091 47 448 1.747746e+09 -23.3231 -179.9220 540.0 4.7\n",
"\n",
"[548092 rows x 7 columns]"
]
},
"execution_count": 35,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df"
]
},
{
"cell_type": "code",
"execution_count": 36,
"id": "967ec29d",
"metadata": {},
"outputs": [],
"source": [
"# # Set threshold at the 95th percentile\n",
"# # threshold = np.percentile(data, 95)\n",
"# threshold = 6 ## the actual threshold value\n",
"# data = df['mag']\n",
"# extremes = data[data > threshold]\n",
"# print(f\"Threshold: {threshold}, Number of extremes: {len(extremes)}\")\n",
"\n",
"# # Visualize the threshold and extremes\n",
"# plt.hist(data, bins=30, edgecolor='k', alpha=0.7, label='Data')\n",
"# plt.axvline(threshold, color='red', linestyle='--', label='Threshold')\n",
"# plt.title('Threshold for POT')\n",
"# plt.legend()\n",
"# plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 37,
"id": "77c900ad",
"metadata": {},
"outputs": [],
"source": [
"# from scipy.stats import genpareto\n",
"\n",
"# # Fit a GPD to the extremes\n",
"# gpd_params = genpareto.fit(extremes - threshold) # Subtract threshold for GPD fit\n",
"# print(f\"GPD Parameters: Shape={gpd_params[0]}, Location={gpd_params[1]}, Scale={gpd_params[2]}\")\n",
"\n",
"# # Visualize the GPD fit\n",
"# x = np.linspace(min(extremes), max(extremes), 100)\n",
"# pdf = genpareto.pdf(x - threshold, *gpd_params)\n",
"# plt.hist(extremes, bins=10, density=True, alpha=0.7, label='Data')\n",
"# plt.plot(x, pdf, label='GPD Fit', color='blue')\n",
"# plt.title('GPD Fit to Extremes')\n",
"# plt.legend()\n",
"# plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f0c623e1",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/var/folders/l9/cmw34wr13t3cp2_91pq103100000gn/T/ipykernel_1457/1713124793.py:1: RuntimeWarning: divide by zero encountered in divide\n",
" inv_mag_frequency = 1/np.bincount(df['bin'].to_numpy())\n"
]
}
],
"source": [
"# inv_mag_frequency = 1/np.bincount(df['bin'].to_numpy())/\n",
"# print(inv_mag_frequency)\n",
"# inv_mag_frequency[inv_mag_frequency == np.inf] = 2\n",
"# print(inv_mag_frequency)\n",
"# print(sum(mag_frequency))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "00b89e53",
"metadata": {},
"outputs": [],
"source": [
"class MultiRNN(nn.Module):\n",
" def __init__(self, input_size, hidden_size, class_count, output_size):\n",
" super(MultiRNN, self).__init__()\n",
" self.rnn = nn.GRU(input_size, hidden_size, batch_first=True)\n",
" self.classify = nn.Linear(hidden_size, class_count)\n",
" self.regress = nn.Linear(hidden_size, output_size)\n",
" \n",
" def forward(self, x):\n",
" h0 = torch.zeros(1, x.size(0), hidden_size).to(x.device)\n",
" out, _ = self.rnn(x, h0)\n",
" classes = self.classify(out[:,-1])\n",
" regresses = self.regress(out[:,-1])\n",
"\n",
" if self.training:\n",
" return classes, regresses\n",
" else:\n",
" return torch.argmax(classes, dim=-1), regresses\n",
"\n",
"\n",
"input_size = 5\n",
"hidden_size = 20\n",
"class_count = 100\n",
"regressor_output_size = 3\n",
"# model = MultiRNN(input_size, hidden_size, class_count, regressor_output_size)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a35aa9f2",
"metadata": {},
"outputs": [],
"source": [
"class MutliRNNLightning(L.LightningModule):\n",
" def __init__(self, input_size, hidden_size, class_count, output_size):\n",
" super().__init__()\n",
" self.model = MultiRNN(input_size, hidden_size, class_count, regressor_output_size)\n",
"\n",
" def forward(self, x):\n",
" return self.model.forward(x)\n",
" \n",
" def training_step(self, batch, batch_idx):\n",
" x, y1, y2 = batch\n",
" y_hat = self(x)\n",
" loss1 = nn.functional.cross_entropy(y_hat, y1)\n",
" loss2 = nn.functional.mse_loss(y_hat,y2)\n",
" # loss = F.cross_entropy(y_hat, y)\n",
" loss = loss1 + loss2\n",
" self.log('train_loss_class', loss1)\n",
" self.log('train_loss_regress', loss1)\n",
" return loss\n",
"\n",
" def training_step(self, batch, batch_idx):\n",
" x, y = batch\n",
" y_hat = self(x)\n",
" \n",
" loss = F.cross_entropy(y_hat, y)\n",
" acc = (y_hat.argmax(1) == y).float().mean()\n",
" \n",
" self.log(\"train_loss\", loss)\n",
" self.log(\"train_acc\", acc)\n",
" return loss\n",
" \n",
"\n",
" def configure_optimizers(self):\n",
" optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-3)\n",
" scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(\n",
" optimizer, mode=\"min\", factor=0.1, patience=5\n",
" )\n",
" \n",
" return {\n",
" \"optimizer\": optimizer,\n",
" \"lr_scheduler\": {\n",
" \"scheduler\": scheduler,\n",
" \"monitor\": \"val_loss\",\n",
" },\n",
" }"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b49f9d92",
"metadata": {},
"outputs": [],
"source": [
"sequence_length = 1000\n",
"\n",
"# criterion = nn.MSELoss()\n",
"class_loss = nn.CrossEntropyLoss(weight=torch.from_numpy(inv_mag_frequency))\n",
"regressor_loss = nn.MSELoss()\n",
"# optimizer = optim.Adam(model.parameters(), lr=0.001)\n",
"\n",
"X = torch.tensor(df.iloc[:sequence_length][['time', 'lat', 'lon', 'depth', 'mag']].values).to(torch.float)\n",
"X = torch.stack((torch.tensor(df.iloc[sequence_length:sequence_length*2][['time', 'lat', 'lon', 'depth', 'mag']].values).to(torch.float),X))\n",
"# print(X.shape)\n",
"model.train()\n",
"outputs = model(X)\n",
"# print(outputs[0].shape)\n",
"# print(outputs[1].shape)\n",
"\n",
"# model.eval()\n",
"# outputs = model(X)\n",
"# print(outputs[0].shape)\n",
"# print(outputs[1].shape)\n",
"\n",
"# num_epochs = 100\n",
"# for epoch in range(num_epochs):\n",
"# model.train()\n",
"# outputs = model(X.unsqueeze(2))\n",
"# loss = criterion(outputs, y.unsqueeze(2))\n",
" \n",
"# optimizer.zero_grad()\n",
"# loss.backward()\n",
"# optimizer.step()\n",
" \n",
"# if (epoch + 1) % 10 == 0:\n",
"# print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "sideprojects",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.8"
}
},
"nbformat": 4,
"nbformat_minor": 5
}