Compare commits

...

10 Commits

Author SHA1 Message Date
28096a932a Adding gitignore so that the acquired data isn't put into the repo and also the model logs aren't either
If it's important enough, it can be placed in a different folder

Signed-off-by: Ethan Wellenreiter <ewellenreiter@gmail.com>
2025-06-23 13:29:17 -04:00
4596d15c0a The model investigative jupyter notebook. It's older
Signed-off-by: Ethan Wellenreiter <ewellenreiter@gmail.com>
2025-06-23 13:28:03 -04:00
f527929ba1 Adding/updating README
Signed-off-by: Ethan Wellenreiter <ewellenreiter@gmail.com>
2025-06-23 13:27:38 -04:00
552404d998 Current generation of model running
It's a simple 1 look ahead RNN. Currently set to be a LSTM
Need to update the metrics to also include high-magnitude prediction accuracy

Signed-off-by: Ethan Wellenreiter <ewellenreiter@gmail.com>
2025-06-23 13:24:04 -04:00
56d78d0973 Moving data related stuff to a folder
Signed-off-by: Ethan Wellenreiter <ewellenreiter@gmail.com>
2025-06-23 13:22:09 -04:00
e52ee45261 An initial and unfinished attempt at creating an RNN model for the problem
Signed-off-by: Ethan Wellenreiter <ewellenreiter@gmail.com>
2025-05-20 15:31:17 -04:00
c3524eda21 A brief look at pulling data for only extreme events
Signed-off-by: Ethan Wellenreiter <ewellenreiter@gmail.com>
2025-05-20 15:30:52 -04:00
6c2247974c A notebook for pulling seismic event data, cleaning it and saving it
Signed-off-by: Ethan Wellenreiter <ewellenreiter@gmail.com>
2025-05-20 15:30:26 -04:00
fc03e01629 Exploring Extreme Value Theory (EVT) techniques
Looking into how it may help with reducing to our problem

Signed-off-by: Ethan Wellenreiter <ewellenreiter@gmail.com>
2025-05-20 15:27:06 -04:00
6038fdef50 Initial files to configure data querying and live websocket connections
Signed-off-by: Ethan Wellenreiter <ewellenreiter@gmail.com>
2025-05-20 15:24:35 -04:00
10 changed files with 3940 additions and 0 deletions

2
.gitignore vendored Normal file
View File

@ -0,0 +1,2 @@
/data
/lightning_logs

View File

@ -1,3 +1,18 @@
# A repository for exploration into earthquake forecasting
### Things to do
1. Update the metric for the RNN to include an accuracy for high magnitude stuff
2. Try farther look ahead models and/or just an amount of time ahead
- can consider something like an auto-transformer type build where it feeds it's prediction back into itself until it's next prediction is outside of a time?
#### Frame the issue
Need to care more about the rare events. Let's say above mag 8. or mag 6.
##### Run command
(current best performance)
C:/Users/ewell/anaconda3/envs/test_space/python.exe d:/projects/earthquake_prediction_exploration/rnn_regression_model.py --max_epochs 50 -x time_to_next_event lat lon depth mag --hidden_size 1000 --seq_size 100 --batch_size 1000

View File

@ -0,0 +1,338 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 18,
"id": "e06a5cf6",
"metadata": {},
"outputs": [],
"source": [
"import requests\n",
"import obspy"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "434a13eb",
"metadata": {},
"outputs": [],
"source": [
"url = \"www.seismicportal.eu/fdsnws/event/1/query?limit=10&start=2020-01-01&end=2022-01-01&format=json\""
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "e72fb8f7",
"metadata": {},
"outputs": [],
"source": [
"def geturl(url):\n",
" res = requests.get(\"https://\"+url, timeout=15)\n",
" return {'status': res.status_code,\n",
" 'content': res.text}"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "19355da2",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{\"type\":\"FeatureCollection\",\"metadata\":{\"count\":10},\"features\":[{\n",
" \"geometry\": {\n",
" \"type\": \"Point\",\n",
" \"coordinates\": [\n",
" 122.38,\n",
" -8.11,\n",
" -10.0\n",
" ]\n",
" },\n",
" \"type\": \"Feature\",\n",
" \"id\": \"20220101_0000155\",\n",
" \"properties\": {\n",
" \"lastupdate\": \"2022-01-02T06:11:00.0Z\",\n",
" \"magtype\": \"m\",\n",
" \"evtype\": \"ke\",\n",
" \"lon\": 122.38,\n",
" \"auth\": \"DJA\",\n",
" \"lat\": -8.11,\n",
" \"depth\": 10.0,\n",
" \"unid\": \"20220101_0000155\",\n",
" \"mag\": 2.9,\n",
" \"time\": \"2022-01-01T23:50:06.0Z\",\n",
" \"source_id\": \"1083157\",\n",
" \"source_catalog\": \"EMSC-RTS\",\n",
" \"flynn_region\": \"FLORES REGION, INDONESIA\"\n",
" }\n",
"},{\n",
" \"geometry\": {\n",
" \"type\": \"Point\",\n",
" \"coordinates\": [\n",
" 124.09,\n",
" -8.84,\n",
" -65.0\n",
" ]\n",
" },\n",
" \"type\": \"Feature\",\n",
" \"id\": \"20220101_0000153\",\n",
" \"properties\": {\n",
" \"lastupdate\": \"2022-01-02T06:11:00.0Z\",\n",
" \"magtype\": \"m\",\n",
" \"evtype\": \"ke\",\n",
" \"lon\": 124.09,\n",
" \"auth\": \"DJA\",\n",
" \"lat\": -8.84,\n",
" \"depth\": 65.0,\n",
" \"unid\": \"20220101_0000153\",\n",
" \"mag\": 3.3,\n",
" \"time\": \"2022-01-01T23:48:18.0Z\",\n",
" \"source_id\": \"1083154\",\n",
" \"source_catalog\": \"EMSC-RTS\",\n",
" \"flynn_region\": \"KEPULAUAN ALOR, INDONESIA\"\n",
" }\n",
"},{\n",
" \"geometry\": {\n",
" \"type\": \"Point\",\n",
" \"coordinates\": [\n",
" -173.92,\n",
" -21.37,\n",
" -10.0\n",
" ]\n",
" },\n",
" \"type\": \"Feature\",\n",
" \"id\": \"20220101_0000137\",\n",
" \"properties\": {\n",
" \"lastupdate\": \"2022-01-02T08:20:00.0Z\",\n",
" \"magtype\": \"mb\",\n",
" \"evtype\": \"ke\",\n",
" \"lon\": -173.92,\n",
" \"auth\": \"EMSC\",\n",
" \"lat\": -21.37,\n",
" \"depth\": 10.0,\n",
" \"unid\": \"20220101_0000137\",\n",
" \"mag\": 4.9,\n",
" \"time\": \"2022-01-01T23:47:39.0Z\",\n",
" \"source_id\": \"1083084\",\n",
" \"source_catalog\": \"EMSC-RTS\",\n",
" \"flynn_region\": \"TONGA\"\n",
" }\n",
"},{\n",
" \"geometry\": {\n",
" \"type\": \"Point\",\n",
" \"coordinates\": [\n",
" -66.64,\n",
" -23.73,\n",
" -233.0\n",
" ]\n",
" },\n",
" \"type\": \"Feature\",\n",
" \"id\": \"20220101_0000136\",\n",
" \"properties\": {\n",
" \"lastupdate\": \"2022-01-01T23:56:00.0Z\",\n",
" \"magtype\": \"m\",\n",
" \"evtype\": \"ke\",\n",
" \"lon\": -66.64,\n",
" \"auth\": \"NSNA\",\n",
" \"lat\": -23.73,\n",
" \"depth\": 233.0,\n",
" \"unid\": \"20220101_0000136\",\n",
" \"mag\": 3.1,\n",
" \"time\": \"2022-01-01T23:45:36.0Z\",\n",
" \"source_id\": \"1083085\",\n",
" \"source_catalog\": \"EMSC-RTS\",\n",
" \"flynn_region\": \"JUJUY, ARGENTINA\"\n",
" }\n",
"},{\n",
" \"geometry\": {\n",
" \"type\": \"Point\",\n",
" \"coordinates\": [\n",
" -155.4,\n",
" 19.2,\n",
" -32.0\n",
" ]\n",
" },\n",
" \"type\": \"Feature\",\n",
" \"id\": \"20220101_0000133\",\n",
" \"properties\": {\n",
" \"lastupdate\": \"2022-01-01T23:45:00.0Z\",\n",
" \"magtype\": \"md\",\n",
" \"evtype\": \"ke\",\n",
" \"lon\": -155.4,\n",
" \"auth\": \"NEIR\",\n",
" \"lat\": 19.2,\n",
" \"depth\": 32.0,\n",
" \"unid\": \"20220101_0000133\",\n",
" \"mag\": 2.2,\n",
" \"time\": \"2022-01-01T23:42:34.8Z\",\n",
" \"source_id\": \"1083079\",\n",
" \"source_catalog\": \"EMSC-RTS\",\n",
" \"flynn_region\": \"ISLAND OF HAWAII, HAWAII\"\n",
" }\n",
"},{\n",
" \"geometry\": {\n",
" \"type\": \"Point\",\n",
" \"coordinates\": [\n",
" -16.21,\n",
" 28.09,\n",
" -8.0\n",
" ]\n",
" },\n",
" \"type\": \"Feature\",\n",
" \"id\": \"20220101_0000202\",\n",
" \"properties\": {\n",
" \"lastupdate\": \"2022-01-02T20:45:00.0Z\",\n",
" \"magtype\": \"ml\",\n",
" \"evtype\": \"ke\",\n",
" \"lon\": -16.21,\n",
" \"auth\": \"MDD\",\n",
" \"lat\": 28.09,\n",
" \"depth\": 8.0,\n",
" \"unid\": \"20220101_0000202\",\n",
" \"mag\": 1.8,\n",
" \"time\": \"2022-01-01T23:40:04.8Z\",\n",
" \"source_id\": \"1083360\",\n",
" \"source_catalog\": \"EMSC-RTS\",\n",
" \"flynn_region\": \"CANARY ISLANDS, SPAIN REGION\"\n",
" }\n",
"},{\n",
" \"geometry\": {\n",
" \"type\": \"Point\",\n",
" \"coordinates\": [\n",
" -69.31,\n",
" 18.08,\n",
" -10.0\n",
" ]\n",
" },\n",
" \"type\": \"Feature\",\n",
" \"id\": \"20220101_0000134\",\n",
" \"properties\": {\n",
" \"lastupdate\": \"2022-01-01T23:51:00.0Z\",\n",
" \"magtype\": \"m\",\n",
" \"evtype\": \"ke\",\n",
" \"lon\": -69.31,\n",
" \"auth\": \"UASD\",\n",
" \"lat\": 18.08,\n",
" \"depth\": 10.0,\n",
" \"unid\": \"20220101_0000134\",\n",
" \"mag\": 3.1,\n",
" \"time\": \"2022-01-01T23:25:21.0Z\",\n",
" \"source_id\": \"1083082\",\n",
" \"source_catalog\": \"EMSC-RTS\",\n",
" \"flynn_region\": \"DOMINICAN REPUBLIC REGION\"\n",
" }\n",
"},{\n",
" \"geometry\": {\n",
" \"type\": \"Point\",\n",
" \"coordinates\": [\n",
" -74.06,\n",
" 18.95,\n",
" -10.0\n",
" ]\n",
" },\n",
" \"type\": \"Feature\",\n",
" \"id\": \"20220101_0000132\",\n",
" \"properties\": {\n",
" \"lastupdate\": \"2022-01-01T23:36:00.0Z\",\n",
" \"magtype\": \"m\",\n",
" \"evtype\": \"ke\",\n",
" \"lon\": -74.06,\n",
" \"auth\": \"UASD\",\n",
" \"lat\": 18.95,\n",
" \"depth\": 10.0,\n",
" \"unid\": \"20220101_0000132\",\n",
" \"mag\": 3.1,\n",
" \"time\": \"2022-01-01T23:19:03.0Z\",\n",
" \"source_id\": \"1083078\",\n",
" \"source_catalog\": \"EMSC-RTS\",\n",
" \"flynn_region\": \"HAITI REGION\"\n",
" }\n",
"},{\n",
" \"geometry\": {\n",
" \"type\": \"Point\",\n",
" \"coordinates\": [\n",
" -98.05,\n",
" 16.31,\n",
" -8.0\n",
" ]\n",
" },\n",
" \"type\": \"Feature\",\n",
" \"id\": \"20220101_0000141\",\n",
" \"properties\": {\n",
" \"lastupdate\": \"2022-01-02T02:01:00.0Z\",\n",
" \"magtype\": \"m\",\n",
" \"evtype\": \"ke\",\n",
" \"lon\": -98.05,\n",
" \"auth\": \"UNM\",\n",
" \"lat\": 16.31,\n",
" \"depth\": 8.0,\n",
" \"unid\": \"20220101_0000141\",\n",
" \"mag\": 3.3,\n",
" \"time\": \"2022-01-01T23:00:27.0Z\",\n",
" \"source_id\": \"1083098\",\n",
" \"source_catalog\": \"EMSC-RTS\",\n",
" \"flynn_region\": \"OAXACA, MEXICO\"\n",
" }\n",
"},{\n",
" \"geometry\": {\n",
" \"type\": \"Point\",\n",
" \"coordinates\": [\n",
" -68.86,\n",
" -21.1,\n",
" -109.0\n",
" ]\n",
" },\n",
" \"type\": \"Feature\",\n",
" \"id\": \"20220101_0000131\",\n",
" \"properties\": {\n",
" \"lastupdate\": \"2022-01-01T23:08:00.0Z\",\n",
" \"magtype\": \"m\",\n",
" \"evtype\": \"ke\",\n",
" \"lon\": -68.86,\n",
" \"auth\": \"GUC\",\n",
" \"lat\": -21.1,\n",
" \"depth\": 109.0,\n",
" \"unid\": \"20220101_0000131\",\n",
" \"mag\": 2.7,\n",
" \"time\": \"2022-01-01T22:57:18.0Z\",\n",
" \"source_id\": \"1083075\",\n",
" \"source_catalog\": \"EMSC-RTS\",\n",
" \"flynn_region\": \"TARAPACA, CHILE\"\n",
" }\n",
"}]}\n"
]
}
],
"source": [
"res = geturl(url)\n",
"print(res['content'])\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "sideprojects",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.8"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@ -0,0 +1,54 @@
from __future__ import unicode_literals
from tornado.websocket import websocket_connect
from tornado.ioloop import IOLoop
from tornado import gen
import logging
import json
import sys
echo_uri = 'wss://www.seismicportal.eu/standing_order/websocket'
PING_INTERVAL = 15
#You can modify this function to run custom process on the message
def myprocessing(message):
try:
data = json.loads(message)
print(data)
info = data['data']['properties']
info['action'] = data['action']
logging.info('>>>> {action:7} event from {auth:7}, unid:{unid}, T0:{time}, Mag:{mag}, Region: {flynn_region}'.format(**info))
except Exception:
logging.exception("Unable to parse json message")
@gen.coroutine
def listen(ws):
while True:
msg = yield ws.read_message()
if msg is None:
logging.info("close")
ws = None
break
myprocessing(msg)
@gen.coroutine
def launch_client():
try:
logging.info("Open WebSocket connection to %s", echo_uri)
ws = yield websocket_connect(echo_uri, ping_interval=PING_INTERVAL)
except Exception:
logging.exception("connection error")
else:
logging.info("Waiting for messages...")
listen(ws)
if __name__ == '__main__':
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
ioloop = IOLoop.instance()
launch_client()
try:
ioloop.start()
except KeyboardInterrupt:
logging.info("Close WebSocket")
ioloop.stop()

File diff suppressed because one or more lines are too long

View File

@ -0,0 +1,402 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 13,
"id": "bd523899",
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import pandas as pd\n",
"import matplotlib.pyplot as plt\n",
"import scipy.stats as stats\n",
"from scipy.stats import genextreme\n",
"from scipy.stats import genpareto\n",
"import requests\n",
"import json\n",
"\n",
"from datetime import date\n",
"from datetime import datetime\n",
"from datetime import timedelta\n",
"import pytz\n",
"\n",
"import os.path"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "c63a1b94",
"metadata": {},
"outputs": [],
"source": [
"start_date = \"2024-01-01\""
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "903ed374",
"metadata": {},
"outputs": [],
"source": [
"def geturl(url):\n",
" res = requests.get(\"https://\"+url, timeout=15)\n",
" return {'status': res.status_code,\n",
" 'content': res.text}"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "89e1979c",
"metadata": {},
"outputs": [],
"source": [
"def retrieveEvents(start, end, limit=20000, retry_limit=5, minmag=0):\n",
" events = []\n",
" moving_start = datetime.fromordinal(start.toordinal()).replace(tzinfo=pytz.utc)\n",
" end = datetime.fromordinal(end.toordinal()).replace(tzinfo=pytz.utc)\n",
" failures = 0\n",
" while moving_start <= end and failures < retry_limit:\n",
" # print(moving_start, end)\n",
" url = \"www.seismicportal.eu/fdsnws/event/1/query?orderby=time-asc&limit={limit}&start={startdate}&end={enddate}&format=json&minmag={minmag}\".format(limit=limit, startdate=moving_start.isoformat(), enddate=end.isoformat(), minmag=minmag)\n",
" # print(url)\n",
" res = geturl(url)\n",
" # print(res['status'])\n",
" if res['status'] != 200:\n",
" failures += 1\n",
" continue\n",
" content = res['content']\n",
" json_parser = json.loads(content)\n",
" temp_events = [event['properties'] for event in json_parser['features']]\n",
"\n",
" if len(temp_events) == 0:\n",
" # print(\"ending\")\n",
" break\n",
"\n",
" # temp_events = sorted(temp_events, key=lambda d: d['time'])\n",
"\n",
" if len(temp_events) == limit:\n",
" moving_start = datetime.fromisoformat(temp_events[-1]['time'])\n",
" else:\n",
" moving_start = end + timedelta(hours=1)\n",
" # print(\"ending here:\", moving_start)\n",
" events.extend(temp_events)\n",
" # print(\"hi\")\n",
" # return pd.DataFrame(events)\n",
" return events\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "8bccb27c",
"metadata": {},
"outputs": [],
"source": [
"data = retrieveEvents(date.fromisoformat(start_date), date.today(), minmag=2)\n",
"df = pd.DataFrame(data)"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "902b6b1e",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>source_id</th>\n",
" <th>source_catalog</th>\n",
" <th>lastupdate</th>\n",
" <th>time</th>\n",
" <th>flynn_region</th>\n",
" <th>lat</th>\n",
" <th>lon</th>\n",
" <th>depth</th>\n",
" <th>evtype</th>\n",
" <th>auth</th>\n",
" <th>mag</th>\n",
" <th>magtype</th>\n",
" <th>unid</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>1600054</td>\n",
" <td>EMSC-RTS</td>\n",
" <td>2024-01-01T00:02:51.439437Z</td>\n",
" <td>2024-01-01T00:00:29.5Z</td>\n",
" <td>CRETE, GREECE</td>\n",
" <td>35.1400</td>\n",
" <td>24.1200</td>\n",
" <td>10.0</td>\n",
" <td>ke</td>\n",
" <td>THE</td>\n",
" <td>2.3</td>\n",
" <td>ml</td>\n",
" <td>20240101_0000001</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>1600055</td>\n",
" <td>EMSC-RTS</td>\n",
" <td>2024-01-01T00:14:14.3925Z</td>\n",
" <td>2024-01-01T00:03:15.0Z</td>\n",
" <td>SULAWESI, INDONESIA</td>\n",
" <td>-1.3000</td>\n",
" <td>120.5100</td>\n",
" <td>10.0</td>\n",
" <td>ke</td>\n",
" <td>BMKG</td>\n",
" <td>3.1</td>\n",
" <td>m</td>\n",
" <td>20240101_0000002</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>1600058</td>\n",
" <td>EMSC-RTS</td>\n",
" <td>2024-01-01T00:24:28.774809Z</td>\n",
" <td>2024-01-01T00:03:15.14Z</td>\n",
" <td>PUERTO RICO</td>\n",
" <td>18.4087</td>\n",
" <td>-66.4270</td>\n",
" <td>105.2</td>\n",
" <td>ke</td>\n",
" <td>PR</td>\n",
" <td>3.2</td>\n",
" <td>md</td>\n",
" <td>20240101_0000004</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>1600056</td>\n",
" <td>EMSC-RTS</td>\n",
" <td>2024-01-01T00:14:58.984143Z</td>\n",
" <td>2024-01-01T00:05:28.0Z</td>\n",
" <td>COLOMBIA-ECUADOR BORDER REGION</td>\n",
" <td>0.1100</td>\n",
" <td>-78.9400</td>\n",
" <td>54.0</td>\n",
" <td>ke</td>\n",
" <td>QUI</td>\n",
" <td>3.5</td>\n",
" <td>m</td>\n",
" <td>20240101_0000003</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>1600057</td>\n",
" <td>EMSC-RTS</td>\n",
" <td>2024-01-02T08:04:45.107234Z</td>\n",
" <td>2024-01-01T00:10:05.6Z</td>\n",
" <td>NORWEGIAN SEA</td>\n",
" <td>72.2450</td>\n",
" <td>1.8470</td>\n",
" <td>6.1</td>\n",
" <td>ke</td>\n",
" <td>BER</td>\n",
" <td>3.7</td>\n",
" <td>mw</td>\n",
" <td>20240101_0000408</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>126136</th>\n",
" <td>1809807</td>\n",
" <td>EMSC-RTS</td>\n",
" <td>2025-05-19T21:34:26.878673Z</td>\n",
" <td>2025-05-18T23:24:13.0Z</td>\n",
" <td>SAN JUAN, ARGENTINA</td>\n",
" <td>-31.6300</td>\n",
" <td>-70.3800</td>\n",
" <td>138.0</td>\n",
" <td>ke</td>\n",
" <td>CSN</td>\n",
" <td>3.4</td>\n",
" <td>ml</td>\n",
" <td>20250518_0000270</td>\n",
" </tr>\n",
" <tr>\n",
" <th>126137</th>\n",
" <td>1809808</td>\n",
" <td>EMSC-RTS</td>\n",
" <td>2025-05-18T23:38:19.408959Z</td>\n",
" <td>2025-05-18T23:27:24.13Z</td>\n",
" <td>SOUTHERN ITALY</td>\n",
" <td>39.0343</td>\n",
" <td>16.4318</td>\n",
" <td>9.8</td>\n",
" <td>ke</td>\n",
" <td>INGV</td>\n",
" <td>2.4</td>\n",
" <td>ml</td>\n",
" <td>20250518_0000271</td>\n",
" </tr>\n",
" <tr>\n",
" <th>126138</th>\n",
" <td>1809813</td>\n",
" <td>EMSC-RTS</td>\n",
" <td>2025-05-19T00:14:20.838747Z</td>\n",
" <td>2025-05-18T23:52:33.0Z</td>\n",
" <td>SOUTHWEST OF SUMATRA, INDONESIA</td>\n",
" <td>-7.8200</td>\n",
" <td>103.8600</td>\n",
" <td>10.0</td>\n",
" <td>ke</td>\n",
" <td>BMKG</td>\n",
" <td>3.5</td>\n",
" <td>m</td>\n",
" <td>20250518_0000275</td>\n",
" </tr>\n",
" <tr>\n",
" <th>126139</th>\n",
" <td>1809809</td>\n",
" <td>EMSC-RTS</td>\n",
" <td>2025-05-19T06:33:38.794931Z</td>\n",
" <td>2025-05-18T23:58:01.56Z</td>\n",
" <td>OFF COAST OF TARAPACA, CHILE</td>\n",
" <td>-18.5159</td>\n",
" <td>-71.3039</td>\n",
" <td>25.0</td>\n",
" <td>ke</td>\n",
" <td>EMSC</td>\n",
" <td>4.2</td>\n",
" <td>mb</td>\n",
" <td>20250518_0000273</td>\n",
" </tr>\n",
" <tr>\n",
" <th>126140</th>\n",
" <td>1809814</td>\n",
" <td>EMSC-RTS</td>\n",
" <td>2025-05-19T06:33:56.681166Z</td>\n",
" <td>2025-05-18T23:58:56.73Z</td>\n",
" <td>WESTERN TURKEY</td>\n",
" <td>37.8953</td>\n",
" <td>27.6327</td>\n",
" <td>11.4</td>\n",
" <td>ke</td>\n",
" <td>EMSC</td>\n",
" <td>2.4</td>\n",
" <td>ml</td>\n",
" <td>20250518_0000287</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>126141 rows × 13 columns</p>\n",
"</div>"
],
"text/plain": [
" source_id source_catalog lastupdate \\\n",
"0 1600054 EMSC-RTS 2024-01-01T00:02:51.439437Z \n",
"1 1600055 EMSC-RTS 2024-01-01T00:14:14.3925Z \n",
"2 1600058 EMSC-RTS 2024-01-01T00:24:28.774809Z \n",
"3 1600056 EMSC-RTS 2024-01-01T00:14:58.984143Z \n",
"4 1600057 EMSC-RTS 2024-01-02T08:04:45.107234Z \n",
"... ... ... ... \n",
"126136 1809807 EMSC-RTS 2025-05-19T21:34:26.878673Z \n",
"126137 1809808 EMSC-RTS 2025-05-18T23:38:19.408959Z \n",
"126138 1809813 EMSC-RTS 2025-05-19T00:14:20.838747Z \n",
"126139 1809809 EMSC-RTS 2025-05-19T06:33:38.794931Z \n",
"126140 1809814 EMSC-RTS 2025-05-19T06:33:56.681166Z \n",
"\n",
" time flynn_region lat \\\n",
"0 2024-01-01T00:00:29.5Z CRETE, GREECE 35.1400 \n",
"1 2024-01-01T00:03:15.0Z SULAWESI, INDONESIA -1.3000 \n",
"2 2024-01-01T00:03:15.14Z PUERTO RICO 18.4087 \n",
"3 2024-01-01T00:05:28.0Z COLOMBIA-ECUADOR BORDER REGION 0.1100 \n",
"4 2024-01-01T00:10:05.6Z NORWEGIAN SEA 72.2450 \n",
"... ... ... ... \n",
"126136 2025-05-18T23:24:13.0Z SAN JUAN, ARGENTINA -31.6300 \n",
"126137 2025-05-18T23:27:24.13Z SOUTHERN ITALY 39.0343 \n",
"126138 2025-05-18T23:52:33.0Z SOUTHWEST OF SUMATRA, INDONESIA -7.8200 \n",
"126139 2025-05-18T23:58:01.56Z OFF COAST OF TARAPACA, CHILE -18.5159 \n",
"126140 2025-05-18T23:58:56.73Z WESTERN TURKEY 37.8953 \n",
"\n",
" lon depth evtype auth mag magtype unid \n",
"0 24.1200 10.0 ke THE 2.3 ml 20240101_0000001 \n",
"1 120.5100 10.0 ke BMKG 3.1 m 20240101_0000002 \n",
"2 -66.4270 105.2 ke PR 3.2 md 20240101_0000004 \n",
"3 -78.9400 54.0 ke QUI 3.5 m 20240101_0000003 \n",
"4 1.8470 6.1 ke BER 3.7 mw 20240101_0000408 \n",
"... ... ... ... ... ... ... ... \n",
"126136 -70.3800 138.0 ke CSN 3.4 ml 20250518_0000270 \n",
"126137 16.4318 9.8 ke INGV 2.4 ml 20250518_0000271 \n",
"126138 103.8600 10.0 ke BMKG 3.5 m 20250518_0000275 \n",
"126139 -71.3039 25.0 ke EMSC 4.2 mb 20250518_0000273 \n",
"126140 27.6327 11.4 ke EMSC 2.4 ml 20250518_0000287 \n",
"\n",
"[126141 rows x 13 columns]"
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "sideprojects",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.8"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

File diff suppressed because one or more lines are too long

734
full_regression_model.ipynb Normal file
View File

@ -0,0 +1,734 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 160,
"id": "98beda53",
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import pandas as pd\n",
"import matplotlib.pyplot as plt\n",
"import scipy.stats as stats\n",
"from scipy.stats import genextreme\n",
"from scipy.stats import genpareto\n",
"import requests\n",
"import json\n",
"\n",
"# from datetime import date\n",
"# from datetime import datetime\n",
"# from datetime import timedelta\n",
"# import pytz\n",
"\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.optim as optim\n",
"import lightning as L\n",
"\n",
"import os\n"
]
},
{
"cell_type": "markdown",
"id": "f15607d7",
"metadata": {},
"source": [
"### Import data"
]
},
{
"cell_type": "code",
"execution_count": 161,
"id": "57505637",
"metadata": {},
"outputs": [],
"source": [
"df = pd.read_csv('reduced_data.csv')\n",
"# df['time'] = pd.to_datetime(df['time'], format='ISO8601')\n",
"\n",
"# df['lastupdate'] = pd.to_datetime(df['lastupdate'], format='ISO8601')"
]
},
{
"cell_type": "code",
"execution_count": 162,
"id": "053cc158",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>bin</th>\n",
" <th>time_to_next_event</th>\n",
" <th>time</th>\n",
" <th>lat</th>\n",
" <th>lon</th>\n",
" <th>depth</th>\n",
" <th>mag</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>27</td>\n",
" <td>3408</td>\n",
" <td>1.262306e+09</td>\n",
" <td>37.5835</td>\n",
" <td>22.1704</td>\n",
" <td>15.0</td>\n",
" <td>2.8</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>21</td>\n",
" <td>495</td>\n",
" <td>1.262310e+09</td>\n",
" <td>42.8370</td>\n",
" <td>13.0010</td>\n",
" <td>9.4</td>\n",
" <td>2.1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>22</td>\n",
" <td>1377</td>\n",
" <td>1.262310e+09</td>\n",
" <td>46.1624</td>\n",
" <td>12.2834</td>\n",
" <td>10.0</td>\n",
" <td>2.3</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>50</td>\n",
" <td>274</td>\n",
" <td>1.262312e+09</td>\n",
" <td>26.4100</td>\n",
" <td>99.9100</td>\n",
" <td>10.0</td>\n",
" <td>5.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>25</td>\n",
" <td>11</td>\n",
" <td>1.262312e+09</td>\n",
" <td>39.0807</td>\n",
" <td>31.0495</td>\n",
" <td>7.0</td>\n",
" <td>2.5</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1062937</th>\n",
" <td>22</td>\n",
" <td>146</td>\n",
" <td>1.749481e+09</td>\n",
" <td>39.2261</td>\n",
" <td>28.9878</td>\n",
" <td>10.8</td>\n",
" <td>2.3</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1062938</th>\n",
" <td>40</td>\n",
" <td>166</td>\n",
" <td>1.749481e+09</td>\n",
" <td>17.5500</td>\n",
" <td>-94.3070</td>\n",
" <td>186.8</td>\n",
" <td>4.1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1062939</th>\n",
" <td>30</td>\n",
" <td>36</td>\n",
" <td>1.749481e+09</td>\n",
" <td>35.3752</td>\n",
" <td>-3.6249</td>\n",
" <td>25.7</td>\n",
" <td>3.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1062940</th>\n",
" <td>20</td>\n",
" <td>25</td>\n",
" <td>1.749481e+09</td>\n",
" <td>39.2152</td>\n",
" <td>29.0178</td>\n",
" <td>5.4</td>\n",
" <td>2.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1062941</th>\n",
" <td>32</td>\n",
" <td>173</td>\n",
" <td>1.749481e+09</td>\n",
" <td>14.8400</td>\n",
" <td>119.4300</td>\n",
" <td>16.0</td>\n",
" <td>3.2</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>1062942 rows × 7 columns</p>\n",
"</div>"
],
"text/plain": [
" bin time_to_next_event time lat lon depth mag\n",
"0 27 3408 1.262306e+09 37.5835 22.1704 15.0 2.8\n",
"1 21 495 1.262310e+09 42.8370 13.0010 9.4 2.1\n",
"2 22 1377 1.262310e+09 46.1624 12.2834 10.0 2.3\n",
"3 50 274 1.262312e+09 26.4100 99.9100 10.0 5.0\n",
"4 25 11 1.262312e+09 39.0807 31.0495 7.0 2.5\n",
"... ... ... ... ... ... ... ...\n",
"1062937 22 146 1.749481e+09 39.2261 28.9878 10.8 2.3\n",
"1062938 40 166 1.749481e+09 17.5500 -94.3070 186.8 4.1\n",
"1062939 30 36 1.749481e+09 35.3752 -3.6249 25.7 3.0\n",
"1062940 20 25 1.749481e+09 39.2152 29.0178 5.4 2.0\n",
"1062941 32 173 1.749481e+09 14.8400 119.4300 16.0 3.2\n",
"\n",
"[1062942 rows x 7 columns]"
]
},
"execution_count": 162,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df"
]
},
{
"cell_type": "markdown",
"id": "b17d72f3",
"metadata": {},
"source": [
"### Dataset loader"
]
},
{
"cell_type": "code",
"execution_count": 163,
"id": "e6dacf0d",
"metadata": {},
"outputs": [],
"source": [
"class TimeseriesDataset(torch.utils.data.Dataset):\n",
" def __init__(self, x, y, seq_len=100):\n",
" self.seq_len = seq_len\n",
" self.x = x\n",
" self.y = y\n",
" \n",
" def __len__(self):\n",
" return self.x.__len__() - self.seq_len\n",
" \n",
" def __getitem__(self, index):\n",
" return (self.x[index:index+self.seq_len], self.y[index+self.seq_len])\n"
]
},
{
"cell_type": "code",
"execution_count": 164,
"id": "f42c537e",
"metadata": {},
"outputs": [],
"source": [
"division_p1 = int(len(df)*0.6)\n",
"division_p2 = int(len(df)*0.8)"
]
},
{
"cell_type": "code",
"execution_count": 165,
"id": "2181b1ad",
"metadata": {},
"outputs": [],
"source": [
"x_columns = ['time_to_next_event','time','lat','lon','depth','mag']\n",
"y_columns = ['time_to_next_event', 'lat', 'lon','mag']"
]
},
{
"cell_type": "code",
"execution_count": 166,
"id": "10a8ac71",
"metadata": {},
"outputs": [],
"source": [
"x = torch.tensor(df[x_columns].values).to(torch.float)\n",
"# y1 = torch.tensor(df[['bin']].values)\n",
"y = torch.tensor(df[y_columns].values).to(torch.float)"
]
},
{
"cell_type": "code",
"execution_count": 167,
"id": "ee26bfb2",
"metadata": {},
"outputs": [],
"source": [
"train_x = x[:division_p1]\n",
"val_x = x[division_p1:division_p2]\n",
"test_x = x[division_p2:]\n",
"\n",
"train_y = y[:division_p1]\n",
"val_y = y[division_p1:division_p2]\n",
"test_y = y[division_p2:]"
]
},
{
"cell_type": "code",
"execution_count": 168,
"id": "4a739cb6",
"metadata": {},
"outputs": [],
"source": [
"train_dataset = TimeseriesDataset(train_x, train_y, seq_len=1000)\n",
"train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=10, shuffle=True)\n",
"\n",
"val_dataset = TimeseriesDataset(val_x, val_y, seq_len=1000)\n",
"val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=10, shuffle=True)"
]
},
{
"cell_type": "markdown",
"id": "119f29ee",
"metadata": {},
"source": [
"### Helper classes\n"
]
},
{
"cell_type": "code",
"execution_count": 169,
"id": "973b9c65",
"metadata": {},
"outputs": [],
"source": [
"## Copied from https://stackoverflow.com/questions/50817916/how-do-i-add-lstm-gru-or-other-recurrent-layers-to-a-sequential-in-pytorch\n",
"class SelectItem(nn.Module):\n",
" def __init__(self, item_index):\n",
" super(SelectItem, self).__init__()\n",
" self._name = 'selectitem'\n",
" self.item_index = item_index\n",
"\n",
" def forward(self, inputs):\n",
" return inputs[self.item_index]"
]
},
{
"cell_type": "markdown",
"id": "971e8319",
"metadata": {},
"source": [
"### Model design"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "92b3b995",
"metadata": {},
"outputs": [],
"source": [
"class MultiRNNLightning(L.LightningModule):\n",
" def __init__(self, input_size, hidden_size, output_size):\n",
" super().__init__()\n",
" self.model = nn.Sequential(nn.LSTM(input_size, hidden_size, batch_first=True),\n",
" SelectItem(1),\n",
" SelectItem(0),\n",
" nn.ReLU(),\n",
" nn.Linear(hidden_size, output_size))\n",
" \n",
" # self.model = nn.GRU(input_size, hidden_size, batch_first=True)\n",
" \n",
" # self.device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
" # self.model.to(self.device)\n",
"\n",
" def forward(self, x):\n",
" # print(\"hi\")\n",
" return torch.squeeze(self.model(x), 0)\n",
" \n",
" def configure_optimizers(self):\n",
" optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-3)\n",
" scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(\n",
" optimizer, mode=\"min\", factor=0.1, patience=5\n",
" )\n",
" \n",
" return {\n",
" \"optimizer\": optimizer,\n",
" \"lr_scheduler\": {\n",
" \"scheduler\": scheduler,\n",
" \"monitor\": \"val_loss\",\n",
" },\n",
" }\n",
" \n",
" def training_step(self, batch, batch_idx):\n",
" x, y = batch\n",
" \n",
" # x.to(self.device)\n",
" # y.to(self.device)\n",
" \n",
" # print(\"x shape\", x.shape)\n",
" y_hat = self(x)\n",
" # print(y_hat.shape)\n",
" # print(\"y shape\", y.shape)\n",
" # return 0\n",
" # print(\"yhat_shape\", y_hat.shape)\n",
" loss = nn.functional.mse_loss(y_hat, y)\n",
" self.log('train_loss', loss) #, on_epoch=True)\n",
" return loss\n",
" \n",
" def validation_step(self, val_batch, batch_idx):\n",
" x, y = val_batch\n",
" \n",
" # x.to(self.device)\n",
" # y.to(self.device)\n",
" \n",
" \n",
" y_hat = self(x)\n",
" loss = nn.functional.mse_loss(y_hat, y)\n",
" self.log('val_loss', loss) #, on_epoch=True)\n",
" return loss\n",
"\n",
" # def training_step(self, batch, batch_idx):\n",
" # x, y = batch\n",
" # class_y_hat, regress_y_hat = self(x)\n",
" \n",
" # loss = F.cross_entropy(y_hat, y)\n",
" # acc = (y_hat.argmax(1) == y).float().mean()\n",
" \n",
" # self.log(\"train_loss\", loss)\n",
" # self.log(\"train_acc\", acc)\n",
" # return loss\n",
" \n",
"\n"
]
},
{
"cell_type": "markdown",
"id": "741ce0f7",
"metadata": {},
"source": [
"#### Model Params"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "886809fd",
"metadata": {},
"outputs": [],
"source": [
"input_size = len(x_columns)\n",
"hidden_size = 1000\n",
"output_size = len(y_columns)"
]
},
{
"cell_type": "markdown",
"id": "25598f51",
"metadata": {},
"source": [
"### Train model"
]
},
{
"cell_type": "code",
"execution_count": 172,
"id": "61fa862f",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"GPU available: True (cuda), used: True\n",
"TPU available: False, using: 0 TPU cores\n",
"HPU available: False, using: 0 HPUs\n"
]
}
],
"source": [
"model = MultiRNNLightning(input_size, hidden_size, output_size)\n",
"\n",
"trainer = L.Trainer(max_epochs=10)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e847d0c2",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n",
"\n",
" | Name | Type | Params | Mode \n",
"---------------------------------------------\n",
"0 | model | Sequential | 167 K | train\n",
"---------------------------------------------\n",
"167 K Trainable params\n",
"0 Non-trainable params\n",
"167 K Total params\n",
"0.669 Total estimated model params size (MB)\n",
"6 Modules in train mode\n",
"0 Modules in eval mode\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "adec569b050341aab3b2c94e22152976",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Sanity Checking: | | 0/? [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"c:\\Users\\ewell\\anaconda3\\envs\\test_space\\Lib\\site-packages\\lightning\\pytorch\\trainer\\connectors\\data_connector.py:476: Your `val_dataloader`'s sampler has shuffling enabled, it is strongly recommended that you turn shuffling off for val/test dataloaders.\n",
"c:\\Users\\ewell\\anaconda3\\envs\\test_space\\Lib\\site-packages\\lightning\\pytorch\\trainer\\connectors\\data_connector.py:425: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.\n",
"c:\\Users\\ewell\\anaconda3\\envs\\test_space\\Lib\\site-packages\\lightning\\pytorch\\trainer\\connectors\\data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "b2c2156c3db24dadbeff025679a1e7b4",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Training: | | 0/? [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "80970987a81d4b059981dce44b21cdc8",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Validation: | | 0/? [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "d21a94f8276d4dea8ecd543ff7d5bdd2",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Validation: | | 0/? [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "5c1b8ebb789e4c80ad4ec452c9ccbbda",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Validation: | | 0/? [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "88b213664aaa4bc7bf57b4e5cfff3ba7",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Validation: | | 0/? [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "26637bbce8b2455bb72a6268f2e2196c",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Validation: | | 0/? [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "c067cd1c645d49c8a11cae04e3d676dc",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Validation: | | 0/? [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "eb0dc33095f841c7a30376c0bdb0fce0",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Validation: | | 0/? [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "12667695d03e4baf9ab468a6b57349e0",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Validation: | | 0/? [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "e65958ac8567425d9e9722f6023f8660",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Validation: | | 0/? [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "c6ff7598384d45c3bd192822cd0070b0",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Validation: | | 0/? [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"`Trainer.fit` stopped: `max_epochs=10` reached.\n"
]
}
],
"source": [
"trainer.fit(model, train_loader, val_loader)"
]
},
{
"cell_type": "code",
"execution_count": 174,
"id": "5d146588",
"metadata": {},
"outputs": [],
"source": [
"# %reload_ext tensorboard\n",
"# %tensorboard --logdir=lightning_logs/"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.11"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

563
model_training.ipynb Normal file
View File

@ -0,0 +1,563 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "b4cc996f",
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import pandas as pd\n",
"import matplotlib.pyplot as plt\n",
"import scipy.stats as stats\n",
"from scipy.stats import genextreme\n",
"from scipy.stats import genpareto\n",
"import requests\n",
"import json\n",
"\n",
"# from datetime import date\n",
"# from datetime import datetime\n",
"# from datetime import timedelta\n",
"# import pytz\n",
"\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.optim as optim\n",
"import lightning as L\n",
"\n",
"import os\n"
]
},
{
"cell_type": "markdown",
"id": "11cc2375",
"metadata": {},
"source": [
"### Import data"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "b0fe3fe6",
"metadata": {},
"outputs": [],
"source": [
"df = pd.read_csv('reduced_data.csv')\n",
"# df['time'] = pd.to_datetime(df['time'], format='ISO8601')\n",
"\n",
"# df['lastupdate'] = pd.to_datetime(df['lastupdate'], format='ISO8601')"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "083c39ae",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>bin</th>\n",
" <th>time_to_next_event</th>\n",
" <th>time</th>\n",
" <th>lat</th>\n",
" <th>lon</th>\n",
" <th>depth</th>\n",
" <th>mag</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>27</td>\n",
" <td>24</td>\n",
" <td>1.577837e+09</td>\n",
" <td>19.2200</td>\n",
" <td>-67.1300</td>\n",
" <td>12.0</td>\n",
" <td>2.8</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>30</td>\n",
" <td>77</td>\n",
" <td>1.577837e+09</td>\n",
" <td>-2.7400</td>\n",
" <td>127.9000</td>\n",
" <td>20.0</td>\n",
" <td>3.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>25</td>\n",
" <td>425</td>\n",
" <td>1.577837e+09</td>\n",
" <td>19.0800</td>\n",
" <td>-67.0900</td>\n",
" <td>6.0</td>\n",
" <td>2.5</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>31</td>\n",
" <td>39</td>\n",
" <td>1.577837e+09</td>\n",
" <td>19.1900</td>\n",
" <td>-67.8400</td>\n",
" <td>28.0</td>\n",
" <td>3.1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>35</td>\n",
" <td>69</td>\n",
" <td>1.577837e+09</td>\n",
" <td>-25.6400</td>\n",
" <td>-70.5200</td>\n",
" <td>53.0</td>\n",
" <td>3.5</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>548217</th>\n",
" <td>21</td>\n",
" <td>113</td>\n",
" <td>1.747781e+09</td>\n",
" <td>35.9836</td>\n",
" <td>28.1056</td>\n",
" <td>7.1</td>\n",
" <td>2.1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>548218</th>\n",
" <td>18</td>\n",
" <td>941</td>\n",
" <td>1.747781e+09</td>\n",
" <td>36.2797</td>\n",
" <td>36.0253</td>\n",
" <td>7.0</td>\n",
" <td>1.9</td>\n",
" </tr>\n",
" <tr>\n",
" <th>548219</th>\n",
" <td>11</td>\n",
" <td>141</td>\n",
" <td>1.747782e+09</td>\n",
" <td>37.2222</td>\n",
" <td>36.9486</td>\n",
" <td>7.3</td>\n",
" <td>1.2</td>\n",
" </tr>\n",
" <tr>\n",
" <th>548220</th>\n",
" <td>32</td>\n",
" <td>143</td>\n",
" <td>1.747783e+09</td>\n",
" <td>-31.1900</td>\n",
" <td>-68.2500</td>\n",
" <td>81.0</td>\n",
" <td>3.3</td>\n",
" </tr>\n",
" <tr>\n",
" <th>548221</th>\n",
" <td>40</td>\n",
" <td>790</td>\n",
" <td>1.747783e+09</td>\n",
" <td>17.6300</td>\n",
" <td>-101.3340</td>\n",
" <td>41.2</td>\n",
" <td>4.0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>548222 rows × 7 columns</p>\n",
"</div>"
],
"text/plain": [
" bin time_to_next_event time lat lon depth mag\n",
"0 27 24 1.577837e+09 19.2200 -67.1300 12.0 2.8\n",
"1 30 77 1.577837e+09 -2.7400 127.9000 20.0 3.0\n",
"2 25 425 1.577837e+09 19.0800 -67.0900 6.0 2.5\n",
"3 31 39 1.577837e+09 19.1900 -67.8400 28.0 3.1\n",
"4 35 69 1.577837e+09 -25.6400 -70.5200 53.0 3.5\n",
"... ... ... ... ... ... ... ...\n",
"548217 21 113 1.747781e+09 35.9836 28.1056 7.1 2.1\n",
"548218 18 941 1.747781e+09 36.2797 36.0253 7.0 1.9\n",
"548219 11 141 1.747782e+09 37.2222 36.9486 7.3 1.2\n",
"548220 32 143 1.747783e+09 -31.1900 -68.2500 81.0 3.3\n",
"548221 40 790 1.747783e+09 17.6300 -101.3340 41.2 4.0\n",
"\n",
"[548222 rows x 7 columns]"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "967ec29d",
"metadata": {},
"outputs": [],
"source": [
"# # Set threshold at the 95th percentile\n",
"# # threshold = np.percentile(data, 95)\n",
"# threshold = 6 ## the actual threshold value\n",
"# data = df['mag']\n",
"# extremes = data[data > threshold]\n",
"# print(f\"Threshold: {threshold}, Number of extremes: {len(extremes)}\")\n",
"\n",
"# # Visualize the threshold and extremes\n",
"# plt.hist(data, bins=30, edgecolor='k', alpha=0.7, label='Data')\n",
"# plt.axvline(threshold, color='red', linestyle='--', label='Threshold')\n",
"# plt.title('Threshold for POT')\n",
"# plt.legend()\n",
"# plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "77c900ad",
"metadata": {},
"outputs": [],
"source": [
"# from scipy.stats import genpareto\n",
"\n",
"# # Fit a GPD to the extremes\n",
"# gpd_params = genpareto.fit(extremes - threshold) # Subtract threshold for GPD fit\n",
"# print(f\"GPD Parameters: Shape={gpd_params[0]}, Location={gpd_params[1]}, Scale={gpd_params[2]}\")\n",
"\n",
"# # Visualize the GPD fit\n",
"# x = np.linspace(min(extremes), max(extremes), 100)\n",
"# pdf = genpareto.pdf(x - threshold, *gpd_params)\n",
"# plt.hist(extremes, bins=10, density=True, alpha=0.7, label='Data')\n",
"# plt.plot(x, pdf, label='GPD Fit', color='blue')\n",
"# plt.title('GPD Fit to Extremes')\n",
"# plt.legend()\n",
"# plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "f0c623e1",
"metadata": {},
"outputs": [],
"source": [
"# inv_mag_frequency = 1/np.bincount(df['bin'].to_numpy())/\n",
"# print(inv_mag_frequency)\n",
"# inv_mag_frequency[inv_mag_frequency == np.inf] = 2\n",
"# print(inv_mag_frequency)\n",
"# print(sum(mag_frequency))"
]
},
{
"cell_type": "markdown",
"id": "5eea5b40",
"metadata": {},
"source": [
"### Create model"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "00b89e53",
"metadata": {},
"outputs": [],
"source": [
"class MultiRNN(nn.Module):\n",
" def __init__(self, input_size, hidden_size, class_count, output_size):\n",
" super(MultiRNN, self).__init__()\n",
" self.rnn = nn.GRU(input_size, hidden_size, batch_first=True)\n",
" self.classify = nn.Linear(hidden_size, class_count)\n",
" self.regress = nn.Linear(hidden_size, output_size)\n",
" \n",
" def forward(self, x):\n",
" h0 = torch.zeros(1, x.size(0), hidden_size).to(x.device)\n",
" out, _ = self.rnn(x, h0)\n",
" classes = self.classify(out[:,-1])\n",
" regresses = self.regress(out[:,-1])\n",
"\n",
" if self.training:\n",
" return classes, regresses\n",
" else:\n",
" return torch.argmax(classes, dim=-1), regresses\n",
"\n",
"\n",
"input_size = 5\n",
"hidden_size = 20\n",
"class_count = 120\n",
"regressor_output_size = 4\n",
"# model = MultiRNN(input_size, hidden_size, class_count, regressor_output_size)"
]
},
{
"cell_type": "markdown",
"id": "0e346047",
"metadata": {},
"source": [
"### Create dataset/dataloader"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "2744a649",
"metadata": {},
"outputs": [],
"source": [
"class TimeseriesDataset(torch.utils.data.Dataset):\n",
" def __init__(self, x, class_y, regress_y, seq_len=100):\n",
" self.seq_len = seq_len\n",
" self.x = x\n",
" self.class_y = class_y\n",
" self.regress_y = regress_y\n",
" \n",
" def __len__(self):\n",
" return self.x.__len__() - (self.seq_len-1)\n",
" \n",
" def __getitem__(self, index):\n",
" return (self.x[index:index+self.seq_len], self.class_y[index+self.seq_len], self.regress_y[index+self.seq_len])\n"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "5714b17d",
"metadata": {},
"outputs": [],
"source": [
"x = torch.tensor(df[['time_to_next_event','time','lat','lon','depth','mag']].values).to(torch.float)\n",
"y1 = torch.tensor(df[['bin']].values)\n",
"y2 = torch.tensor(df[['time_to_next_event', 'time', 'lat', 'lon']].values).to(torch.float)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "72f7b0f1",
"metadata": {},
"outputs": [],
"source": [
"train_dataset = TimeseriesDataset(x, y1, y2, seq_len=1000)\n",
"train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=10, shuffle=False)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ee154b1b",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0 torch.Size([10, 1000, 6]) torch.Size([10, 1]) torch.Size([10, 4])\n",
"54723\n"
]
}
],
"source": [
"# for i, d in enumerate(train_loader):\n",
"# print(i, d[0].shape, d[1].shape, d[2].shape)\n",
"# break\n",
"# print(len(train_loader))"
]
},
{
"cell_type": "markdown",
"id": "2277a6ee",
"metadata": {},
"source": [
"### Create the Lightning setup for model"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a35aa9f2",
"metadata": {},
"outputs": [],
"source": [
"class MutliRNNLightning(L.LightningModule):\n",
" def __init__(self, input_size, hidden_size, class_count, output_size):\n",
" super().__init__()\n",
" self.model = MultiRNN(input_size, hidden_size, class_count, regressor_output_size)\n",
" self.device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
" self.model.to(self.device)\n",
"\n",
" def forward(self, x):\n",
" return self.model.forward(x)\n",
" \n",
" def configure_optimizers(self):\n",
" optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-3)\n",
" scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(\n",
" optimizer, mode=\"min\", factor=0.1, patience=5\n",
" )\n",
" \n",
" return {\n",
" \"optimizer\": optimizer,\n",
" \"lr_scheduler\": {\n",
" \"scheduler\": scheduler,\n",
" \"monitor\": \"val_loss_class\",\n",
" },\n",
" }\n",
" \n",
" def training_step(self, batch, batch_idx):\n",
" x, y1, y2 = batch\n",
" \n",
" x.to(self.device)\n",
" y1.to(self.device)\n",
" y2.to(self.device)\n",
" \n",
" \n",
" class_y_hat, regress_y_hat = self(x)\n",
" loss1 = nn.functional.cross_entropy(class_y_hat, y1)\n",
" loss2 = nn.functional.mse_loss(regress_y_hat, y2)\n",
" # loss = F.cross_entropy(y_hat, y)\n",
" loss = loss1 + loss2\n",
" self.log('train_loss_class', loss1)\n",
" self.log('train_loss_regress', loss1) #, on_epoch=True)\n",
" return loss\n",
" \n",
" def validation_step(self, val_batch, batch_idx):\n",
" x, y1, y2 = val_batch\n",
" \n",
" x.to(self.device)\n",
" y1.to(self.device)\n",
" y2.to(self.device)\n",
" \n",
" \n",
" class_y_hat, regress_y_hat = self(x)\n",
" loss1 = nn.functional.cross_entropy(class_y_hat, y1)\n",
" loss2 = nn.functional.mse_loss(regress_y_hat, y2)\n",
" # loss = F.cross_entropy(y_hat, y)\n",
" loss = loss1 + loss2\n",
" self.log('val_loss_class', loss1)\n",
" self.log('val_loss_regress', loss1) #, on_epoch=True)\n",
" return loss\n",
"\n",
" # def training_step(self, batch, batch_idx):\n",
" # x, y = batch\n",
" # class_y_hat, regress_y_hat = self(x)\n",
" \n",
" # loss = F.cross_entropy(y_hat, y)\n",
" # acc = (y_hat.argmax(1) == y).float().mean()\n",
" \n",
" # self.log(\"train_loss\", loss)\n",
" # self.log(\"train_acc\", acc)\n",
" # return loss\n",
" \n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "97657a45",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "b49f9d92",
"metadata": {},
"outputs": [],
"source": [
"sequence_length = 1000\n",
"\n",
"# criterion = nn.MSELoss()\n",
"class_loss = nn.CrossEntropyLoss(weight=torch.from_numpy(inv_mag_frequency))\n",
"regressor_loss = nn.MSELoss()\n",
"# optimizer = optim.Adam(model.parameters(), lr=0.001)\n",
"\n",
"X = torch.tensor(df.iloc[:sequence_length][['time', 'lat', 'lon', 'depth', 'mag']].values).to(torch.float)\n",
"X = torch.stack((torch.tensor(df.iloc[sequence_length:sequence_length*2][['time', 'lat', 'lon', 'depth', 'mag']].values).to(torch.float),X))\n",
"# print(X.shape)\n",
"model.train()\n",
"outputs = model(X)\n",
"# print(outputs[0].shape)\n",
"# print(outputs[1].shape)\n",
"\n",
"# model.eval()\n",
"# outputs = model(X)\n",
"# print(outputs[0].shape)\n",
"# print(outputs[1].shape)\n",
"\n",
"# num_epochs = 100\n",
"# for epoch in range(num_epochs):\n",
"# model.train()\n",
"# outputs = model(X.unsqueeze(2))\n",
"# loss = criterion(outputs, y.unsqueeze(2))\n",
" \n",
"# optimizer.zero_grad()\n",
"# loss.backward()\n",
"# optimizer.step()\n",
" \n",
"# if (epoch + 1) % 10 == 0:\n",
"# print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.11"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

193
rnn_regression_model.py Normal file
View File

@ -0,0 +1,193 @@
#%%
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import scipy.stats as stats
from scipy.stats import genextreme
from scipy.stats import genpareto
from argparse import ArgumentParser
import pathlib
import torch
import torch.nn as nn
import torch.optim as optim
import lightning as L
import os
#%%
class TimeseriesDataset(torch.utils.data.Dataset):
def __init__(self, x, y, seq_len=100):
self.seq_len = seq_len
self.x = x
self.y = y
def __len__(self):
return self.x.__len__() - self.seq_len
def __getitem__(self, index):
return (self.x[index:index+self.seq_len], self.y[index+self.seq_len])
#%%
## Copied from https://stackoverflow.com/questions/50817916/how-do-i-add-lstm-gru-or-other-recurrent-layers-to-a-sequential-in-pytorch
class SelectItem(nn.Module):
def __init__(self, item_index):
super(SelectItem, self).__init__()
self._name = 'selectitem'
self.item_index = item_index
def forward(self, inputs):
return inputs[self.item_index]
#%%
class MultiRNNLightning(L.LightningModule):
def __init__(self, input_size, output_size, args, conf):
super().__init__()
self.model = nn.Sequential(nn.LSTM(input_size, args.hidden_size, batch_first=True, bidirectional=args.b),
SelectItem(1),
SelectItem(0),
nn.ReLU(),
nn.Linear(args.hidden_size*(2 if args.b else 1), output_size))
self.save_hyperparameters(conf)
# self.model = nn.GRU(input_size, hidden_size, batch_first=True)
# self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# self.model.to(self.device)
@staticmethod
def add_model_specific_args(parent_parser):
parser = ArgumentParser(parents=[parent_parser], add_help=False)
parser.add_argument('--hidden_size', type=int, default=100)
parser.add_argument('-b', action='store_true')
return parser
def forward(self, x):
# print("hi")
return torch.squeeze(self.model(x), 0)
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer, mode="min", factor=0.1, patience=5
)
return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": scheduler,
"monitor": "val_loss",
},
}
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = nn.functional.mse_loss(y_hat, y)
# print(batch)
self.log('train_loss', loss, prog_bar=True) #, on_epoch=True)
self.log('training_loss_per_unit', loss/y.shape[0])
return loss
def validation_step(self, val_batch, batch_idx):
x, y = val_batch
y_hat = self(x)
loss = nn.functional.mse_loss(y_hat, y)
self.log('val_loss', loss, prog_bar=True) #, on_epoch=True)
# print(val_batch)
self.log('val_loss_per_unit', loss/y.shape[0])
return loss
def test_step(self, test_batch, batch_idx):
x, y = test_batch
y_hat = self(x)
loss = nn.functional.mse_loss(y_hat, y)
self.log('test_loss', loss, prog_bar=True) #, on_epoch=True)
self.log('test_loss_per_unit', loss/y.shape[0])
return loss
#%%
parser = ArgumentParser()
#%%
parser.add_argument('--val_size', type=float, default=0.2)
parser.add_argument('--test_size', type=float, default=0.2)
parser.add_argument('--seq_size', type=int, default=1000)
parser.add_argument('--data_dir', type=pathlib.Path, default=pathlib.Path("./data"))
parser.add_argument('--batch_size', type=int, default=10)
parser.add_argument('-x', type=str, nargs='+', default=['time_to_next_event','time','lat','lon','depth','mag'])
parser.add_argument('-y', type=str, nargs='+', default=['time_to_next_event','lat','lon','depth','mag'])
parser = MultiRNNLightning.add_model_specific_args(parser)
# parser = L.Trainer.add_argparse_args(parser)
parser.add_argument('--max_epochs', type=int, default=1)
#%%
args = parser.parse_args()
#%%
datadir = args.data_dir / 'reduced_data.csv'
print(datadir.as_posix())
df = pd.read_csv(datadir.as_posix())
#%%
val_size = int(args.val_size*len(df))
test_size = int(args.test_size*len(df))
train_size = len(df)-val_size-test_size
assert val_size >= 0, "validation size is less than 0"
assert test_size >= 0, "test size is less than 0"
assert train_size > 0, "train size is non-positive"
#%%
x_columns = args.x
y_columns = args.y
x = torch.tensor(df[x_columns].values).to(torch.float)
y = torch.tensor(df[y_columns].values).to(torch.float)
mag_index = -1 if 'mag' not in y_columns else y_columns.find('mag')
train_x, val_x, test_x = torch.split(x, [train_size, val_size, test_size])
train_y, val_y, test_y = torch.split(y, [train_size, val_size, test_size])
train_dataset = TimeseriesDataset(train_x, train_y, seq_len=args.seq_size)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
val_dataset = TimeseriesDataset(val_x, val_y, seq_len=args.seq_size)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False)
test_dataset = TimeseriesDataset(test_x, test_y, seq_len=args.seq_size)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False)
#%%
input_size = len(x_columns)
output_size = len(y_columns)
#%%
model = MultiRNNLightning(input_size, output_size, args, {"input_dim":input_size, "output_dim":output_size, "hidden_dim": args.hidden_size, "bidirectional": args.b, "batch_size":args.batch_size, "seq_len":args.seq_size, "input_features": args.x, "output_features": args.y})
trainer = L.Trainer(max_epochs=args.max_epochs)
#%%
trainer.fit(model, train_loader, val_loader)
#%%
trainer.test(model, test_loader)