Compare commits
10 Commits
main
...
initial_ex
| Author | SHA1 | Date | |
|---|---|---|---|
| 28096a932a | |||
| 4596d15c0a | |||
| f527929ba1 | |||
| 552404d998 | |||
| 56d78d0973 | |||
| e52ee45261 | |||
| c3524eda21 | |||
| 6c2247974c | |||
| fc03e01629 | |||
| 6038fdef50 |
2
.gitignore
vendored
Normal file
2
.gitignore
vendored
Normal file
@ -0,0 +1,2 @@
|
||||
/data
|
||||
/lightning_logs
|
||||
15
README.md
15
README.md
@ -1,3 +1,18 @@
|
||||
# A repository for exploration into earthquake forecasting
|
||||
|
||||
|
||||
### Things to do
|
||||
1. Update the metric for the RNN to include an accuracy for high magnitude stuff
|
||||
2. Try farther look ahead models and/or just an amount of time ahead
|
||||
- can consider something like an auto-transformer type build where it feeds it's prediction back into itself until it's next prediction is outside of a time?
|
||||
|
||||
#### Frame the issue
|
||||
|
||||
Need to care more about the rare events. Let's say above mag 8. or mag 6.
|
||||
|
||||
|
||||
##### Run command
|
||||
(current best performance)
|
||||
C:/Users/ewell/anaconda3/envs/test_space/python.exe d:/projects/earthquake_prediction_exploration/rnn_regression_model.py --max_epochs 50 -x time_to_next_event lat lon depth mag --hidden_size 1000 --seq_size 100 --batch_size 1000
|
||||
|
||||
|
||||
|
||||
338
data_acquisition/datapolling.ipynb
Normal file
338
data_acquisition/datapolling.ipynb
Normal file
@ -0,0 +1,338 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 18,
|
||||
"id": "e06a5cf6",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import requests\n",
|
||||
"import obspy"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 19,
|
||||
"id": "434a13eb",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"url = \"www.seismicportal.eu/fdsnws/event/1/query?limit=10&start=2020-01-01&end=2022-01-01&format=json\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 20,
|
||||
"id": "e72fb8f7",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def geturl(url):\n",
|
||||
" res = requests.get(\"https://\"+url, timeout=15)\n",
|
||||
" return {'status': res.status_code,\n",
|
||||
" 'content': res.text}"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 21,
|
||||
"id": "19355da2",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"{\"type\":\"FeatureCollection\",\"metadata\":{\"count\":10},\"features\":[{\n",
|
||||
" \"geometry\": {\n",
|
||||
" \"type\": \"Point\",\n",
|
||||
" \"coordinates\": [\n",
|
||||
" 122.38,\n",
|
||||
" -8.11,\n",
|
||||
" -10.0\n",
|
||||
" ]\n",
|
||||
" },\n",
|
||||
" \"type\": \"Feature\",\n",
|
||||
" \"id\": \"20220101_0000155\",\n",
|
||||
" \"properties\": {\n",
|
||||
" \"lastupdate\": \"2022-01-02T06:11:00.0Z\",\n",
|
||||
" \"magtype\": \"m\",\n",
|
||||
" \"evtype\": \"ke\",\n",
|
||||
" \"lon\": 122.38,\n",
|
||||
" \"auth\": \"DJA\",\n",
|
||||
" \"lat\": -8.11,\n",
|
||||
" \"depth\": 10.0,\n",
|
||||
" \"unid\": \"20220101_0000155\",\n",
|
||||
" \"mag\": 2.9,\n",
|
||||
" \"time\": \"2022-01-01T23:50:06.0Z\",\n",
|
||||
" \"source_id\": \"1083157\",\n",
|
||||
" \"source_catalog\": \"EMSC-RTS\",\n",
|
||||
" \"flynn_region\": \"FLORES REGION, INDONESIA\"\n",
|
||||
" }\n",
|
||||
"},{\n",
|
||||
" \"geometry\": {\n",
|
||||
" \"type\": \"Point\",\n",
|
||||
" \"coordinates\": [\n",
|
||||
" 124.09,\n",
|
||||
" -8.84,\n",
|
||||
" -65.0\n",
|
||||
" ]\n",
|
||||
" },\n",
|
||||
" \"type\": \"Feature\",\n",
|
||||
" \"id\": \"20220101_0000153\",\n",
|
||||
" \"properties\": {\n",
|
||||
" \"lastupdate\": \"2022-01-02T06:11:00.0Z\",\n",
|
||||
" \"magtype\": \"m\",\n",
|
||||
" \"evtype\": \"ke\",\n",
|
||||
" \"lon\": 124.09,\n",
|
||||
" \"auth\": \"DJA\",\n",
|
||||
" \"lat\": -8.84,\n",
|
||||
" \"depth\": 65.0,\n",
|
||||
" \"unid\": \"20220101_0000153\",\n",
|
||||
" \"mag\": 3.3,\n",
|
||||
" \"time\": \"2022-01-01T23:48:18.0Z\",\n",
|
||||
" \"source_id\": \"1083154\",\n",
|
||||
" \"source_catalog\": \"EMSC-RTS\",\n",
|
||||
" \"flynn_region\": \"KEPULAUAN ALOR, INDONESIA\"\n",
|
||||
" }\n",
|
||||
"},{\n",
|
||||
" \"geometry\": {\n",
|
||||
" \"type\": \"Point\",\n",
|
||||
" \"coordinates\": [\n",
|
||||
" -173.92,\n",
|
||||
" -21.37,\n",
|
||||
" -10.0\n",
|
||||
" ]\n",
|
||||
" },\n",
|
||||
" \"type\": \"Feature\",\n",
|
||||
" \"id\": \"20220101_0000137\",\n",
|
||||
" \"properties\": {\n",
|
||||
" \"lastupdate\": \"2022-01-02T08:20:00.0Z\",\n",
|
||||
" \"magtype\": \"mb\",\n",
|
||||
" \"evtype\": \"ke\",\n",
|
||||
" \"lon\": -173.92,\n",
|
||||
" \"auth\": \"EMSC\",\n",
|
||||
" \"lat\": -21.37,\n",
|
||||
" \"depth\": 10.0,\n",
|
||||
" \"unid\": \"20220101_0000137\",\n",
|
||||
" \"mag\": 4.9,\n",
|
||||
" \"time\": \"2022-01-01T23:47:39.0Z\",\n",
|
||||
" \"source_id\": \"1083084\",\n",
|
||||
" \"source_catalog\": \"EMSC-RTS\",\n",
|
||||
" \"flynn_region\": \"TONGA\"\n",
|
||||
" }\n",
|
||||
"},{\n",
|
||||
" \"geometry\": {\n",
|
||||
" \"type\": \"Point\",\n",
|
||||
" \"coordinates\": [\n",
|
||||
" -66.64,\n",
|
||||
" -23.73,\n",
|
||||
" -233.0\n",
|
||||
" ]\n",
|
||||
" },\n",
|
||||
" \"type\": \"Feature\",\n",
|
||||
" \"id\": \"20220101_0000136\",\n",
|
||||
" \"properties\": {\n",
|
||||
" \"lastupdate\": \"2022-01-01T23:56:00.0Z\",\n",
|
||||
" \"magtype\": \"m\",\n",
|
||||
" \"evtype\": \"ke\",\n",
|
||||
" \"lon\": -66.64,\n",
|
||||
" \"auth\": \"NSNA\",\n",
|
||||
" \"lat\": -23.73,\n",
|
||||
" \"depth\": 233.0,\n",
|
||||
" \"unid\": \"20220101_0000136\",\n",
|
||||
" \"mag\": 3.1,\n",
|
||||
" \"time\": \"2022-01-01T23:45:36.0Z\",\n",
|
||||
" \"source_id\": \"1083085\",\n",
|
||||
" \"source_catalog\": \"EMSC-RTS\",\n",
|
||||
" \"flynn_region\": \"JUJUY, ARGENTINA\"\n",
|
||||
" }\n",
|
||||
"},{\n",
|
||||
" \"geometry\": {\n",
|
||||
" \"type\": \"Point\",\n",
|
||||
" \"coordinates\": [\n",
|
||||
" -155.4,\n",
|
||||
" 19.2,\n",
|
||||
" -32.0\n",
|
||||
" ]\n",
|
||||
" },\n",
|
||||
" \"type\": \"Feature\",\n",
|
||||
" \"id\": \"20220101_0000133\",\n",
|
||||
" \"properties\": {\n",
|
||||
" \"lastupdate\": \"2022-01-01T23:45:00.0Z\",\n",
|
||||
" \"magtype\": \"md\",\n",
|
||||
" \"evtype\": \"ke\",\n",
|
||||
" \"lon\": -155.4,\n",
|
||||
" \"auth\": \"NEIR\",\n",
|
||||
" \"lat\": 19.2,\n",
|
||||
" \"depth\": 32.0,\n",
|
||||
" \"unid\": \"20220101_0000133\",\n",
|
||||
" \"mag\": 2.2,\n",
|
||||
" \"time\": \"2022-01-01T23:42:34.8Z\",\n",
|
||||
" \"source_id\": \"1083079\",\n",
|
||||
" \"source_catalog\": \"EMSC-RTS\",\n",
|
||||
" \"flynn_region\": \"ISLAND OF HAWAII, HAWAII\"\n",
|
||||
" }\n",
|
||||
"},{\n",
|
||||
" \"geometry\": {\n",
|
||||
" \"type\": \"Point\",\n",
|
||||
" \"coordinates\": [\n",
|
||||
" -16.21,\n",
|
||||
" 28.09,\n",
|
||||
" -8.0\n",
|
||||
" ]\n",
|
||||
" },\n",
|
||||
" \"type\": \"Feature\",\n",
|
||||
" \"id\": \"20220101_0000202\",\n",
|
||||
" \"properties\": {\n",
|
||||
" \"lastupdate\": \"2022-01-02T20:45:00.0Z\",\n",
|
||||
" \"magtype\": \"ml\",\n",
|
||||
" \"evtype\": \"ke\",\n",
|
||||
" \"lon\": -16.21,\n",
|
||||
" \"auth\": \"MDD\",\n",
|
||||
" \"lat\": 28.09,\n",
|
||||
" \"depth\": 8.0,\n",
|
||||
" \"unid\": \"20220101_0000202\",\n",
|
||||
" \"mag\": 1.8,\n",
|
||||
" \"time\": \"2022-01-01T23:40:04.8Z\",\n",
|
||||
" \"source_id\": \"1083360\",\n",
|
||||
" \"source_catalog\": \"EMSC-RTS\",\n",
|
||||
" \"flynn_region\": \"CANARY ISLANDS, SPAIN REGION\"\n",
|
||||
" }\n",
|
||||
"},{\n",
|
||||
" \"geometry\": {\n",
|
||||
" \"type\": \"Point\",\n",
|
||||
" \"coordinates\": [\n",
|
||||
" -69.31,\n",
|
||||
" 18.08,\n",
|
||||
" -10.0\n",
|
||||
" ]\n",
|
||||
" },\n",
|
||||
" \"type\": \"Feature\",\n",
|
||||
" \"id\": \"20220101_0000134\",\n",
|
||||
" \"properties\": {\n",
|
||||
" \"lastupdate\": \"2022-01-01T23:51:00.0Z\",\n",
|
||||
" \"magtype\": \"m\",\n",
|
||||
" \"evtype\": \"ke\",\n",
|
||||
" \"lon\": -69.31,\n",
|
||||
" \"auth\": \"UASD\",\n",
|
||||
" \"lat\": 18.08,\n",
|
||||
" \"depth\": 10.0,\n",
|
||||
" \"unid\": \"20220101_0000134\",\n",
|
||||
" \"mag\": 3.1,\n",
|
||||
" \"time\": \"2022-01-01T23:25:21.0Z\",\n",
|
||||
" \"source_id\": \"1083082\",\n",
|
||||
" \"source_catalog\": \"EMSC-RTS\",\n",
|
||||
" \"flynn_region\": \"DOMINICAN REPUBLIC REGION\"\n",
|
||||
" }\n",
|
||||
"},{\n",
|
||||
" \"geometry\": {\n",
|
||||
" \"type\": \"Point\",\n",
|
||||
" \"coordinates\": [\n",
|
||||
" -74.06,\n",
|
||||
" 18.95,\n",
|
||||
" -10.0\n",
|
||||
" ]\n",
|
||||
" },\n",
|
||||
" \"type\": \"Feature\",\n",
|
||||
" \"id\": \"20220101_0000132\",\n",
|
||||
" \"properties\": {\n",
|
||||
" \"lastupdate\": \"2022-01-01T23:36:00.0Z\",\n",
|
||||
" \"magtype\": \"m\",\n",
|
||||
" \"evtype\": \"ke\",\n",
|
||||
" \"lon\": -74.06,\n",
|
||||
" \"auth\": \"UASD\",\n",
|
||||
" \"lat\": 18.95,\n",
|
||||
" \"depth\": 10.0,\n",
|
||||
" \"unid\": \"20220101_0000132\",\n",
|
||||
" \"mag\": 3.1,\n",
|
||||
" \"time\": \"2022-01-01T23:19:03.0Z\",\n",
|
||||
" \"source_id\": \"1083078\",\n",
|
||||
" \"source_catalog\": \"EMSC-RTS\",\n",
|
||||
" \"flynn_region\": \"HAITI REGION\"\n",
|
||||
" }\n",
|
||||
"},{\n",
|
||||
" \"geometry\": {\n",
|
||||
" \"type\": \"Point\",\n",
|
||||
" \"coordinates\": [\n",
|
||||
" -98.05,\n",
|
||||
" 16.31,\n",
|
||||
" -8.0\n",
|
||||
" ]\n",
|
||||
" },\n",
|
||||
" \"type\": \"Feature\",\n",
|
||||
" \"id\": \"20220101_0000141\",\n",
|
||||
" \"properties\": {\n",
|
||||
" \"lastupdate\": \"2022-01-02T02:01:00.0Z\",\n",
|
||||
" \"magtype\": \"m\",\n",
|
||||
" \"evtype\": \"ke\",\n",
|
||||
" \"lon\": -98.05,\n",
|
||||
" \"auth\": \"UNM\",\n",
|
||||
" \"lat\": 16.31,\n",
|
||||
" \"depth\": 8.0,\n",
|
||||
" \"unid\": \"20220101_0000141\",\n",
|
||||
" \"mag\": 3.3,\n",
|
||||
" \"time\": \"2022-01-01T23:00:27.0Z\",\n",
|
||||
" \"source_id\": \"1083098\",\n",
|
||||
" \"source_catalog\": \"EMSC-RTS\",\n",
|
||||
" \"flynn_region\": \"OAXACA, MEXICO\"\n",
|
||||
" }\n",
|
||||
"},{\n",
|
||||
" \"geometry\": {\n",
|
||||
" \"type\": \"Point\",\n",
|
||||
" \"coordinates\": [\n",
|
||||
" -68.86,\n",
|
||||
" -21.1,\n",
|
||||
" -109.0\n",
|
||||
" ]\n",
|
||||
" },\n",
|
||||
" \"type\": \"Feature\",\n",
|
||||
" \"id\": \"20220101_0000131\",\n",
|
||||
" \"properties\": {\n",
|
||||
" \"lastupdate\": \"2022-01-01T23:08:00.0Z\",\n",
|
||||
" \"magtype\": \"m\",\n",
|
||||
" \"evtype\": \"ke\",\n",
|
||||
" \"lon\": -68.86,\n",
|
||||
" \"auth\": \"GUC\",\n",
|
||||
" \"lat\": -21.1,\n",
|
||||
" \"depth\": 109.0,\n",
|
||||
" \"unid\": \"20220101_0000131\",\n",
|
||||
" \"mag\": 2.7,\n",
|
||||
" \"time\": \"2022-01-01T22:57:18.0Z\",\n",
|
||||
" \"source_id\": \"1083075\",\n",
|
||||
" \"source_catalog\": \"EMSC-RTS\",\n",
|
||||
" \"flynn_region\": \"TARAPACA, CHILE\"\n",
|
||||
" }\n",
|
||||
"}]}\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"res = geturl(url)\n",
|
||||
"print(res['content'])\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "sideprojects",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.8"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
54
data_acquisition/datawebsocket.py
Normal file
54
data_acquisition/datawebsocket.py
Normal file
@ -0,0 +1,54 @@
|
||||
from __future__ import unicode_literals
|
||||
|
||||
from tornado.websocket import websocket_connect
|
||||
from tornado.ioloop import IOLoop
|
||||
from tornado import gen
|
||||
|
||||
import logging
|
||||
import json
|
||||
import sys
|
||||
|
||||
echo_uri = 'wss://www.seismicportal.eu/standing_order/websocket'
|
||||
PING_INTERVAL = 15
|
||||
|
||||
#You can modify this function to run custom process on the message
|
||||
def myprocessing(message):
|
||||
try:
|
||||
data = json.loads(message)
|
||||
print(data)
|
||||
info = data['data']['properties']
|
||||
info['action'] = data['action']
|
||||
logging.info('>>>> {action:7} event from {auth:7}, unid:{unid}, T0:{time}, Mag:{mag}, Region: {flynn_region}'.format(**info))
|
||||
except Exception:
|
||||
logging.exception("Unable to parse json message")
|
||||
|
||||
@gen.coroutine
|
||||
def listen(ws):
|
||||
while True:
|
||||
msg = yield ws.read_message()
|
||||
if msg is None:
|
||||
logging.info("close")
|
||||
ws = None
|
||||
break
|
||||
myprocessing(msg)
|
||||
|
||||
@gen.coroutine
|
||||
def launch_client():
|
||||
try:
|
||||
logging.info("Open WebSocket connection to %s", echo_uri)
|
||||
ws = yield websocket_connect(echo_uri, ping_interval=PING_INTERVAL)
|
||||
except Exception:
|
||||
logging.exception("connection error")
|
||||
else:
|
||||
logging.info("Waiting for messages...")
|
||||
listen(ws)
|
||||
|
||||
if __name__ == '__main__':
|
||||
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
|
||||
ioloop = IOLoop.instance()
|
||||
launch_client()
|
||||
try:
|
||||
ioloop.start()
|
||||
except KeyboardInterrupt:
|
||||
logging.info("Close WebSocket")
|
||||
ioloop.stop()
|
||||
522
data_acquisition/evt_testing.ipynb
Normal file
522
data_acquisition/evt_testing.ipynb
Normal file
File diff suppressed because one or more lines are too long
402
data_acquisition/extreme_data.ipynb
Normal file
402
data_acquisition/extreme_data.ipynb
Normal file
@ -0,0 +1,402 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"id": "bd523899",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import numpy as np\n",
|
||||
"import pandas as pd\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"import scipy.stats as stats\n",
|
||||
"from scipy.stats import genextreme\n",
|
||||
"from scipy.stats import genpareto\n",
|
||||
"import requests\n",
|
||||
"import json\n",
|
||||
"\n",
|
||||
"from datetime import date\n",
|
||||
"from datetime import datetime\n",
|
||||
"from datetime import timedelta\n",
|
||||
"import pytz\n",
|
||||
"\n",
|
||||
"import os.path"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"id": "c63a1b94",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"start_date = \"2024-01-01\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 15,
|
||||
"id": "903ed374",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def geturl(url):\n",
|
||||
" res = requests.get(\"https://\"+url, timeout=15)\n",
|
||||
" return {'status': res.status_code,\n",
|
||||
" 'content': res.text}"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 16,
|
||||
"id": "89e1979c",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def retrieveEvents(start, end, limit=20000, retry_limit=5, minmag=0):\n",
|
||||
" events = []\n",
|
||||
" moving_start = datetime.fromordinal(start.toordinal()).replace(tzinfo=pytz.utc)\n",
|
||||
" end = datetime.fromordinal(end.toordinal()).replace(tzinfo=pytz.utc)\n",
|
||||
" failures = 0\n",
|
||||
" while moving_start <= end and failures < retry_limit:\n",
|
||||
" # print(moving_start, end)\n",
|
||||
" url = \"www.seismicportal.eu/fdsnws/event/1/query?orderby=time-asc&limit={limit}&start={startdate}&end={enddate}&format=json&minmag={minmag}\".format(limit=limit, startdate=moving_start.isoformat(), enddate=end.isoformat(), minmag=minmag)\n",
|
||||
" # print(url)\n",
|
||||
" res = geturl(url)\n",
|
||||
" # print(res['status'])\n",
|
||||
" if res['status'] != 200:\n",
|
||||
" failures += 1\n",
|
||||
" continue\n",
|
||||
" content = res['content']\n",
|
||||
" json_parser = json.loads(content)\n",
|
||||
" temp_events = [event['properties'] for event in json_parser['features']]\n",
|
||||
"\n",
|
||||
" if len(temp_events) == 0:\n",
|
||||
" # print(\"ending\")\n",
|
||||
" break\n",
|
||||
"\n",
|
||||
" # temp_events = sorted(temp_events, key=lambda d: d['time'])\n",
|
||||
"\n",
|
||||
" if len(temp_events) == limit:\n",
|
||||
" moving_start = datetime.fromisoformat(temp_events[-1]['time'])\n",
|
||||
" else:\n",
|
||||
" moving_start = end + timedelta(hours=1)\n",
|
||||
" # print(\"ending here:\", moving_start)\n",
|
||||
" events.extend(temp_events)\n",
|
||||
" # print(\"hi\")\n",
|
||||
" # return pd.DataFrame(events)\n",
|
||||
" return events\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 17,
|
||||
"id": "8bccb27c",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"data = retrieveEvents(date.fromisoformat(start_date), date.today(), minmag=2)\n",
|
||||
"df = pd.DataFrame(data)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 18,
|
||||
"id": "902b6b1e",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<div>\n",
|
||||
"<style scoped>\n",
|
||||
" .dataframe tbody tr th:only-of-type {\n",
|
||||
" vertical-align: middle;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe tbody tr th {\n",
|
||||
" vertical-align: top;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe thead th {\n",
|
||||
" text-align: right;\n",
|
||||
" }\n",
|
||||
"</style>\n",
|
||||
"<table border=\"1\" class=\"dataframe\">\n",
|
||||
" <thead>\n",
|
||||
" <tr style=\"text-align: right;\">\n",
|
||||
" <th></th>\n",
|
||||
" <th>source_id</th>\n",
|
||||
" <th>source_catalog</th>\n",
|
||||
" <th>lastupdate</th>\n",
|
||||
" <th>time</th>\n",
|
||||
" <th>flynn_region</th>\n",
|
||||
" <th>lat</th>\n",
|
||||
" <th>lon</th>\n",
|
||||
" <th>depth</th>\n",
|
||||
" <th>evtype</th>\n",
|
||||
" <th>auth</th>\n",
|
||||
" <th>mag</th>\n",
|
||||
" <th>magtype</th>\n",
|
||||
" <th>unid</th>\n",
|
||||
" </tr>\n",
|
||||
" </thead>\n",
|
||||
" <tbody>\n",
|
||||
" <tr>\n",
|
||||
" <th>0</th>\n",
|
||||
" <td>1600054</td>\n",
|
||||
" <td>EMSC-RTS</td>\n",
|
||||
" <td>2024-01-01T00:02:51.439437Z</td>\n",
|
||||
" <td>2024-01-01T00:00:29.5Z</td>\n",
|
||||
" <td>CRETE, GREECE</td>\n",
|
||||
" <td>35.1400</td>\n",
|
||||
" <td>24.1200</td>\n",
|
||||
" <td>10.0</td>\n",
|
||||
" <td>ke</td>\n",
|
||||
" <td>THE</td>\n",
|
||||
" <td>2.3</td>\n",
|
||||
" <td>ml</td>\n",
|
||||
" <td>20240101_0000001</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>1</th>\n",
|
||||
" <td>1600055</td>\n",
|
||||
" <td>EMSC-RTS</td>\n",
|
||||
" <td>2024-01-01T00:14:14.3925Z</td>\n",
|
||||
" <td>2024-01-01T00:03:15.0Z</td>\n",
|
||||
" <td>SULAWESI, INDONESIA</td>\n",
|
||||
" <td>-1.3000</td>\n",
|
||||
" <td>120.5100</td>\n",
|
||||
" <td>10.0</td>\n",
|
||||
" <td>ke</td>\n",
|
||||
" <td>BMKG</td>\n",
|
||||
" <td>3.1</td>\n",
|
||||
" <td>m</td>\n",
|
||||
" <td>20240101_0000002</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>2</th>\n",
|
||||
" <td>1600058</td>\n",
|
||||
" <td>EMSC-RTS</td>\n",
|
||||
" <td>2024-01-01T00:24:28.774809Z</td>\n",
|
||||
" <td>2024-01-01T00:03:15.14Z</td>\n",
|
||||
" <td>PUERTO RICO</td>\n",
|
||||
" <td>18.4087</td>\n",
|
||||
" <td>-66.4270</td>\n",
|
||||
" <td>105.2</td>\n",
|
||||
" <td>ke</td>\n",
|
||||
" <td>PR</td>\n",
|
||||
" <td>3.2</td>\n",
|
||||
" <td>md</td>\n",
|
||||
" <td>20240101_0000004</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>3</th>\n",
|
||||
" <td>1600056</td>\n",
|
||||
" <td>EMSC-RTS</td>\n",
|
||||
" <td>2024-01-01T00:14:58.984143Z</td>\n",
|
||||
" <td>2024-01-01T00:05:28.0Z</td>\n",
|
||||
" <td>COLOMBIA-ECUADOR BORDER REGION</td>\n",
|
||||
" <td>0.1100</td>\n",
|
||||
" <td>-78.9400</td>\n",
|
||||
" <td>54.0</td>\n",
|
||||
" <td>ke</td>\n",
|
||||
" <td>QUI</td>\n",
|
||||
" <td>3.5</td>\n",
|
||||
" <td>m</td>\n",
|
||||
" <td>20240101_0000003</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>4</th>\n",
|
||||
" <td>1600057</td>\n",
|
||||
" <td>EMSC-RTS</td>\n",
|
||||
" <td>2024-01-02T08:04:45.107234Z</td>\n",
|
||||
" <td>2024-01-01T00:10:05.6Z</td>\n",
|
||||
" <td>NORWEGIAN SEA</td>\n",
|
||||
" <td>72.2450</td>\n",
|
||||
" <td>1.8470</td>\n",
|
||||
" <td>6.1</td>\n",
|
||||
" <td>ke</td>\n",
|
||||
" <td>BER</td>\n",
|
||||
" <td>3.7</td>\n",
|
||||
" <td>mw</td>\n",
|
||||
" <td>20240101_0000408</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>...</th>\n",
|
||||
" <td>...</td>\n",
|
||||
" <td>...</td>\n",
|
||||
" <td>...</td>\n",
|
||||
" <td>...</td>\n",
|
||||
" <td>...</td>\n",
|
||||
" <td>...</td>\n",
|
||||
" <td>...</td>\n",
|
||||
" <td>...</td>\n",
|
||||
" <td>...</td>\n",
|
||||
" <td>...</td>\n",
|
||||
" <td>...</td>\n",
|
||||
" <td>...</td>\n",
|
||||
" <td>...</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>126136</th>\n",
|
||||
" <td>1809807</td>\n",
|
||||
" <td>EMSC-RTS</td>\n",
|
||||
" <td>2025-05-19T21:34:26.878673Z</td>\n",
|
||||
" <td>2025-05-18T23:24:13.0Z</td>\n",
|
||||
" <td>SAN JUAN, ARGENTINA</td>\n",
|
||||
" <td>-31.6300</td>\n",
|
||||
" <td>-70.3800</td>\n",
|
||||
" <td>138.0</td>\n",
|
||||
" <td>ke</td>\n",
|
||||
" <td>CSN</td>\n",
|
||||
" <td>3.4</td>\n",
|
||||
" <td>ml</td>\n",
|
||||
" <td>20250518_0000270</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>126137</th>\n",
|
||||
" <td>1809808</td>\n",
|
||||
" <td>EMSC-RTS</td>\n",
|
||||
" <td>2025-05-18T23:38:19.408959Z</td>\n",
|
||||
" <td>2025-05-18T23:27:24.13Z</td>\n",
|
||||
" <td>SOUTHERN ITALY</td>\n",
|
||||
" <td>39.0343</td>\n",
|
||||
" <td>16.4318</td>\n",
|
||||
" <td>9.8</td>\n",
|
||||
" <td>ke</td>\n",
|
||||
" <td>INGV</td>\n",
|
||||
" <td>2.4</td>\n",
|
||||
" <td>ml</td>\n",
|
||||
" <td>20250518_0000271</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>126138</th>\n",
|
||||
" <td>1809813</td>\n",
|
||||
" <td>EMSC-RTS</td>\n",
|
||||
" <td>2025-05-19T00:14:20.838747Z</td>\n",
|
||||
" <td>2025-05-18T23:52:33.0Z</td>\n",
|
||||
" <td>SOUTHWEST OF SUMATRA, INDONESIA</td>\n",
|
||||
" <td>-7.8200</td>\n",
|
||||
" <td>103.8600</td>\n",
|
||||
" <td>10.0</td>\n",
|
||||
" <td>ke</td>\n",
|
||||
" <td>BMKG</td>\n",
|
||||
" <td>3.5</td>\n",
|
||||
" <td>m</td>\n",
|
||||
" <td>20250518_0000275</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>126139</th>\n",
|
||||
" <td>1809809</td>\n",
|
||||
" <td>EMSC-RTS</td>\n",
|
||||
" <td>2025-05-19T06:33:38.794931Z</td>\n",
|
||||
" <td>2025-05-18T23:58:01.56Z</td>\n",
|
||||
" <td>OFF COAST OF TARAPACA, CHILE</td>\n",
|
||||
" <td>-18.5159</td>\n",
|
||||
" <td>-71.3039</td>\n",
|
||||
" <td>25.0</td>\n",
|
||||
" <td>ke</td>\n",
|
||||
" <td>EMSC</td>\n",
|
||||
" <td>4.2</td>\n",
|
||||
" <td>mb</td>\n",
|
||||
" <td>20250518_0000273</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>126140</th>\n",
|
||||
" <td>1809814</td>\n",
|
||||
" <td>EMSC-RTS</td>\n",
|
||||
" <td>2025-05-19T06:33:56.681166Z</td>\n",
|
||||
" <td>2025-05-18T23:58:56.73Z</td>\n",
|
||||
" <td>WESTERN TURKEY</td>\n",
|
||||
" <td>37.8953</td>\n",
|
||||
" <td>27.6327</td>\n",
|
||||
" <td>11.4</td>\n",
|
||||
" <td>ke</td>\n",
|
||||
" <td>EMSC</td>\n",
|
||||
" <td>2.4</td>\n",
|
||||
" <td>ml</td>\n",
|
||||
" <td>20250518_0000287</td>\n",
|
||||
" </tr>\n",
|
||||
" </tbody>\n",
|
||||
"</table>\n",
|
||||
"<p>126141 rows × 13 columns</p>\n",
|
||||
"</div>"
|
||||
],
|
||||
"text/plain": [
|
||||
" source_id source_catalog lastupdate \\\n",
|
||||
"0 1600054 EMSC-RTS 2024-01-01T00:02:51.439437Z \n",
|
||||
"1 1600055 EMSC-RTS 2024-01-01T00:14:14.3925Z \n",
|
||||
"2 1600058 EMSC-RTS 2024-01-01T00:24:28.774809Z \n",
|
||||
"3 1600056 EMSC-RTS 2024-01-01T00:14:58.984143Z \n",
|
||||
"4 1600057 EMSC-RTS 2024-01-02T08:04:45.107234Z \n",
|
||||
"... ... ... ... \n",
|
||||
"126136 1809807 EMSC-RTS 2025-05-19T21:34:26.878673Z \n",
|
||||
"126137 1809808 EMSC-RTS 2025-05-18T23:38:19.408959Z \n",
|
||||
"126138 1809813 EMSC-RTS 2025-05-19T00:14:20.838747Z \n",
|
||||
"126139 1809809 EMSC-RTS 2025-05-19T06:33:38.794931Z \n",
|
||||
"126140 1809814 EMSC-RTS 2025-05-19T06:33:56.681166Z \n",
|
||||
"\n",
|
||||
" time flynn_region lat \\\n",
|
||||
"0 2024-01-01T00:00:29.5Z CRETE, GREECE 35.1400 \n",
|
||||
"1 2024-01-01T00:03:15.0Z SULAWESI, INDONESIA -1.3000 \n",
|
||||
"2 2024-01-01T00:03:15.14Z PUERTO RICO 18.4087 \n",
|
||||
"3 2024-01-01T00:05:28.0Z COLOMBIA-ECUADOR BORDER REGION 0.1100 \n",
|
||||
"4 2024-01-01T00:10:05.6Z NORWEGIAN SEA 72.2450 \n",
|
||||
"... ... ... ... \n",
|
||||
"126136 2025-05-18T23:24:13.0Z SAN JUAN, ARGENTINA -31.6300 \n",
|
||||
"126137 2025-05-18T23:27:24.13Z SOUTHERN ITALY 39.0343 \n",
|
||||
"126138 2025-05-18T23:52:33.0Z SOUTHWEST OF SUMATRA, INDONESIA -7.8200 \n",
|
||||
"126139 2025-05-18T23:58:01.56Z OFF COAST OF TARAPACA, CHILE -18.5159 \n",
|
||||
"126140 2025-05-18T23:58:56.73Z WESTERN TURKEY 37.8953 \n",
|
||||
"\n",
|
||||
" lon depth evtype auth mag magtype unid \n",
|
||||
"0 24.1200 10.0 ke THE 2.3 ml 20240101_0000001 \n",
|
||||
"1 120.5100 10.0 ke BMKG 3.1 m 20240101_0000002 \n",
|
||||
"2 -66.4270 105.2 ke PR 3.2 md 20240101_0000004 \n",
|
||||
"3 -78.9400 54.0 ke QUI 3.5 m 20240101_0000003 \n",
|
||||
"4 1.8470 6.1 ke BER 3.7 mw 20240101_0000408 \n",
|
||||
"... ... ... ... ... ... ... ... \n",
|
||||
"126136 -70.3800 138.0 ke CSN 3.4 ml 20250518_0000270 \n",
|
||||
"126137 16.4318 9.8 ke INGV 2.4 ml 20250518_0000271 \n",
|
||||
"126138 103.8600 10.0 ke BMKG 3.5 m 20250518_0000275 \n",
|
||||
"126139 -71.3039 25.0 ke EMSC 4.2 mb 20250518_0000273 \n",
|
||||
"126140 27.6327 11.4 ke EMSC 2.4 ml 20250518_0000287 \n",
|
||||
"\n",
|
||||
"[126141 rows x 13 columns]"
|
||||
]
|
||||
},
|
||||
"execution_count": 18,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"df"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "sideprojects",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.8"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
1117
data_acquisition/normal_data_handler.ipynb
Normal file
1117
data_acquisition/normal_data_handler.ipynb
Normal file
File diff suppressed because one or more lines are too long
734
full_regression_model.ipynb
Normal file
734
full_regression_model.ipynb
Normal file
@ -0,0 +1,734 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 160,
|
||||
"id": "98beda53",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import numpy as np\n",
|
||||
"import pandas as pd\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"import scipy.stats as stats\n",
|
||||
"from scipy.stats import genextreme\n",
|
||||
"from scipy.stats import genpareto\n",
|
||||
"import requests\n",
|
||||
"import json\n",
|
||||
"\n",
|
||||
"# from datetime import date\n",
|
||||
"# from datetime import datetime\n",
|
||||
"# from datetime import timedelta\n",
|
||||
"# import pytz\n",
|
||||
"\n",
|
||||
"import torch\n",
|
||||
"import torch.nn as nn\n",
|
||||
"import torch.optim as optim\n",
|
||||
"import lightning as L\n",
|
||||
"\n",
|
||||
"import os\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "f15607d7",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Import data"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 161,
|
||||
"id": "57505637",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"df = pd.read_csv('reduced_data.csv')\n",
|
||||
"# df['time'] = pd.to_datetime(df['time'], format='ISO8601')\n",
|
||||
"\n",
|
||||
"# df['lastupdate'] = pd.to_datetime(df['lastupdate'], format='ISO8601')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 162,
|
||||
"id": "053cc158",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<div>\n",
|
||||
"<style scoped>\n",
|
||||
" .dataframe tbody tr th:only-of-type {\n",
|
||||
" vertical-align: middle;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe tbody tr th {\n",
|
||||
" vertical-align: top;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe thead th {\n",
|
||||
" text-align: right;\n",
|
||||
" }\n",
|
||||
"</style>\n",
|
||||
"<table border=\"1\" class=\"dataframe\">\n",
|
||||
" <thead>\n",
|
||||
" <tr style=\"text-align: right;\">\n",
|
||||
" <th></th>\n",
|
||||
" <th>bin</th>\n",
|
||||
" <th>time_to_next_event</th>\n",
|
||||
" <th>time</th>\n",
|
||||
" <th>lat</th>\n",
|
||||
" <th>lon</th>\n",
|
||||
" <th>depth</th>\n",
|
||||
" <th>mag</th>\n",
|
||||
" </tr>\n",
|
||||
" </thead>\n",
|
||||
" <tbody>\n",
|
||||
" <tr>\n",
|
||||
" <th>0</th>\n",
|
||||
" <td>27</td>\n",
|
||||
" <td>3408</td>\n",
|
||||
" <td>1.262306e+09</td>\n",
|
||||
" <td>37.5835</td>\n",
|
||||
" <td>22.1704</td>\n",
|
||||
" <td>15.0</td>\n",
|
||||
" <td>2.8</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>1</th>\n",
|
||||
" <td>21</td>\n",
|
||||
" <td>495</td>\n",
|
||||
" <td>1.262310e+09</td>\n",
|
||||
" <td>42.8370</td>\n",
|
||||
" <td>13.0010</td>\n",
|
||||
" <td>9.4</td>\n",
|
||||
" <td>2.1</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>2</th>\n",
|
||||
" <td>22</td>\n",
|
||||
" <td>1377</td>\n",
|
||||
" <td>1.262310e+09</td>\n",
|
||||
" <td>46.1624</td>\n",
|
||||
" <td>12.2834</td>\n",
|
||||
" <td>10.0</td>\n",
|
||||
" <td>2.3</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>3</th>\n",
|
||||
" <td>50</td>\n",
|
||||
" <td>274</td>\n",
|
||||
" <td>1.262312e+09</td>\n",
|
||||
" <td>26.4100</td>\n",
|
||||
" <td>99.9100</td>\n",
|
||||
" <td>10.0</td>\n",
|
||||
" <td>5.0</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>4</th>\n",
|
||||
" <td>25</td>\n",
|
||||
" <td>11</td>\n",
|
||||
" <td>1.262312e+09</td>\n",
|
||||
" <td>39.0807</td>\n",
|
||||
" <td>31.0495</td>\n",
|
||||
" <td>7.0</td>\n",
|
||||
" <td>2.5</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>...</th>\n",
|
||||
" <td>...</td>\n",
|
||||
" <td>...</td>\n",
|
||||
" <td>...</td>\n",
|
||||
" <td>...</td>\n",
|
||||
" <td>...</td>\n",
|
||||
" <td>...</td>\n",
|
||||
" <td>...</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>1062937</th>\n",
|
||||
" <td>22</td>\n",
|
||||
" <td>146</td>\n",
|
||||
" <td>1.749481e+09</td>\n",
|
||||
" <td>39.2261</td>\n",
|
||||
" <td>28.9878</td>\n",
|
||||
" <td>10.8</td>\n",
|
||||
" <td>2.3</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>1062938</th>\n",
|
||||
" <td>40</td>\n",
|
||||
" <td>166</td>\n",
|
||||
" <td>1.749481e+09</td>\n",
|
||||
" <td>17.5500</td>\n",
|
||||
" <td>-94.3070</td>\n",
|
||||
" <td>186.8</td>\n",
|
||||
" <td>4.1</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>1062939</th>\n",
|
||||
" <td>30</td>\n",
|
||||
" <td>36</td>\n",
|
||||
" <td>1.749481e+09</td>\n",
|
||||
" <td>35.3752</td>\n",
|
||||
" <td>-3.6249</td>\n",
|
||||
" <td>25.7</td>\n",
|
||||
" <td>3.0</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>1062940</th>\n",
|
||||
" <td>20</td>\n",
|
||||
" <td>25</td>\n",
|
||||
" <td>1.749481e+09</td>\n",
|
||||
" <td>39.2152</td>\n",
|
||||
" <td>29.0178</td>\n",
|
||||
" <td>5.4</td>\n",
|
||||
" <td>2.0</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>1062941</th>\n",
|
||||
" <td>32</td>\n",
|
||||
" <td>173</td>\n",
|
||||
" <td>1.749481e+09</td>\n",
|
||||
" <td>14.8400</td>\n",
|
||||
" <td>119.4300</td>\n",
|
||||
" <td>16.0</td>\n",
|
||||
" <td>3.2</td>\n",
|
||||
" </tr>\n",
|
||||
" </tbody>\n",
|
||||
"</table>\n",
|
||||
"<p>1062942 rows × 7 columns</p>\n",
|
||||
"</div>"
|
||||
],
|
||||
"text/plain": [
|
||||
" bin time_to_next_event time lat lon depth mag\n",
|
||||
"0 27 3408 1.262306e+09 37.5835 22.1704 15.0 2.8\n",
|
||||
"1 21 495 1.262310e+09 42.8370 13.0010 9.4 2.1\n",
|
||||
"2 22 1377 1.262310e+09 46.1624 12.2834 10.0 2.3\n",
|
||||
"3 50 274 1.262312e+09 26.4100 99.9100 10.0 5.0\n",
|
||||
"4 25 11 1.262312e+09 39.0807 31.0495 7.0 2.5\n",
|
||||
"... ... ... ... ... ... ... ...\n",
|
||||
"1062937 22 146 1.749481e+09 39.2261 28.9878 10.8 2.3\n",
|
||||
"1062938 40 166 1.749481e+09 17.5500 -94.3070 186.8 4.1\n",
|
||||
"1062939 30 36 1.749481e+09 35.3752 -3.6249 25.7 3.0\n",
|
||||
"1062940 20 25 1.749481e+09 39.2152 29.0178 5.4 2.0\n",
|
||||
"1062941 32 173 1.749481e+09 14.8400 119.4300 16.0 3.2\n",
|
||||
"\n",
|
||||
"[1062942 rows x 7 columns]"
|
||||
]
|
||||
},
|
||||
"execution_count": 162,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"df"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "b17d72f3",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Dataset loader"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 163,
|
||||
"id": "e6dacf0d",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class TimeseriesDataset(torch.utils.data.Dataset):\n",
|
||||
" def __init__(self, x, y, seq_len=100):\n",
|
||||
" self.seq_len = seq_len\n",
|
||||
" self.x = x\n",
|
||||
" self.y = y\n",
|
||||
" \n",
|
||||
" def __len__(self):\n",
|
||||
" return self.x.__len__() - self.seq_len\n",
|
||||
" \n",
|
||||
" def __getitem__(self, index):\n",
|
||||
" return (self.x[index:index+self.seq_len], self.y[index+self.seq_len])\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 164,
|
||||
"id": "f42c537e",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"division_p1 = int(len(df)*0.6)\n",
|
||||
"division_p2 = int(len(df)*0.8)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 165,
|
||||
"id": "2181b1ad",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"x_columns = ['time_to_next_event','time','lat','lon','depth','mag']\n",
|
||||
"y_columns = ['time_to_next_event', 'lat', 'lon','mag']"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 166,
|
||||
"id": "10a8ac71",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"x = torch.tensor(df[x_columns].values).to(torch.float)\n",
|
||||
"# y1 = torch.tensor(df[['bin']].values)\n",
|
||||
"y = torch.tensor(df[y_columns].values).to(torch.float)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 167,
|
||||
"id": "ee26bfb2",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"train_x = x[:division_p1]\n",
|
||||
"val_x = x[division_p1:division_p2]\n",
|
||||
"test_x = x[division_p2:]\n",
|
||||
"\n",
|
||||
"train_y = y[:division_p1]\n",
|
||||
"val_y = y[division_p1:division_p2]\n",
|
||||
"test_y = y[division_p2:]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 168,
|
||||
"id": "4a739cb6",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"train_dataset = TimeseriesDataset(train_x, train_y, seq_len=1000)\n",
|
||||
"train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=10, shuffle=True)\n",
|
||||
"\n",
|
||||
"val_dataset = TimeseriesDataset(val_x, val_y, seq_len=1000)\n",
|
||||
"val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=10, shuffle=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "119f29ee",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Helper classes\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 169,
|
||||
"id": "973b9c65",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"## Copied from https://stackoverflow.com/questions/50817916/how-do-i-add-lstm-gru-or-other-recurrent-layers-to-a-sequential-in-pytorch\n",
|
||||
"class SelectItem(nn.Module):\n",
|
||||
" def __init__(self, item_index):\n",
|
||||
" super(SelectItem, self).__init__()\n",
|
||||
" self._name = 'selectitem'\n",
|
||||
" self.item_index = item_index\n",
|
||||
"\n",
|
||||
" def forward(self, inputs):\n",
|
||||
" return inputs[self.item_index]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "971e8319",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Model design"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "92b3b995",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class MultiRNNLightning(L.LightningModule):\n",
|
||||
" def __init__(self, input_size, hidden_size, output_size):\n",
|
||||
" super().__init__()\n",
|
||||
" self.model = nn.Sequential(nn.LSTM(input_size, hidden_size, batch_first=True),\n",
|
||||
" SelectItem(1),\n",
|
||||
" SelectItem(0),\n",
|
||||
" nn.ReLU(),\n",
|
||||
" nn.Linear(hidden_size, output_size))\n",
|
||||
" \n",
|
||||
" # self.model = nn.GRU(input_size, hidden_size, batch_first=True)\n",
|
||||
" \n",
|
||||
" # self.device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
||||
" # self.model.to(self.device)\n",
|
||||
"\n",
|
||||
" def forward(self, x):\n",
|
||||
" # print(\"hi\")\n",
|
||||
" return torch.squeeze(self.model(x), 0)\n",
|
||||
" \n",
|
||||
" def configure_optimizers(self):\n",
|
||||
" optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-3)\n",
|
||||
" scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(\n",
|
||||
" optimizer, mode=\"min\", factor=0.1, patience=5\n",
|
||||
" )\n",
|
||||
" \n",
|
||||
" return {\n",
|
||||
" \"optimizer\": optimizer,\n",
|
||||
" \"lr_scheduler\": {\n",
|
||||
" \"scheduler\": scheduler,\n",
|
||||
" \"monitor\": \"val_loss\",\n",
|
||||
" },\n",
|
||||
" }\n",
|
||||
" \n",
|
||||
" def training_step(self, batch, batch_idx):\n",
|
||||
" x, y = batch\n",
|
||||
" \n",
|
||||
" # x.to(self.device)\n",
|
||||
" # y.to(self.device)\n",
|
||||
" \n",
|
||||
" # print(\"x shape\", x.shape)\n",
|
||||
" y_hat = self(x)\n",
|
||||
" # print(y_hat.shape)\n",
|
||||
" # print(\"y shape\", y.shape)\n",
|
||||
" # return 0\n",
|
||||
" # print(\"yhat_shape\", y_hat.shape)\n",
|
||||
" loss = nn.functional.mse_loss(y_hat, y)\n",
|
||||
" self.log('train_loss', loss) #, on_epoch=True)\n",
|
||||
" return loss\n",
|
||||
" \n",
|
||||
" def validation_step(self, val_batch, batch_idx):\n",
|
||||
" x, y = val_batch\n",
|
||||
" \n",
|
||||
" # x.to(self.device)\n",
|
||||
" # y.to(self.device)\n",
|
||||
" \n",
|
||||
" \n",
|
||||
" y_hat = self(x)\n",
|
||||
" loss = nn.functional.mse_loss(y_hat, y)\n",
|
||||
" self.log('val_loss', loss) #, on_epoch=True)\n",
|
||||
" return loss\n",
|
||||
"\n",
|
||||
" # def training_step(self, batch, batch_idx):\n",
|
||||
" # x, y = batch\n",
|
||||
" # class_y_hat, regress_y_hat = self(x)\n",
|
||||
" \n",
|
||||
" # loss = F.cross_entropy(y_hat, y)\n",
|
||||
" # acc = (y_hat.argmax(1) == y).float().mean()\n",
|
||||
" \n",
|
||||
" # self.log(\"train_loss\", loss)\n",
|
||||
" # self.log(\"train_acc\", acc)\n",
|
||||
" # return loss\n",
|
||||
" \n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "741ce0f7",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### Model Params"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "886809fd",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"input_size = len(x_columns)\n",
|
||||
"hidden_size = 1000\n",
|
||||
"output_size = len(y_columns)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "25598f51",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Train model"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 172,
|
||||
"id": "61fa862f",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"GPU available: True (cuda), used: True\n",
|
||||
"TPU available: False, using: 0 TPU cores\n",
|
||||
"HPU available: False, using: 0 HPUs\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"model = MultiRNNLightning(input_size, hidden_size, output_size)\n",
|
||||
"\n",
|
||||
"trainer = L.Trainer(max_epochs=10)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "e847d0c2",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n",
|
||||
"\n",
|
||||
" | Name | Type | Params | Mode \n",
|
||||
"---------------------------------------------\n",
|
||||
"0 | model | Sequential | 167 K | train\n",
|
||||
"---------------------------------------------\n",
|
||||
"167 K Trainable params\n",
|
||||
"0 Non-trainable params\n",
|
||||
"167 K Total params\n",
|
||||
"0.669 Total estimated model params size (MB)\n",
|
||||
"6 Modules in train mode\n",
|
||||
"0 Modules in eval mode\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "adec569b050341aab3b2c94e22152976",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"Sanity Checking: | | 0/? [00:00<?, ?it/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"c:\\Users\\ewell\\anaconda3\\envs\\test_space\\Lib\\site-packages\\lightning\\pytorch\\trainer\\connectors\\data_connector.py:476: Your `val_dataloader`'s sampler has shuffling enabled, it is strongly recommended that you turn shuffling off for val/test dataloaders.\n",
|
||||
"c:\\Users\\ewell\\anaconda3\\envs\\test_space\\Lib\\site-packages\\lightning\\pytorch\\trainer\\connectors\\data_connector.py:425: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.\n",
|
||||
"c:\\Users\\ewell\\anaconda3\\envs\\test_space\\Lib\\site-packages\\lightning\\pytorch\\trainer\\connectors\\data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "b2c2156c3db24dadbeff025679a1e7b4",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"Training: | | 0/? [00:00<?, ?it/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "80970987a81d4b059981dce44b21cdc8",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"Validation: | | 0/? [00:00<?, ?it/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "d21a94f8276d4dea8ecd543ff7d5bdd2",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"Validation: | | 0/? [00:00<?, ?it/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "5c1b8ebb789e4c80ad4ec452c9ccbbda",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"Validation: | | 0/? [00:00<?, ?it/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "88b213664aaa4bc7bf57b4e5cfff3ba7",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"Validation: | | 0/? [00:00<?, ?it/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "26637bbce8b2455bb72a6268f2e2196c",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"Validation: | | 0/? [00:00<?, ?it/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "c067cd1c645d49c8a11cae04e3d676dc",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"Validation: | | 0/? [00:00<?, ?it/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "eb0dc33095f841c7a30376c0bdb0fce0",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"Validation: | | 0/? [00:00<?, ?it/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "12667695d03e4baf9ab468a6b57349e0",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"Validation: | | 0/? [00:00<?, ?it/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "e65958ac8567425d9e9722f6023f8660",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"Validation: | | 0/? [00:00<?, ?it/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "c6ff7598384d45c3bd192822cd0070b0",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"Validation: | | 0/? [00:00<?, ?it/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"`Trainer.fit` stopped: `max_epochs=10` reached.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"trainer.fit(model, train_loader, val_loader)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 174,
|
||||
"id": "5d146588",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# %reload_ext tensorboard\n",
|
||||
"# %tensorboard --logdir=lightning_logs/"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.11"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
563
model_training.ipynb
Normal file
563
model_training.ipynb
Normal file
@ -0,0 +1,563 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "b4cc996f",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import numpy as np\n",
|
||||
"import pandas as pd\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"import scipy.stats as stats\n",
|
||||
"from scipy.stats import genextreme\n",
|
||||
"from scipy.stats import genpareto\n",
|
||||
"import requests\n",
|
||||
"import json\n",
|
||||
"\n",
|
||||
"# from datetime import date\n",
|
||||
"# from datetime import datetime\n",
|
||||
"# from datetime import timedelta\n",
|
||||
"# import pytz\n",
|
||||
"\n",
|
||||
"import torch\n",
|
||||
"import torch.nn as nn\n",
|
||||
"import torch.optim as optim\n",
|
||||
"import lightning as L\n",
|
||||
"\n",
|
||||
"import os\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "11cc2375",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Import data"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "b0fe3fe6",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"df = pd.read_csv('reduced_data.csv')\n",
|
||||
"# df['time'] = pd.to_datetime(df['time'], format='ISO8601')\n",
|
||||
"\n",
|
||||
"# df['lastupdate'] = pd.to_datetime(df['lastupdate'], format='ISO8601')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "083c39ae",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<div>\n",
|
||||
"<style scoped>\n",
|
||||
" .dataframe tbody tr th:only-of-type {\n",
|
||||
" vertical-align: middle;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe tbody tr th {\n",
|
||||
" vertical-align: top;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe thead th {\n",
|
||||
" text-align: right;\n",
|
||||
" }\n",
|
||||
"</style>\n",
|
||||
"<table border=\"1\" class=\"dataframe\">\n",
|
||||
" <thead>\n",
|
||||
" <tr style=\"text-align: right;\">\n",
|
||||
" <th></th>\n",
|
||||
" <th>bin</th>\n",
|
||||
" <th>time_to_next_event</th>\n",
|
||||
" <th>time</th>\n",
|
||||
" <th>lat</th>\n",
|
||||
" <th>lon</th>\n",
|
||||
" <th>depth</th>\n",
|
||||
" <th>mag</th>\n",
|
||||
" </tr>\n",
|
||||
" </thead>\n",
|
||||
" <tbody>\n",
|
||||
" <tr>\n",
|
||||
" <th>0</th>\n",
|
||||
" <td>27</td>\n",
|
||||
" <td>24</td>\n",
|
||||
" <td>1.577837e+09</td>\n",
|
||||
" <td>19.2200</td>\n",
|
||||
" <td>-67.1300</td>\n",
|
||||
" <td>12.0</td>\n",
|
||||
" <td>2.8</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>1</th>\n",
|
||||
" <td>30</td>\n",
|
||||
" <td>77</td>\n",
|
||||
" <td>1.577837e+09</td>\n",
|
||||
" <td>-2.7400</td>\n",
|
||||
" <td>127.9000</td>\n",
|
||||
" <td>20.0</td>\n",
|
||||
" <td>3.0</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>2</th>\n",
|
||||
" <td>25</td>\n",
|
||||
" <td>425</td>\n",
|
||||
" <td>1.577837e+09</td>\n",
|
||||
" <td>19.0800</td>\n",
|
||||
" <td>-67.0900</td>\n",
|
||||
" <td>6.0</td>\n",
|
||||
" <td>2.5</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>3</th>\n",
|
||||
" <td>31</td>\n",
|
||||
" <td>39</td>\n",
|
||||
" <td>1.577837e+09</td>\n",
|
||||
" <td>19.1900</td>\n",
|
||||
" <td>-67.8400</td>\n",
|
||||
" <td>28.0</td>\n",
|
||||
" <td>3.1</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>4</th>\n",
|
||||
" <td>35</td>\n",
|
||||
" <td>69</td>\n",
|
||||
" <td>1.577837e+09</td>\n",
|
||||
" <td>-25.6400</td>\n",
|
||||
" <td>-70.5200</td>\n",
|
||||
" <td>53.0</td>\n",
|
||||
" <td>3.5</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>...</th>\n",
|
||||
" <td>...</td>\n",
|
||||
" <td>...</td>\n",
|
||||
" <td>...</td>\n",
|
||||
" <td>...</td>\n",
|
||||
" <td>...</td>\n",
|
||||
" <td>...</td>\n",
|
||||
" <td>...</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>548217</th>\n",
|
||||
" <td>21</td>\n",
|
||||
" <td>113</td>\n",
|
||||
" <td>1.747781e+09</td>\n",
|
||||
" <td>35.9836</td>\n",
|
||||
" <td>28.1056</td>\n",
|
||||
" <td>7.1</td>\n",
|
||||
" <td>2.1</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>548218</th>\n",
|
||||
" <td>18</td>\n",
|
||||
" <td>941</td>\n",
|
||||
" <td>1.747781e+09</td>\n",
|
||||
" <td>36.2797</td>\n",
|
||||
" <td>36.0253</td>\n",
|
||||
" <td>7.0</td>\n",
|
||||
" <td>1.9</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>548219</th>\n",
|
||||
" <td>11</td>\n",
|
||||
" <td>141</td>\n",
|
||||
" <td>1.747782e+09</td>\n",
|
||||
" <td>37.2222</td>\n",
|
||||
" <td>36.9486</td>\n",
|
||||
" <td>7.3</td>\n",
|
||||
" <td>1.2</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>548220</th>\n",
|
||||
" <td>32</td>\n",
|
||||
" <td>143</td>\n",
|
||||
" <td>1.747783e+09</td>\n",
|
||||
" <td>-31.1900</td>\n",
|
||||
" <td>-68.2500</td>\n",
|
||||
" <td>81.0</td>\n",
|
||||
" <td>3.3</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>548221</th>\n",
|
||||
" <td>40</td>\n",
|
||||
" <td>790</td>\n",
|
||||
" <td>1.747783e+09</td>\n",
|
||||
" <td>17.6300</td>\n",
|
||||
" <td>-101.3340</td>\n",
|
||||
" <td>41.2</td>\n",
|
||||
" <td>4.0</td>\n",
|
||||
" </tr>\n",
|
||||
" </tbody>\n",
|
||||
"</table>\n",
|
||||
"<p>548222 rows × 7 columns</p>\n",
|
||||
"</div>"
|
||||
],
|
||||
"text/plain": [
|
||||
" bin time_to_next_event time lat lon depth mag\n",
|
||||
"0 27 24 1.577837e+09 19.2200 -67.1300 12.0 2.8\n",
|
||||
"1 30 77 1.577837e+09 -2.7400 127.9000 20.0 3.0\n",
|
||||
"2 25 425 1.577837e+09 19.0800 -67.0900 6.0 2.5\n",
|
||||
"3 31 39 1.577837e+09 19.1900 -67.8400 28.0 3.1\n",
|
||||
"4 35 69 1.577837e+09 -25.6400 -70.5200 53.0 3.5\n",
|
||||
"... ... ... ... ... ... ... ...\n",
|
||||
"548217 21 113 1.747781e+09 35.9836 28.1056 7.1 2.1\n",
|
||||
"548218 18 941 1.747781e+09 36.2797 36.0253 7.0 1.9\n",
|
||||
"548219 11 141 1.747782e+09 37.2222 36.9486 7.3 1.2\n",
|
||||
"548220 32 143 1.747783e+09 -31.1900 -68.2500 81.0 3.3\n",
|
||||
"548221 40 790 1.747783e+09 17.6300 -101.3340 41.2 4.0\n",
|
||||
"\n",
|
||||
"[548222 rows x 7 columns]"
|
||||
]
|
||||
},
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"df"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "967ec29d",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# # Set threshold at the 95th percentile\n",
|
||||
"# # threshold = np.percentile(data, 95)\n",
|
||||
"# threshold = 6 ## the actual threshold value\n",
|
||||
"# data = df['mag']\n",
|
||||
"# extremes = data[data > threshold]\n",
|
||||
"# print(f\"Threshold: {threshold}, Number of extremes: {len(extremes)}\")\n",
|
||||
"\n",
|
||||
"# # Visualize the threshold and extremes\n",
|
||||
"# plt.hist(data, bins=30, edgecolor='k', alpha=0.7, label='Data')\n",
|
||||
"# plt.axvline(threshold, color='red', linestyle='--', label='Threshold')\n",
|
||||
"# plt.title('Threshold for POT')\n",
|
||||
"# plt.legend()\n",
|
||||
"# plt.show()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"id": "77c900ad",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# from scipy.stats import genpareto\n",
|
||||
"\n",
|
||||
"# # Fit a GPD to the extremes\n",
|
||||
"# gpd_params = genpareto.fit(extremes - threshold) # Subtract threshold for GPD fit\n",
|
||||
"# print(f\"GPD Parameters: Shape={gpd_params[0]}, Location={gpd_params[1]}, Scale={gpd_params[2]}\")\n",
|
||||
"\n",
|
||||
"# # Visualize the GPD fit\n",
|
||||
"# x = np.linspace(min(extremes), max(extremes), 100)\n",
|
||||
"# pdf = genpareto.pdf(x - threshold, *gpd_params)\n",
|
||||
"# plt.hist(extremes, bins=10, density=True, alpha=0.7, label='Data')\n",
|
||||
"# plt.plot(x, pdf, label='GPD Fit', color='blue')\n",
|
||||
"# plt.title('GPD Fit to Extremes')\n",
|
||||
"# plt.legend()\n",
|
||||
"# plt.show()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"id": "f0c623e1",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# inv_mag_frequency = 1/np.bincount(df['bin'].to_numpy())/\n",
|
||||
"# print(inv_mag_frequency)\n",
|
||||
"# inv_mag_frequency[inv_mag_frequency == np.inf] = 2\n",
|
||||
"# print(inv_mag_frequency)\n",
|
||||
"# print(sum(mag_frequency))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "5eea5b40",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Create model"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"id": "00b89e53",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class MultiRNN(nn.Module):\n",
|
||||
" def __init__(self, input_size, hidden_size, class_count, output_size):\n",
|
||||
" super(MultiRNN, self).__init__()\n",
|
||||
" self.rnn = nn.GRU(input_size, hidden_size, batch_first=True)\n",
|
||||
" self.classify = nn.Linear(hidden_size, class_count)\n",
|
||||
" self.regress = nn.Linear(hidden_size, output_size)\n",
|
||||
" \n",
|
||||
" def forward(self, x):\n",
|
||||
" h0 = torch.zeros(1, x.size(0), hidden_size).to(x.device)\n",
|
||||
" out, _ = self.rnn(x, h0)\n",
|
||||
" classes = self.classify(out[:,-1])\n",
|
||||
" regresses = self.regress(out[:,-1])\n",
|
||||
"\n",
|
||||
" if self.training:\n",
|
||||
" return classes, regresses\n",
|
||||
" else:\n",
|
||||
" return torch.argmax(classes, dim=-1), regresses\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"input_size = 5\n",
|
||||
"hidden_size = 20\n",
|
||||
"class_count = 120\n",
|
||||
"regressor_output_size = 4\n",
|
||||
"# model = MultiRNN(input_size, hidden_size, class_count, regressor_output_size)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "0e346047",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Create dataset/dataloader"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"id": "2744a649",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class TimeseriesDataset(torch.utils.data.Dataset):\n",
|
||||
" def __init__(self, x, class_y, regress_y, seq_len=100):\n",
|
||||
" self.seq_len = seq_len\n",
|
||||
" self.x = x\n",
|
||||
" self.class_y = class_y\n",
|
||||
" self.regress_y = regress_y\n",
|
||||
" \n",
|
||||
" def __len__(self):\n",
|
||||
" return self.x.__len__() - (self.seq_len-1)\n",
|
||||
" \n",
|
||||
" def __getitem__(self, index):\n",
|
||||
" return (self.x[index:index+self.seq_len], self.class_y[index+self.seq_len], self.regress_y[index+self.seq_len])\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"id": "5714b17d",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"x = torch.tensor(df[['time_to_next_event','time','lat','lon','depth','mag']].values).to(torch.float)\n",
|
||||
"y1 = torch.tensor(df[['bin']].values)\n",
|
||||
"y2 = torch.tensor(df[['time_to_next_event', 'time', 'lat', 'lon']].values).to(torch.float)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"id": "72f7b0f1",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"train_dataset = TimeseriesDataset(x, y1, y2, seq_len=1000)\n",
|
||||
"train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=10, shuffle=False)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "ee154b1b",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"0 torch.Size([10, 1000, 6]) torch.Size([10, 1]) torch.Size([10, 4])\n",
|
||||
"54723\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# for i, d in enumerate(train_loader):\n",
|
||||
"# print(i, d[0].shape, d[1].shape, d[2].shape)\n",
|
||||
"# break\n",
|
||||
"# print(len(train_loader))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "2277a6ee",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Create the Lightning setup for model"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "a35aa9f2",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class MutliRNNLightning(L.LightningModule):\n",
|
||||
" def __init__(self, input_size, hidden_size, class_count, output_size):\n",
|
||||
" super().__init__()\n",
|
||||
" self.model = MultiRNN(input_size, hidden_size, class_count, regressor_output_size)\n",
|
||||
" self.device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
||||
" self.model.to(self.device)\n",
|
||||
"\n",
|
||||
" def forward(self, x):\n",
|
||||
" return self.model.forward(x)\n",
|
||||
" \n",
|
||||
" def configure_optimizers(self):\n",
|
||||
" optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-3)\n",
|
||||
" scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(\n",
|
||||
" optimizer, mode=\"min\", factor=0.1, patience=5\n",
|
||||
" )\n",
|
||||
" \n",
|
||||
" return {\n",
|
||||
" \"optimizer\": optimizer,\n",
|
||||
" \"lr_scheduler\": {\n",
|
||||
" \"scheduler\": scheduler,\n",
|
||||
" \"monitor\": \"val_loss_class\",\n",
|
||||
" },\n",
|
||||
" }\n",
|
||||
" \n",
|
||||
" def training_step(self, batch, batch_idx):\n",
|
||||
" x, y1, y2 = batch\n",
|
||||
" \n",
|
||||
" x.to(self.device)\n",
|
||||
" y1.to(self.device)\n",
|
||||
" y2.to(self.device)\n",
|
||||
" \n",
|
||||
" \n",
|
||||
" class_y_hat, regress_y_hat = self(x)\n",
|
||||
" loss1 = nn.functional.cross_entropy(class_y_hat, y1)\n",
|
||||
" loss2 = nn.functional.mse_loss(regress_y_hat, y2)\n",
|
||||
" # loss = F.cross_entropy(y_hat, y)\n",
|
||||
" loss = loss1 + loss2\n",
|
||||
" self.log('train_loss_class', loss1)\n",
|
||||
" self.log('train_loss_regress', loss1) #, on_epoch=True)\n",
|
||||
" return loss\n",
|
||||
" \n",
|
||||
" def validation_step(self, val_batch, batch_idx):\n",
|
||||
" x, y1, y2 = val_batch\n",
|
||||
" \n",
|
||||
" x.to(self.device)\n",
|
||||
" y1.to(self.device)\n",
|
||||
" y2.to(self.device)\n",
|
||||
" \n",
|
||||
" \n",
|
||||
" class_y_hat, regress_y_hat = self(x)\n",
|
||||
" loss1 = nn.functional.cross_entropy(class_y_hat, y1)\n",
|
||||
" loss2 = nn.functional.mse_loss(regress_y_hat, y2)\n",
|
||||
" # loss = F.cross_entropy(y_hat, y)\n",
|
||||
" loss = loss1 + loss2\n",
|
||||
" self.log('val_loss_class', loss1)\n",
|
||||
" self.log('val_loss_regress', loss1) #, on_epoch=True)\n",
|
||||
" return loss\n",
|
||||
"\n",
|
||||
" # def training_step(self, batch, batch_idx):\n",
|
||||
" # x, y = batch\n",
|
||||
" # class_y_hat, regress_y_hat = self(x)\n",
|
||||
" \n",
|
||||
" # loss = F.cross_entropy(y_hat, y)\n",
|
||||
" # acc = (y_hat.argmax(1) == y).float().mean()\n",
|
||||
" \n",
|
||||
" # self.log(\"train_loss\", loss)\n",
|
||||
" # self.log(\"train_acc\", acc)\n",
|
||||
" # return loss\n",
|
||||
" \n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "97657a45",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "b49f9d92",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"sequence_length = 1000\n",
|
||||
"\n",
|
||||
"# criterion = nn.MSELoss()\n",
|
||||
"class_loss = nn.CrossEntropyLoss(weight=torch.from_numpy(inv_mag_frequency))\n",
|
||||
"regressor_loss = nn.MSELoss()\n",
|
||||
"# optimizer = optim.Adam(model.parameters(), lr=0.001)\n",
|
||||
"\n",
|
||||
"X = torch.tensor(df.iloc[:sequence_length][['time', 'lat', 'lon', 'depth', 'mag']].values).to(torch.float)\n",
|
||||
"X = torch.stack((torch.tensor(df.iloc[sequence_length:sequence_length*2][['time', 'lat', 'lon', 'depth', 'mag']].values).to(torch.float),X))\n",
|
||||
"# print(X.shape)\n",
|
||||
"model.train()\n",
|
||||
"outputs = model(X)\n",
|
||||
"# print(outputs[0].shape)\n",
|
||||
"# print(outputs[1].shape)\n",
|
||||
"\n",
|
||||
"# model.eval()\n",
|
||||
"# outputs = model(X)\n",
|
||||
"# print(outputs[0].shape)\n",
|
||||
"# print(outputs[1].shape)\n",
|
||||
"\n",
|
||||
"# num_epochs = 100\n",
|
||||
"# for epoch in range(num_epochs):\n",
|
||||
"# model.train()\n",
|
||||
"# outputs = model(X.unsqueeze(2))\n",
|
||||
"# loss = criterion(outputs, y.unsqueeze(2))\n",
|
||||
" \n",
|
||||
"# optimizer.zero_grad()\n",
|
||||
"# loss.backward()\n",
|
||||
"# optimizer.step()\n",
|
||||
" \n",
|
||||
"# if (epoch + 1) % 10 == 0:\n",
|
||||
"# print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.11"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
193
rnn_regression_model.py
Normal file
193
rnn_regression_model.py
Normal file
@ -0,0 +1,193 @@
|
||||
#%%
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import matplotlib.pyplot as plt
|
||||
import scipy.stats as stats
|
||||
from scipy.stats import genextreme
|
||||
from scipy.stats import genpareto
|
||||
|
||||
from argparse import ArgumentParser
|
||||
import pathlib
|
||||
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
import lightning as L
|
||||
|
||||
import os
|
||||
#%%
|
||||
class TimeseriesDataset(torch.utils.data.Dataset):
|
||||
def __init__(self, x, y, seq_len=100):
|
||||
self.seq_len = seq_len
|
||||
self.x = x
|
||||
self.y = y
|
||||
|
||||
def __len__(self):
|
||||
return self.x.__len__() - self.seq_len
|
||||
|
||||
def __getitem__(self, index):
|
||||
return (self.x[index:index+self.seq_len], self.y[index+self.seq_len])
|
||||
|
||||
|
||||
#%%
|
||||
## Copied from https://stackoverflow.com/questions/50817916/how-do-i-add-lstm-gru-or-other-recurrent-layers-to-a-sequential-in-pytorch
|
||||
class SelectItem(nn.Module):
|
||||
def __init__(self, item_index):
|
||||
super(SelectItem, self).__init__()
|
||||
self._name = 'selectitem'
|
||||
self.item_index = item_index
|
||||
|
||||
def forward(self, inputs):
|
||||
return inputs[self.item_index]
|
||||
|
||||
#%%
|
||||
class MultiRNNLightning(L.LightningModule):
|
||||
def __init__(self, input_size, output_size, args, conf):
|
||||
super().__init__()
|
||||
self.model = nn.Sequential(nn.LSTM(input_size, args.hidden_size, batch_first=True, bidirectional=args.b),
|
||||
SelectItem(1),
|
||||
SelectItem(0),
|
||||
nn.ReLU(),
|
||||
nn.Linear(args.hidden_size*(2 if args.b else 1), output_size))
|
||||
|
||||
self.save_hyperparameters(conf)
|
||||
|
||||
# self.model = nn.GRU(input_size, hidden_size, batch_first=True)
|
||||
|
||||
# self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
# self.model.to(self.device)
|
||||
|
||||
@staticmethod
|
||||
def add_model_specific_args(parent_parser):
|
||||
parser = ArgumentParser(parents=[parent_parser], add_help=False)
|
||||
parser.add_argument('--hidden_size', type=int, default=100)
|
||||
parser.add_argument('-b', action='store_true')
|
||||
return parser
|
||||
|
||||
def forward(self, x):
|
||||
# print("hi")
|
||||
return torch.squeeze(self.model(x), 0)
|
||||
|
||||
def configure_optimizers(self):
|
||||
optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-3)
|
||||
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
||||
optimizer, mode="min", factor=0.1, patience=5
|
||||
)
|
||||
|
||||
return {
|
||||
"optimizer": optimizer,
|
||||
"lr_scheduler": {
|
||||
"scheduler": scheduler,
|
||||
"monitor": "val_loss",
|
||||
},
|
||||
}
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
x, y = batch
|
||||
y_hat = self(x)
|
||||
loss = nn.functional.mse_loss(y_hat, y)
|
||||
# print(batch)
|
||||
self.log('train_loss', loss, prog_bar=True) #, on_epoch=True)
|
||||
self.log('training_loss_per_unit', loss/y.shape[0])
|
||||
return loss
|
||||
|
||||
def validation_step(self, val_batch, batch_idx):
|
||||
x, y = val_batch
|
||||
y_hat = self(x)
|
||||
loss = nn.functional.mse_loss(y_hat, y)
|
||||
self.log('val_loss', loss, prog_bar=True) #, on_epoch=True)
|
||||
# print(val_batch)
|
||||
self.log('val_loss_per_unit', loss/y.shape[0])
|
||||
return loss
|
||||
|
||||
def test_step(self, test_batch, batch_idx):
|
||||
x, y = test_batch
|
||||
y_hat = self(x)
|
||||
loss = nn.functional.mse_loss(y_hat, y)
|
||||
self.log('test_loss', loss, prog_bar=True) #, on_epoch=True)
|
||||
self.log('test_loss_per_unit', loss/y.shape[0])
|
||||
return loss
|
||||
|
||||
|
||||
|
||||
|
||||
#%%
|
||||
parser = ArgumentParser()
|
||||
|
||||
|
||||
|
||||
#%%
|
||||
parser.add_argument('--val_size', type=float, default=0.2)
|
||||
parser.add_argument('--test_size', type=float, default=0.2)
|
||||
|
||||
parser.add_argument('--seq_size', type=int, default=1000)
|
||||
|
||||
parser.add_argument('--data_dir', type=pathlib.Path, default=pathlib.Path("./data"))
|
||||
|
||||
parser.add_argument('--batch_size', type=int, default=10)
|
||||
|
||||
parser.add_argument('-x', type=str, nargs='+', default=['time_to_next_event','time','lat','lon','depth','mag'])
|
||||
parser.add_argument('-y', type=str, nargs='+', default=['time_to_next_event','lat','lon','depth','mag'])
|
||||
|
||||
parser = MultiRNNLightning.add_model_specific_args(parser)
|
||||
# parser = L.Trainer.add_argparse_args(parser)
|
||||
parser.add_argument('--max_epochs', type=int, default=1)
|
||||
|
||||
#%%
|
||||
args = parser.parse_args()
|
||||
|
||||
#%%
|
||||
datadir = args.data_dir / 'reduced_data.csv'
|
||||
print(datadir.as_posix())
|
||||
df = pd.read_csv(datadir.as_posix())
|
||||
|
||||
#%%
|
||||
val_size = int(args.val_size*len(df))
|
||||
test_size = int(args.test_size*len(df))
|
||||
train_size = len(df)-val_size-test_size
|
||||
|
||||
assert val_size >= 0, "validation size is less than 0"
|
||||
assert test_size >= 0, "test size is less than 0"
|
||||
assert train_size > 0, "train size is non-positive"
|
||||
|
||||
|
||||
#%%
|
||||
x_columns = args.x
|
||||
y_columns = args.y
|
||||
|
||||
x = torch.tensor(df[x_columns].values).to(torch.float)
|
||||
y = torch.tensor(df[y_columns].values).to(torch.float)
|
||||
|
||||
mag_index = -1 if 'mag' not in y_columns else y_columns.find('mag')
|
||||
|
||||
train_x, val_x, test_x = torch.split(x, [train_size, val_size, test_size])
|
||||
|
||||
train_y, val_y, test_y = torch.split(y, [train_size, val_size, test_size])
|
||||
|
||||
train_dataset = TimeseriesDataset(train_x, train_y, seq_len=args.seq_size)
|
||||
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
|
||||
|
||||
val_dataset = TimeseriesDataset(val_x, val_y, seq_len=args.seq_size)
|
||||
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False)
|
||||
|
||||
test_dataset = TimeseriesDataset(test_x, test_y, seq_len=args.seq_size)
|
||||
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False)
|
||||
|
||||
#%%
|
||||
input_size = len(x_columns)
|
||||
output_size = len(y_columns)
|
||||
|
||||
#%%
|
||||
|
||||
model = MultiRNNLightning(input_size, output_size, args, {"input_dim":input_size, "output_dim":output_size, "hidden_dim": args.hidden_size, "bidirectional": args.b, "batch_size":args.batch_size, "seq_len":args.seq_size, "input_features": args.x, "output_features": args.y})
|
||||
trainer = L.Trainer(max_epochs=args.max_epochs)
|
||||
|
||||
|
||||
|
||||
#%%
|
||||
trainer.fit(model, train_loader, val_loader)
|
||||
|
||||
|
||||
#%%
|
||||
trainer.test(model, test_loader)
|
||||
Loading…
Reference in New Issue
Block a user