Merge pull request 'WIP backend' (#27) from backend into main
Reviewed-on: #27
This commit is contained in:
commit
971914442c
52
backend/.air.toml
Normal file
52
backend/.air.toml
Normal file
@ -0,0 +1,52 @@
|
||||
root = "."
|
||||
testdata_dir = "testdata"
|
||||
tmp_dir = "bin"
|
||||
|
||||
[build]
|
||||
args_bin = []
|
||||
bin = "./bin/main.exe"
|
||||
cmd = "go build -o ./bin/main.exe ./cmd/api"
|
||||
delay = 1000
|
||||
exclude_dir = ["assets", "bin", "vendor", "testdata", "docs", "scripts"]
|
||||
exclude_file = []
|
||||
exclude_regex = ["_test.go"]
|
||||
exclude_unchanged = false
|
||||
follow_symlink = false
|
||||
full_bin = ""
|
||||
include_dir = []
|
||||
include_ext = ["go", "tpl", "tmpl", "html"]
|
||||
include_file = []
|
||||
kill_delay = "0s"
|
||||
log = "build-errors.log"
|
||||
poll = false
|
||||
poll_interval = 0
|
||||
post_cmd = []
|
||||
pre_cmd = []
|
||||
rerun = false
|
||||
rerun_delay = 500
|
||||
send_interrupt = false
|
||||
stop_on_error = false
|
||||
|
||||
[color]
|
||||
app = ""
|
||||
build = "yellow"
|
||||
main = "magenta"
|
||||
runner = "green"
|
||||
watcher = "cyan"
|
||||
|
||||
[log]
|
||||
main_only = false
|
||||
silent = false
|
||||
time = false
|
||||
|
||||
[misc]
|
||||
clean_on_exit = false
|
||||
|
||||
[proxy]
|
||||
app_port = 0
|
||||
enabled = false
|
||||
proxy_port = 0
|
||||
|
||||
[screen]
|
||||
clear_on_rebuild = false
|
||||
keep_scroll = true
|
||||
1
backend/bin/build-errors.log
Normal file
1
backend/bin/build-errors.log
Normal file
@ -0,0 +1 @@
|
||||
exit status 1exit status 1exit status 1exit status 1exit status 1exit status 1exit status 1exit status 1
|
||||
BIN
backend/bin/main.exe
Normal file
BIN
backend/bin/main.exe
Normal file
Binary file not shown.
14
backend/bucket_lambdas/minio_image_preview/main.go
Normal file
14
backend/bucket_lambdas/minio_image_preview/main.go
Normal file
@ -0,0 +1,14 @@
|
||||
package main
|
||||
|
||||
func init() {
|
||||
|
||||
}
|
||||
|
||||
func transformImage() {
|
||||
|
||||
// imgtransform.ResizeImage(,10, 10)
|
||||
}
|
||||
|
||||
func main() {
|
||||
|
||||
}
|
||||
1
backend/bucket_lambdas/plan.txt
Normal file
1
backend/bucket_lambdas/plan.txt
Normal file
@ -0,0 +1 @@
|
||||
Will use minio or s3 lambda object transform functions. This can be used to get thumbnail or other sized images so that the full s3 doesn't have to be used.
|
||||
2
backend/bucket_lambdas/s3_image_preview/main.go
Normal file
2
backend/bucket_lambdas/s3_image_preview/main.go
Normal file
@ -0,0 +1,2 @@
|
||||
package main
|
||||
|
||||
266
backend/cmd/api/api.go
Normal file
266
backend/cmd/api/api.go
Normal file
@ -0,0 +1,266 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/go-chi/chi/v5/middleware"
|
||||
"github.com/go-chi/httprate"
|
||||
|
||||
// "git.ewellenr.ca/receipt_indexer/backend/internal/auth"
|
||||
"git.ewellenr.ca/receipt_indexer/backend/internal/env"
|
||||
"git.ewellenr.ca/receipt_indexer/backend/internal/logger"
|
||||
"git.ewellenr.ca/receipt_indexer/backend/internal/ratelimiter"
|
||||
"git.ewellenr.ca/receipt_indexer/backend/internal/storage"
|
||||
auth_storage "git.ewellenr.ca/receipt_indexer/backend/internal/storage/auth"
|
||||
"git.ewellenr.ca/receipt_indexer/backend/internal/storage/cache"
|
||||
)
|
||||
|
||||
type application struct {
|
||||
// set up the configs stuff here
|
||||
config config
|
||||
auth auth_storage.AuthStorage
|
||||
store storage.Storage
|
||||
logger logger.Logger
|
||||
cacheStorage cache.Storage
|
||||
rateLimiter ratelimiter.Limiter
|
||||
environment env.Environment
|
||||
}
|
||||
|
||||
type config struct {
|
||||
addr string
|
||||
rateLimiter ratelimiter.Config
|
||||
redisCfg redisConfig
|
||||
|
||||
// holds the different stuff like rate limiter, store, authenticator
|
||||
}
|
||||
|
||||
type redisConfig struct {
|
||||
addr string
|
||||
pw string
|
||||
db int
|
||||
enabled bool
|
||||
}
|
||||
|
||||
func (app *application) mount() http.Handler {
|
||||
r := chi.NewRouter()
|
||||
|
||||
r.Use(middleware.Logger)
|
||||
r.Use(middleware.RequestID)
|
||||
r.Use(middleware.RealIP)
|
||||
r.Use(middleware.CleanPath)
|
||||
r.Use(middleware.Recoverer)
|
||||
|
||||
r.Use(middleware.Throttle(100)) // temporary or removable. throttles the whole thing to 1000 concurrent requests
|
||||
// r.Use(cors.Handler(cors.Options{
|
||||
// AllowedOrigins: []string{env.GetString("CORS_ALLOWED_ORIGIN", "http://localhost:5174")},
|
||||
// AllowedMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"},
|
||||
// AllowedHeaders: []string{"Accept", "Authorization", "Content-Type", "X-CSRF-Token"},
|
||||
// ExposedHeaders: []string{"Link"},
|
||||
// AllowCredentials: false,
|
||||
// MaxAge: 300, // Maximum value not ignored by any of major browsers
|
||||
// }))
|
||||
|
||||
// use either the chi built rate limiter or the custom built one
|
||||
if app.config.rateLimiter.Enabled {
|
||||
// r.Use(app.RateLimiterMiddleware)
|
||||
r.Use(httprate.LimitByRealIP(app.config.rateLimiter.RequestsPerTimeFrame, app.config.rateLimiter.TimeFrame))
|
||||
}
|
||||
// Set a timeout value on the request context (ctx), that will signal
|
||||
// through ctx.Done() that the request has timed out and further
|
||||
// processing should be stopped.
|
||||
r.Use(middleware.Timeout(60 * time.Second))
|
||||
|
||||
r.Use(middleware.Heartbeat("/ping"))
|
||||
|
||||
// v1 of api
|
||||
r.Route("/v1", func(r chi.Router) {
|
||||
|
||||
// SEPERATE STUFF FOR THE LOGIN RELATED STUFF. CONSIDER A GROUP
|
||||
|
||||
// FIX API. NEED TO ALSO CONSIDER GROUPS AND STUFF
|
||||
// Operations
|
||||
// r.Get("/health", app.healthCheckHandler)
|
||||
// r.With(app.BasicAuthMiddleware()).Get("/debug/vars", expvar.Handler().ServeHTTP)
|
||||
|
||||
// docsURL := fmt.Sprintf("%s/swagger/doc.json", app.config.addr)
|
||||
// r.Get("/swagger/*", httpSwagger.Handler(httpSwagger.URL(docsURL)))
|
||||
|
||||
// Need to sign in as a user. Then, you can see the groups you're in, your role in the groups,
|
||||
|
||||
r.Route("/user", func(r chi.Router) {
|
||||
r.Route("/{userID}", func(r chi.Router) {
|
||||
r.Use(app.AuthSessionMiddleware, app.CSRFCheckMiddleware, app.CheckUserMatchingMiddleware)
|
||||
|
||||
r.Get("/", app.getUserHandler)
|
||||
|
||||
r.Route("/groups", func(r chi.Router) {
|
||||
r.Get("/", app.getUsersGroupsHandler)
|
||||
|
||||
r.Route("/{groupID}", func(r chi.Router) {
|
||||
r.Get("/", app.getUsersGroupHandler)
|
||||
r.Delete("/", app.removeUserGroupHandler) // maybe this should expect authentication headers to reverify the password when deleting a group you own.
|
||||
|
||||
r.Put("/moderator", app.addGroupModeratorHandler)
|
||||
r.Delete("/moderator/{secondaryuserID}", app.removeModeratorPriviligesHandler)
|
||||
|
||||
r.Get("/users", app.getGroupUsersHandler)
|
||||
r.Delete("/users/{secondaryuserID}", app.removeUserFromGroupHandler)
|
||||
|
||||
r.Put("/owner", app.setGroupOwnerHandler)
|
||||
})
|
||||
})
|
||||
|
||||
r.Route("/receipts", func(r chi.Router) {
|
||||
r.Get("/", app.getReceiptsHandler)
|
||||
r.Route("/{receiptID}", func(r chi.Router) {
|
||||
r.Get("/", app.getReceiptHandler)
|
||||
r.Delete("/", app.deleteReceiptHandler)
|
||||
|
||||
r.Route("/images", func(r chi.Router) {
|
||||
r.Get("/", app.getReceiptImagesHandler)
|
||||
r.Put("/", app.addReceiptImageHandler)
|
||||
r.Route("/{imageID}", func(r chi.Router) {
|
||||
r.Get("/", app.getReceiptImageHandler)
|
||||
r.Put("/", app.changeReceiptImageHandler)
|
||||
r.Delete("/", app.deleteReceiptImageHandler)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
})
|
||||
|
||||
})
|
||||
|
||||
})
|
||||
|
||||
r.Use(app.CSRFCheckMiddleware)
|
||||
|
||||
r.Group(func(r chi.Router) {
|
||||
r.Use(app.AuthSessionMiddleware)
|
||||
r.Use(app.CSRFCheckMiddleware)
|
||||
|
||||
r.Route("/groups", func(r chi.Router) {
|
||||
r.Get("/", app.getGroupsHandler)
|
||||
r.Route("/{groupID}", func(r chi.Router) {
|
||||
r.Get("/", app.getGroupHandler)
|
||||
})
|
||||
})
|
||||
|
||||
r.Route("/users", func(r chi.Router) {
|
||||
r.With(app.CheckRoleMiddleware("admin")).Get("/", app.getUsersHandler)
|
||||
|
||||
r.Route("/{userID}", func(r chi.Router) {
|
||||
|
||||
r.With(app.CheckRoleMiddleware("admin")).Delete("/", app.getUserHandler)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
// r.Route("/users", func(r chi.Router) {
|
||||
// // r.Put("/activate/{token}", app.activateUserHandler)
|
||||
|
||||
// // r.Get("/", app.check)
|
||||
|
||||
// r.Route("/{userID}", func(r chi.Router) {
|
||||
// r.Use(app.AuthSessionMiddleware)
|
||||
// r.Use(app.CSRFCheckMiddleware)
|
||||
// // r.Use(app.CheckUserMatchingMiddleware)
|
||||
|
||||
// r.Get("/", app.getUserHandler)
|
||||
|
||||
// r.Route("/receipts", func(r chi.Router) {
|
||||
// r.With(app.Paginate).Get("/", app.getReceiptsHandler)
|
||||
|
||||
// r.Post("/", app.createReceiptHandler)
|
||||
|
||||
// r.Route("/{receiptID}", func(r chi.Router) {
|
||||
// r.Use(app.receiptsContextMiddleware)
|
||||
|
||||
// r.Get("/", app.getReceiptHandler)
|
||||
// r.Patch("/", app.updateReceiptHandler)
|
||||
// r.Delete("/", app.checkReceiptOwnership("admin", app.deleteReceiptHandler))
|
||||
|
||||
// r.Route("/images", func(r chi.Router) {
|
||||
// r.Post("/", app.addImageHandler)
|
||||
// r.Delete("/{imageID}", app.deleteImageHandler)
|
||||
// })
|
||||
// })
|
||||
// })
|
||||
// })
|
||||
|
||||
// })
|
||||
|
||||
// // Admin page routes
|
||||
// r.Route("/admin", func(r chi.Router) {
|
||||
// r.Use(app.AuthSessionMiddleware)
|
||||
// r.Use(app.CheckRoleMiddleware("admin"))
|
||||
|
||||
// r.Route("/users", func(r chi.Router) {
|
||||
// r.Get("/", app.getUsersHandler)
|
||||
// r.Delete("/{userID}", app.deleteUserHandler)
|
||||
// })
|
||||
|
||||
// })
|
||||
|
||||
// Public routes
|
||||
r.Route("/auth", func(r chi.Router) {
|
||||
r.Post("/login", app.loginHandler)
|
||||
r.Post("/newuser", app.registerUserHandler)
|
||||
r.Post("/refreshtoken", app.refreshTokenHandler)
|
||||
r.Post("/logout", app.logoutHandler)
|
||||
})
|
||||
})
|
||||
|
||||
return r
|
||||
}
|
||||
|
||||
func (app *application) run(mux http.Handler) error {
|
||||
|
||||
srv := &http.Server{
|
||||
Addr: app.config.addr,
|
||||
Handler: mux,
|
||||
WriteTimeout: time.Second * 30,
|
||||
ReadTimeout: time.Second * 10,
|
||||
IdleTimeout: time.Minute,
|
||||
}
|
||||
|
||||
shutdown := make(chan error)
|
||||
|
||||
go func() {
|
||||
quit := make(chan os.Signal, 1)
|
||||
|
||||
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
|
||||
s := <-quit
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
app.logger.Info("signal caught", "signal", s.String())
|
||||
|
||||
shutdown <- srv.Shutdown(ctx)
|
||||
}()
|
||||
|
||||
app.logger.Info("server has started", "addr", app.config.addr) //, "env", app.config.env)
|
||||
|
||||
err := srv.ListenAndServe()
|
||||
|
||||
if !errors.Is(err, http.ErrServerClosed) {
|
||||
return err
|
||||
}
|
||||
|
||||
err = <-shutdown
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
app.logger.Info("server has stopped", "addr", app.config.addr) //, "env", app.config.env)
|
||||
|
||||
return nil
|
||||
}
|
||||
91
backend/cmd/api/auth.go
Normal file
91
backend/cmd/api/auth.go
Normal file
@ -0,0 +1,91 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func (app *application) loginHandler(w http.ResponseWriter, r *http.Request) {
|
||||
// should give them a cookie in the response
|
||||
authHeader := r.Header.Get("Authorization")
|
||||
if authHeader == "" {
|
||||
app.unauthorizedBasicErrorResponse(w, r, fmt.Errorf("authorization header is missing"))
|
||||
return
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
|
||||
// parse it -> get the base64
|
||||
parts := strings.Split(authHeader, " ")
|
||||
if len(parts) != 2 || parts[0] != "Basic" {
|
||||
app.unauthorizedBasicErrorResponse(w, r, fmt.Errorf("authorization header is malformed"))
|
||||
return
|
||||
}
|
||||
|
||||
// decode it
|
||||
decoded, err := base64.StdEncoding.DecodeString(parts[1])
|
||||
if err != nil {
|
||||
app.unauthorizedBasicErrorResponse(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
// check the credentials
|
||||
creds := strings.SplitN(string(decoded), ":", 2)
|
||||
if len(creds) != 2 {
|
||||
app.unauthorizedBasicErrorResponse(w, r, fmt.Errorf("invalid credentials"))
|
||||
return
|
||||
}
|
||||
username, pass := creds[0], creds[1]
|
||||
|
||||
valid, user, err := app.auth.Users.SigninUser(ctx, username, pass)
|
||||
if !valid || err != nil {
|
||||
app.unauthorizedBasicErrorResponse(w, r, fmt.Errorf("invalid credentials"))
|
||||
return
|
||||
}
|
||||
|
||||
token, err := app.auth.Sessions.AddSession(ctx, user.ID)
|
||||
if err != nil {
|
||||
app.unauthorizedBasicErrorResponse(w, r, fmt.Errorf("failed to add session"))
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Add("Vary", "Cookie")
|
||||
w.Header().Add("Cache-Control", `no-cache="Set-Cookie"`)
|
||||
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: "Session_token",
|
||||
Value: token,
|
||||
Path: "/",
|
||||
HttpOnly: true,
|
||||
Secure: true,
|
||||
SameSite: http.SameSiteStrictMode,
|
||||
// Set an expiry and path
|
||||
})
|
||||
// also need to set the csrf or cors or refresh stuff here. Probably just the refresh token stuff (actually, with sessions, there will be no refresh token. Just updating the expiry)
|
||||
|
||||
// cache and ignore if the cache fails?
|
||||
if !app.config.redisCfg.enabled {
|
||||
if err := app.cacheStorage.Users.Set(ctx, user); err != nil {
|
||||
app.internalServerError(w, r, fmt.Errorf("Failed to add user to cache"))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// since it could be a different storage system, idk if it makes sense to do any caching or to just deal with the overhead of logging in. It doesn't happen that often anyways so caching shouldn't make much of a difference. Maybe load it into cache after
|
||||
|
||||
// consider using the cached user stuff instead. that user will have the hashed password which is fine to move around. And then it can just do the compare stuff using whatever auth. Maybe instead, the checkpass can take in a user
|
||||
}
|
||||
|
||||
func (app *application) logoutHandler(w http.ResponseWriter, r *http.Request) {
|
||||
// should give them a cookie in the response
|
||||
}
|
||||
|
||||
func (app *application) registerUserHandler(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
}
|
||||
|
||||
func (app *application) refreshTokenHandler(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
}
|
||||
57
backend/cmd/api/errors.go
Normal file
57
backend/cmd/api/errors.go
Normal file
@ -0,0 +1,57 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
)
|
||||
|
||||
func (app *application) internalServerError(w http.ResponseWriter, r *http.Request, err error) {
|
||||
app.logger.Error("internal error", "method", r.Method, "path", r.URL.Path, "error", err.Error())
|
||||
|
||||
writeJSONError(w, http.StatusInternalServerError, "the server encountered a problem")
|
||||
}
|
||||
|
||||
func (app *application) forbiddenResponse(w http.ResponseWriter, r *http.Request) {
|
||||
app.logger.Warn("forbidden", "method", r.Method, "path", r.URL.Path, "error")
|
||||
|
||||
writeJSONError(w, http.StatusForbidden, "forbidden")
|
||||
}
|
||||
|
||||
func (app *application) badRequestResponse(w http.ResponseWriter, r *http.Request, err error) {
|
||||
app.logger.Warn("bad request", "method", r.Method, "path", r.URL.Path, "error", err.Error())
|
||||
|
||||
writeJSONError(w, http.StatusBadRequest, err.Error())
|
||||
}
|
||||
|
||||
func (app *application) conflictResponse(w http.ResponseWriter, r *http.Request, err error) {
|
||||
app.logger.Error("conflict response", "method", r.Method, "path", r.URL.Path, "error", err.Error())
|
||||
|
||||
writeJSONError(w, http.StatusConflict, err.Error())
|
||||
}
|
||||
|
||||
func (app *application) notFoundResponse(w http.ResponseWriter, r *http.Request, err error) {
|
||||
app.logger.Warn("not found error", "method", r.Method, "path", r.URL.Path, "error", err.Error())
|
||||
|
||||
writeJSONError(w, http.StatusNotFound, "not found")
|
||||
}
|
||||
|
||||
func (app *application) unauthorizedErrorResponse(w http.ResponseWriter, r *http.Request, err error) {
|
||||
app.logger.Warn("unauthorized error", "method", r.Method, "path", r.URL.Path, "error", err.Error())
|
||||
|
||||
writeJSONError(w, http.StatusUnauthorized, "unauthorized")
|
||||
}
|
||||
|
||||
func (app *application) unauthorizedBasicErrorResponse(w http.ResponseWriter, r *http.Request, err error) {
|
||||
app.logger.Warn("unauthorized basic error", "method", r.Method, "path", r.URL.Path, "error", err.Error())
|
||||
|
||||
w.Header().Set("WWW-Authenticate", `Basic realm="restricted", charset="UTF-8"`)
|
||||
|
||||
writeJSONError(w, http.StatusUnauthorized, "unauthorized")
|
||||
}
|
||||
|
||||
func (app *application) rateLimitExceededResponse(w http.ResponseWriter, r *http.Request, retryAfter string) {
|
||||
app.logger.Warn("rate limit exceeded", "method", r.Method, "path", r.URL.Path)
|
||||
|
||||
w.Header().Set("Retry-After", retryAfter)
|
||||
|
||||
writeJSONError(w, http.StatusTooManyRequests, "rate limit exceeded, retry after: "+retryAfter)
|
||||
}
|
||||
8
backend/cmd/api/health.go
Normal file
8
backend/cmd/api/health.go
Normal file
@ -0,0 +1,8 @@
|
||||
package main
|
||||
|
||||
import "net/http"
|
||||
|
||||
func (app *application) healthCheckHandler(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte("ok"))
|
||||
|
||||
}
|
||||
44
backend/cmd/api/images.go
Normal file
44
backend/cmd/api/images.go
Normal file
@ -0,0 +1,44 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"git.ewellenr.ca/receipt_indexer/backend/internal/storage"
|
||||
)
|
||||
|
||||
func (app *application) getImage(ctx context.Context, imageID int64) (*storage.Image, error) {
|
||||
if !app.config.redisCfg.enabled {
|
||||
return app.store.Images.GetByID(ctx, imageID)
|
||||
}
|
||||
|
||||
image, err := app.cacheStorage.ReceiptImage.Get(ctx, imageID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if image == nil {
|
||||
image, err = app.store.Images.GetByID(ctx, imageID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := app.cacheStorage.ReceiptImage.Set(ctx, image); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return image, nil
|
||||
}
|
||||
|
||||
func (app *application) addImageHandler(w http.ResponseWriter, r *http.Request) {
|
||||
// create a new image, add it to the receipt. this should be a function because it should be one whole transaction
|
||||
|
||||
// it should be do the database transaction, attempt the upload, and then abort/commit the transaction depending on the results of the upload. While uploading directly to the s3/minio would be good, it doesn't provide the transactionality that is required
|
||||
}
|
||||
|
||||
func (app *application) deleteImageHandler(w http.ResponseWriter, r *http.Request) {
|
||||
// delete image and remove from cache
|
||||
|
||||
// this also needs to be transactional first the database transaction stuff, then attempt the delete, and then abort/commit the transaction depending on the results of the upload.
|
||||
}
|
||||
20
backend/cmd/api/json.go
Normal file
20
backend/cmd/api/json.go
Normal file
@ -0,0 +1,20 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
func writeJSON(w http.ResponseWriter, status int, data any) error {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(status)
|
||||
return json.NewEncoder(w).Encode(data)
|
||||
}
|
||||
|
||||
func writeJSONError(w http.ResponseWriter, status int, message string) error {
|
||||
type envelope struct {
|
||||
Error string `json:"error"`
|
||||
}
|
||||
|
||||
return writeJSON(w, status, &envelope{Error: message})
|
||||
}
|
||||
21
backend/cmd/api/main.go
Normal file
21
backend/cmd/api/main.go
Normal file
@ -0,0 +1,21 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"log"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// set up the application config here
|
||||
|
||||
cfg := config{
|
||||
addr: ":8080",
|
||||
}
|
||||
|
||||
app := application{
|
||||
config: cfg,
|
||||
}
|
||||
|
||||
// fmt.Println(app)
|
||||
mux := app.mount()
|
||||
log.Fatal(app.run(mux))
|
||||
}
|
||||
197
backend/cmd/api/middleware.go
Normal file
197
backend/cmd/api/middleware.go
Normal file
@ -0,0 +1,197 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"strconv"
|
||||
|
||||
auth_storage "git.ewellenr.ca/receipt_indexer/backend/internal/storage/auth"
|
||||
"github.com/go-chi/chi/v5"
|
||||
)
|
||||
|
||||
const sessionCtx string = "session_id"
|
||||
|
||||
func (app *application) AuthSessionMiddleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
cookie, err := r.Cookie("Session")
|
||||
if err != nil {
|
||||
app.unauthorizedErrorResponse(w, r, fmt.Errorf("Session cookie is missing"))
|
||||
return
|
||||
}
|
||||
// parse cookie
|
||||
token := cookie.Value
|
||||
if token == "" {
|
||||
app.unauthorizedErrorResponse(w, r, fmt.Errorf("Empty session token"))
|
||||
return
|
||||
}
|
||||
|
||||
valid, userID, err := app.auth.Sessions.CheckSession(r.Context(), token) // should have a different function for this
|
||||
if !valid {
|
||||
app.unauthorizedErrorResponse(w, r, fmt.Errorf("Invalid session token"))
|
||||
return
|
||||
}
|
||||
// userID comes from the session check
|
||||
|
||||
// need to select the user from the token validation stuff
|
||||
ctx := r.Context()
|
||||
|
||||
user, err := app.getUser(ctx, userID)
|
||||
if err != nil {
|
||||
app.unauthorizedErrorResponse(w, r, err)
|
||||
return
|
||||
}
|
||||
ctx = context.WithValue(ctx, userCtx, user)
|
||||
|
||||
ctx = context.WithValue(ctx, sessionCtx, token)
|
||||
|
||||
// make sure to add user and role into the context here
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
|
||||
func (app *application) CSRFCheckMiddleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
})
|
||||
}
|
||||
|
||||
func (app *application) Paginate(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// add the page and size to the context
|
||||
|
||||
})
|
||||
}
|
||||
|
||||
// Also include the stuff like getting the receipt and stuff
|
||||
|
||||
func (app *application) checkRole(next http.Handler, roleName string) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
user := getUserFromContext(r)
|
||||
|
||||
allowed, err := app.checkRolePrecedence(r.Context(), user, roleName)
|
||||
if err != nil {
|
||||
app.internalServerError(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
if !allowed {
|
||||
app.forbiddenResponse(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
// essentially a role checking middleware factory
|
||||
func (app *application) CheckRoleMiddleware(roleName string) func(next http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return app.checkRole(next, roleName)
|
||||
}
|
||||
}
|
||||
|
||||
func (app *application) CheckUserMatchingMiddleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
urlUserID, err := strconv.ParseInt(chi.URLParam(r, "userID"), 10, 64)
|
||||
if err != nil {
|
||||
app.badRequestResponse(w, r, fmt.Errorf("Invalid url user ID - Not an integer"))
|
||||
return
|
||||
}
|
||||
|
||||
user := getUserFromContext(r)
|
||||
if int64(urlUserID) != user.ID {
|
||||
app.forbiddenResponse(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
func (app *application) RateLimiterMiddleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if app.config.rateLimiter.Enabled {
|
||||
ip, _, err := net.SplitHostPort(r.RemoteAddr)
|
||||
if err != nil {
|
||||
app.internalServerError(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
if allow, retryAfter := app.rateLimiter.Allow(ip); !allow {
|
||||
app.rateLimitExceededResponse(w, r, retryAfter.String())
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
func (app *application) receiptsContextMiddleware(next http.Handler) http.Handler {
|
||||
// add the receipt id to the context? or the receipt class to the context
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
receiptID, err := strconv.ParseInt(chi.URLParam(r, "receiptID"), 10, 64)
|
||||
if err != nil {
|
||||
app.badRequestResponse(w, r, fmt.Errorf("Invalid url receipt ID - Not an integer"))
|
||||
return
|
||||
}
|
||||
|
||||
// need to select the receipt
|
||||
ctx := r.Context()
|
||||
|
||||
receipt, err := app.getReceipt(ctx, receiptID)
|
||||
if err != nil {
|
||||
app.internalServerError(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
if receipt.OwnerID != getUserFromContext(r).ID {
|
||||
app.forbiddenResponse(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
ctx = context.WithValue(ctx, receiptCtx, receipt)
|
||||
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
func (app *application) checkRolePrecedence(ctx context.Context, user *auth_storage.User, roleName string) (bool, error) {
|
||||
role, err := app.auth.Roles.GetByName(ctx, roleName)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return user.Role.Level >= role.Level, nil
|
||||
}
|
||||
|
||||
func (app *application) checkReceiptOwnership(requiredRole string, next http.HandlerFunc) http.HandlerFunc {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
user := getUserFromContext(r)
|
||||
receipt := getReceiptFromContext(r)
|
||||
|
||||
if receipt.OwnerID == user.ID {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
allowed, err := app.checkRolePrecedence(r.Context(), user, requiredRole)
|
||||
if err != nil {
|
||||
app.internalServerError(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
if !allowed {
|
||||
app.forbiddenResponse(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
70
backend/cmd/api/receipts.go
Normal file
70
backend/cmd/api/receipts.go
Normal file
@ -0,0 +1,70 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"git.ewellenr.ca/receipt_indexer/backend/internal/storage"
|
||||
)
|
||||
|
||||
type receiptKey string
|
||||
|
||||
const receiptCtx receiptKey = "receipt"
|
||||
|
||||
func (app *application) getReceiptsHandler(w http.ResponseWriter, r *http.Request) {
|
||||
// get the page and size from context
|
||||
// default them to something
|
||||
|
||||
}
|
||||
|
||||
func (app *application) getReceiptHandler(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
}
|
||||
|
||||
func (app *application) createReceiptHandler(w http.ResponseWriter, r *http.Request) {
|
||||
// handle receipt creation logic here
|
||||
|
||||
}
|
||||
|
||||
func (app *application) updateReceiptHandler(w http.ResponseWriter, r *http.Request) {
|
||||
// handle receipt update logic here
|
||||
// Not too sure what to do here. need to break it down into what can actually be updated via the api
|
||||
|
||||
}
|
||||
|
||||
func (app *application) deleteReceiptHandler(w http.ResponseWriter, r *http.Request) {
|
||||
// delete the receipt
|
||||
// should be as simple as getting the receipt and then calling the delete function.
|
||||
// Should also delete from cache if cache is enabled
|
||||
// the delete function should be like a transaction. It should do the transaction for the receipt and images in the db.
|
||||
// Then it should try and delete all the images in the S3/minio bucket. Abort/commit otherwise
|
||||
}
|
||||
|
||||
func (app *application) getReceipt(ctx context.Context, receiptID int64) (*storage.Receipt, error) {
|
||||
if !app.config.redisCfg.enabled {
|
||||
return app.store.Receipts.GetByID(ctx, receiptID)
|
||||
}
|
||||
|
||||
receipt, err := app.cacheStorage.Receipts.Get(ctx, receiptID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if receipt == nil {
|
||||
receipt, err = app.store.Receipts.GetByID(ctx, receiptID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := app.cacheStorage.Receipts.Set(ctx, receipt); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return receipt, nil
|
||||
}
|
||||
|
||||
func getReceiptFromContext(r *http.Request) *storage.Receipt {
|
||||
receipt, _ := r.Context().Value(receiptCtx).(*storage.Receipt)
|
||||
return receipt
|
||||
}
|
||||
50
backend/cmd/api/users.go
Normal file
50
backend/cmd/api/users.go
Normal file
@ -0,0 +1,50 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
auth_storage "git.ewellenr.ca/receipt_indexer/backend/internal/storage/auth"
|
||||
)
|
||||
|
||||
type userKey string
|
||||
|
||||
const userCtx userKey = "user"
|
||||
|
||||
func (app *application) getUserHandler(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
}
|
||||
|
||||
func (app *application) getUsersHandler(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
func (app *application) deleteUserHandler(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
func (app *application) getUser(ctx context.Context, userID int64) (*auth_storage.User, error) {
|
||||
if !app.config.redisCfg.enabled {
|
||||
return app.auth.Users.GetByID(ctx, userID)
|
||||
}
|
||||
|
||||
user, err := app.cacheStorage.Users.Get(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if user == nil {
|
||||
user, err = app.auth.Users.GetByID(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := app.cacheStorage.Users.Set(ctx, user); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func getUserFromContext(r *http.Request) *auth_storage.User {
|
||||
user, _ := r.Context().Value(userCtx).(*auth_storage.User)
|
||||
return user
|
||||
}
|
||||
22
backend/go.mod
Normal file
22
backend/go.mod
Normal file
@ -0,0 +1,22 @@
|
||||
module git.ewellenr.ca/receipt_indexer/backend
|
||||
|
||||
go 1.24.2
|
||||
|
||||
require (
|
||||
github.com/cespare/xxhash/v2 v2.2.0 // indirect
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
|
||||
github.com/go-chi/chi/v5 v5.2.1 // indirect
|
||||
github.com/go-chi/httprate v0.15.0 // indirect
|
||||
github.com/klauspost/cpuid/v2 v2.2.10 // indirect
|
||||
github.com/mattn/go-colorable v0.1.13 // indirect
|
||||
github.com/mattn/go-isatty v0.0.19 // indirect
|
||||
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 // indirect
|
||||
github.com/redis/go-redis/v9 v9.7.3 // indirect
|
||||
github.com/rs/zerolog v1.34.0 // indirect
|
||||
github.com/zeebo/xxh3 v1.0.2 // indirect
|
||||
go.uber.org/multierr v1.10.0 // indirect
|
||||
go.uber.org/zap v1.27.0 // indirect
|
||||
golang.org/x/crypto v0.37.0 // indirect
|
||||
golang.org/x/sys v0.32.0 // indirect
|
||||
golang.org/x/time v0.11.0 // indirect
|
||||
)
|
||||
40
backend/go.sum
Normal file
40
backend/go.sum
Normal file
@ -0,0 +1,40 @@
|
||||
github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44=
|
||||
github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
||||
github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
|
||||
github.com/go-chi/chi/v5 v5.2.1 h1:KOIHODQj58PmL80G2Eak4WdvUzjSJSm0vG72crDCqb8=
|
||||
github.com/go-chi/chi/v5 v5.2.1/go.mod h1:L2yAIGWB3H+phAw1NxKwWM+7eUH/lU8pOMm5hHcoops=
|
||||
github.com/go-chi/httprate v0.15.0 h1:j54xcWV9KGmPf/X4H32/aTH+wBlrvxL7P+SdnRqxh5g=
|
||||
github.com/go-chi/httprate v0.15.0/go.mod h1:rzGHhVrsBn3IMLYDOZQsSU4fJNWcjui4fWKJcCId1R4=
|
||||
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
|
||||
github.com/klauspost/cpuid/v2 v2.2.10 h1:tBs3QSyvjDyFTq3uoc/9xFpCuOsJQFNPiAhYdw2skhE=
|
||||
github.com/klauspost/cpuid/v2 v2.2.10/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0=
|
||||
github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA=
|
||||
github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg=
|
||||
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
|
||||
github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA=
|
||||
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 h1:zYyBkD/k9seD2A7fsi6Oo2LfFZAehjjQMERAvZLEDnQ=
|
||||
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646/go.mod h1:jpp1/29i3P1S/RLdc7JQKbRpFeM1dOBd8T9ki5s+AY8=
|
||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/redis/go-redis/v9 v9.7.3 h1:YpPyAayJV+XErNsatSElgRZZVCwXX9QzkKYNvO7x0wM=
|
||||
github.com/redis/go-redis/v9 v9.7.3/go.mod h1:bGUrSggJ9X9GUmZpZNEOQKaANxSGgOEBRltRTZHSvrA=
|
||||
github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0=
|
||||
github.com/rs/zerolog v1.34.0 h1:k43nTLIwcTVQAncfCw4KZ2VY6ukYoZaBPNOE8txlOeY=
|
||||
github.com/rs/zerolog v1.34.0/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ=
|
||||
github.com/zeebo/xxh3 v1.0.2 h1:xZmwmqxHZA8AI603jOQ0tMqmBr9lPeFwGg6d+xy9DC0=
|
||||
github.com/zeebo/xxh3 v1.0.2/go.mod h1:5NWz9Sef7zIDm2JHfFlcQvNekmcEl9ekUZQQKCYaDcA=
|
||||
go.uber.org/multierr v1.10.0 h1:S0h4aNzvfcFsC3dRF1jLoaov7oRaKqRGC/pUEJ2yvPQ=
|
||||
go.uber.org/multierr v1.10.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
|
||||
go.uber.org/zap v1.27.0 h1:aJMhYGrd5QSmlpLMr2MftRKl7t8J8PTZPA732ud/XR8=
|
||||
go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E=
|
||||
golang.org/x/crypto v0.37.0 h1:kJNSjF/Xp7kU0iB2Z+9viTPMW4EqqsrywMXLJOOsXSE=
|
||||
golang.org/x/crypto v0.37.0/go.mod h1:vg+k43peMZ0pUMhYmVAWysMK35e6ioLh3wB8ZCAfbVc=
|
||||
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.32.0 h1:s77OFDvIQeibCmezSnk/q6iAfkdiQaJi4VzroCFrN20=
|
||||
golang.org/x/sys v0.32.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
|
||||
golang.org/x/time v0.11.0 h1:/bpjEDfN9tkoN/ryeYHnv5hcMlc8ncjMcM4XBk5NWV0=
|
||||
golang.org/x/time v0.11.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg=
|
||||
76
backend/internal/db/temp.txt
Normal file
76
backend/internal/db/temp.txt
Normal file
@ -0,0 +1,76 @@
|
||||
// Use DBML to define your database structure
|
||||
// Docs: https://dbml.dbdiagram.io/docs
|
||||
|
||||
Table users {
|
||||
id integer [primary key]
|
||||
username varchar unique
|
||||
email varchar unique
|
||||
password varchar
|
||||
created_at timestamp
|
||||
is_active bool
|
||||
role integer
|
||||
personalgroup integer
|
||||
}
|
||||
|
||||
Table groups {
|
||||
id integer [primary key]
|
||||
name varchar
|
||||
owner integer
|
||||
}
|
||||
|
||||
Table groupMembership {
|
||||
membershipid integer [primary key]
|
||||
groupid integer
|
||||
userid integer
|
||||
moderator bool
|
||||
}
|
||||
|
||||
|
||||
Table roles {
|
||||
id integer [primary key]
|
||||
name varchar
|
||||
description varchar
|
||||
level integer
|
||||
}
|
||||
|
||||
Table reciepts {
|
||||
id integer [primary key]
|
||||
groupid integer
|
||||
data nvarchar
|
||||
created_at timestamp
|
||||
updated_at timestamp
|
||||
}
|
||||
|
||||
// Table imageOwnership {
|
||||
// ownershipid integer [primary key]
|
||||
// receiptid integer
|
||||
// imageid integer
|
||||
// }
|
||||
|
||||
Table images {
|
||||
id integer [primary key]
|
||||
receiptid integer
|
||||
created_at timestamp
|
||||
path varchar
|
||||
added bool
|
||||
}
|
||||
|
||||
|
||||
|
||||
Ref: "users"."personalgroup" > "groups"."id"
|
||||
|
||||
Ref: "groups"."owner" > "users"."id"
|
||||
|
||||
Ref: "groups"."id" < "groupMembership"."groupid"
|
||||
|
||||
Ref: "groupMembership"."userid" < "users"."id"
|
||||
|
||||
Ref: "roles"."id" < "users"."role"
|
||||
|
||||
// Ref: "reciepts"."id" < "imageOwnership"."receiptid"
|
||||
|
||||
// Ref: "images"."id" < "imageOwnership"."imageid"
|
||||
|
||||
Ref: "reciepts"."id" < "images"."receiptid"
|
||||
|
||||
Ref: "groups"."id" < "reciepts"."groupownerid"
|
||||
18
backend/internal/env/env.go
vendored
Normal file
18
backend/internal/env/env.go
vendored
Normal file
@ -0,0 +1,18 @@
|
||||
package env
|
||||
|
||||
type Environment interface {
|
||||
Initialize(path string)
|
||||
GetString(key string, defaultValue string) string
|
||||
GetBool(key string, defaultValue bool) bool
|
||||
GetInt(key string, defaultValue int) int
|
||||
}
|
||||
|
||||
// func GetString(key, fallback string) string {
|
||||
// value, ok := os.LookupEnv(key)
|
||||
// if !ok {
|
||||
// return fallback
|
||||
// }
|
||||
// return value
|
||||
// }
|
||||
|
||||
// repeate for bool and int and such
|
||||
7
backend/internal/env/environmentvars.go
vendored
Normal file
7
backend/internal/env/environmentvars.go
vendored
Normal file
@ -0,0 +1,7 @@
|
||||
package env
|
||||
|
||||
type EnvironmentVariables struct {
|
||||
// nothing. just empty
|
||||
}
|
||||
|
||||
// also add one for a file
|
||||
1
backend/internal/env/jsonvars.go
vendored
Normal file
1
backend/internal/env/jsonvars.go
vendored
Normal file
@ -0,0 +1 @@
|
||||
package env
|
||||
1
backend/internal/env/yamlvars.go
vendored
Normal file
1
backend/internal/env/yamlvars.go
vendored
Normal file
@ -0,0 +1 @@
|
||||
package env
|
||||
125
backend/internal/lcrypto/argon2.go
Normal file
125
backend/internal/lcrypto/argon2.go
Normal file
@ -0,0 +1,125 @@
|
||||
package lcrypto
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/crypto/argon2"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrInvalidHash = errors.New("the encoded hash is not in the correct format")
|
||||
ErrIncompatibleVersion = errors.New("incompatible argon version")
|
||||
ErrWrongHashAlgo = errors.New("wrong hash algorithm")
|
||||
algorithmName = "argon2id"
|
||||
)
|
||||
|
||||
type Argon2id struct {
|
||||
Memory uint32
|
||||
Iterations uint32
|
||||
Parallelism uint8
|
||||
}
|
||||
|
||||
func (a *Argon2id) getAlgoString() string {
|
||||
return fmt.Sprintf("$%s$v=%d$m=%d,t=%d,p=%d", algorithmName, argon2.Version, a.Memory, a.Iterations, a.Parallelism)
|
||||
}
|
||||
|
||||
func (a *Argon2id) decodeAlgoParams(encodedString string) (algo Argon2id, leftovers []string, err error) {
|
||||
algo = Argon2id{}
|
||||
|
||||
vals := strings.Split(encodedString, "$")
|
||||
if len(vals) > 4 {
|
||||
return algo, nil, ErrInvalidHash
|
||||
}
|
||||
|
||||
var name string
|
||||
if _, err = fmt.Sscanf(vals[1], "%s", &name); err != nil {
|
||||
return algo, nil, err
|
||||
}
|
||||
if name != algorithmName {
|
||||
return algo, nil, ErrWrongHashAlgo
|
||||
}
|
||||
|
||||
var version int
|
||||
_, err = fmt.Sscanf(vals[2], "v=%d", &version)
|
||||
if err != nil {
|
||||
return algo, nil, err
|
||||
}
|
||||
if version != argon2.Version {
|
||||
return algo, nil, ErrIncompatibleVersion
|
||||
}
|
||||
|
||||
_, err = fmt.Sscanf(vals[3], "m=%d,t=%d,p=%d", &algo.Memory, &algo.Iterations, &algo.Parallelism)
|
||||
if err != nil {
|
||||
return algo, nil, err
|
||||
}
|
||||
|
||||
return algo, vals[4:], nil
|
||||
}
|
||||
|
||||
func (a *Argon2id) EncodeHashAndSalt(hash []byte, salt []byte) (string, error) {
|
||||
// Base64 encode the salt and hashed password.
|
||||
b64Hash := base64.RawStdEncoding.EncodeToString(hash)
|
||||
b64Salt := base64.RawStdEncoding.EncodeToString(salt)
|
||||
|
||||
// Return a string using the standard encoded hash representation.
|
||||
encoding := fmt.Sprintf("%s$%s$%s", a.getAlgoString(), b64Hash, b64Salt)
|
||||
return encoding, nil
|
||||
}
|
||||
|
||||
func (a *Argon2id) EncodeHash(hash []byte) (string, error) {
|
||||
// Base64 encode the salt and hashed password.
|
||||
b64Hash := base64.RawStdEncoding.EncodeToString(hash)
|
||||
|
||||
// Return a string using the standard encoded hash representation.
|
||||
encoding := fmt.Sprintf("%s$%s", a.getAlgoString(), b64Hash)
|
||||
return encoding, nil
|
||||
}
|
||||
|
||||
func (a *Argon2id) DecodeHashAndSalt(encodedString string) (algo Argon2id, hash []byte, salt []byte, err error) {
|
||||
|
||||
algo, values, err := a.decodeAlgoParams(encodedString)
|
||||
if err != nil {
|
||||
return algo, nil, nil, err
|
||||
}
|
||||
|
||||
if len(values) != 2 {
|
||||
return algo, nil, nil, ErrInvalidHash
|
||||
}
|
||||
|
||||
hash, err = base64.RawStdEncoding.Strict().DecodeString(values[0])
|
||||
if err != nil {
|
||||
return algo, nil, nil, err
|
||||
}
|
||||
|
||||
salt, err = base64.RawStdEncoding.Strict().DecodeString(values[1])
|
||||
if err != nil {
|
||||
return algo, nil, nil, err
|
||||
}
|
||||
|
||||
return algo, hash, salt, nil
|
||||
}
|
||||
|
||||
func (a *Argon2id) DecodeHash(encodedString string) (algo Argon2id, hash []byte, err error) {
|
||||
algo, values, err := a.decodeAlgoParams(encodedString)
|
||||
if err != nil {
|
||||
return algo, nil, err
|
||||
}
|
||||
|
||||
if len(values) != 1 {
|
||||
return algo, nil, ErrInvalidHash
|
||||
}
|
||||
|
||||
hash, err = base64.RawStdEncoding.Strict().DecodeString(values[0])
|
||||
if err != nil {
|
||||
return algo, nil, err
|
||||
}
|
||||
return algo, hash, nil
|
||||
}
|
||||
|
||||
func (a *Argon2id) HashString(in string, salt []byte, keyLen uint) ([]byte, error) {
|
||||
hash := argon2.IDKey([]byte(in), salt, a.Iterations, a.Memory, a.Parallelism, uint32(keyLen))
|
||||
return hash, nil
|
||||
}
|
||||
103
backend/internal/lcrypto/lcrypto.go
Normal file
103
backend/internal/lcrypto/lcrypto.go
Normal file
@ -0,0 +1,103 @@
|
||||
package lcrypto
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/subtle"
|
||||
)
|
||||
|
||||
type HashAlgo interface {
|
||||
EncodeHashAndSalt(hash []byte, salt []byte) (string, error)
|
||||
EncodeHash(hash []byte) (string, error)
|
||||
DecodeHashAndSalt(encodedString string) (algo Argon2id, hash []byte, salt []byte, err error)
|
||||
DecodeHash(encodedString string) (algo Argon2id, hash []byte, err error)
|
||||
HashString(in string, salt []byte, keyLen uint) ([]byte, error)
|
||||
}
|
||||
|
||||
type Hasher struct {
|
||||
Algo HashAlgo
|
||||
KeyLength uint
|
||||
SaltLength uint
|
||||
}
|
||||
|
||||
func (h *Hasher) CheckString(pass string, encoded string) (bool, error) {
|
||||
algo, truehash, salt, err := h.Algo.DecodeHashAndSalt(encoded)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
newhash, err := algo.HashString(pass, salt, h.KeyLength)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
equal, err := CompareHash(truehash, newhash)
|
||||
|
||||
return equal, err
|
||||
}
|
||||
|
||||
func (h *Hasher) CheckStringWithSalt(in string, salt []byte, encoded string) (bool, error) {
|
||||
algo, truehash, err := h.Algo.DecodeHash(encoded)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
newhash, err := algo.HashString(in, salt, h.KeyLength)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
equal, err := CompareHash(truehash, newhash)
|
||||
|
||||
return equal, err
|
||||
}
|
||||
|
||||
func (h *Hasher) HashString(in string) (string, error) {
|
||||
|
||||
salt, err := GenerateRandomBytes(h.SaltLength)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
hash, err := h.Algo.HashString(in, salt, h.KeyLength)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return h.Algo.EncodeHashAndSalt(hash, salt)
|
||||
}
|
||||
|
||||
func (h *Hasher) HashStringWithSalt(in string, salt []byte) (string, error) {
|
||||
hash, err := h.Algo.HashString(in, salt, h.KeyLength)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return h.Algo.EncodeHash(hash)
|
||||
}
|
||||
|
||||
func (h *Hasher) Hash(in string) ([]byte, error) {
|
||||
salt, err := GenerateRandomBytes(h.SaltLength)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return h.Algo.HashString(in, salt, h.KeyLength)
|
||||
}
|
||||
|
||||
func CompareHash(h1 []byte, h2 []byte) (bool, error) {
|
||||
if subtle.ConstantTimeCompare(h1, h2) == 1 {
|
||||
return true, nil
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func GenerateRandomBytes(saltLen uint) ([]byte, error) {
|
||||
|
||||
var bytes []byte = make([]byte, saltLen)
|
||||
|
||||
_, err := rand.Read(bytes)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return bytes, nil
|
||||
}
|
||||
9
backend/internal/logger/logger.go
Normal file
9
backend/internal/logger/logger.go
Normal file
@ -0,0 +1,9 @@
|
||||
package logger
|
||||
|
||||
type Logger interface {
|
||||
Debug(msg string, keysAndValues ...interface{})
|
||||
Info(msg string, keysAndValues ...interface{})
|
||||
Warn(msg string, keysAndValues ...interface{})
|
||||
Error(msg string, keysAndValues ...interface{})
|
||||
Fatal(msg string, keysAndValues ...interface{})
|
||||
}
|
||||
1
backend/internal/logger/note.txt
Normal file
1
backend/internal/logger/note.txt
Normal file
@ -0,0 +1 @@
|
||||
logger taken from https://dwarvesf.hashnode.dev/go-1-21-release-slog-with-benchmarks-zerolog-and-zap#heading-implementing-the-logger-with-zerolog-and-zap
|
||||
38
backend/internal/logger/slog.go
Normal file
38
backend/internal/logger/slog.go
Normal file
@ -0,0 +1,38 @@
|
||||
package logger
|
||||
|
||||
import (
|
||||
"log"
|
||||
"log/slog"
|
||||
"os"
|
||||
)
|
||||
|
||||
type SlogLogger struct {
|
||||
log *slog.Logger
|
||||
}
|
||||
|
||||
func NewSlogLogger() Logger {
|
||||
sl := slog.New(slog.NewJSONHandler(os.Stdout, nil))
|
||||
|
||||
return &SlogLogger{log: sl}
|
||||
}
|
||||
|
||||
func (zl *SlogLogger) Debug(msg string, keysAndValues ...interface{}) {
|
||||
zl.log.Debug(msg, keysAndValues...)
|
||||
}
|
||||
|
||||
func (zl *SlogLogger) Info(msg string, keysAndValues ...interface{}) {
|
||||
zl.log.Info(msg, keysAndValues...)
|
||||
}
|
||||
|
||||
func (zl *SlogLogger) Warn(msg string, keysAndValues ...interface{}) {
|
||||
zl.log.Warn(msg, keysAndValues...)
|
||||
}
|
||||
|
||||
func (zl *SlogLogger) Error(msg string, keysAndValues ...interface{}) {
|
||||
zl.log.Error(msg, keysAndValues...)
|
||||
}
|
||||
|
||||
func (zl *SlogLogger) Fatal(msg string, keysAndValues ...interface{}) {
|
||||
zl.log.Error(msg, keysAndValues...)
|
||||
log.Fatal(msg)
|
||||
}
|
||||
35
backend/internal/logger/zap.go
Normal file
35
backend/internal/logger/zap.go
Normal file
@ -0,0 +1,35 @@
|
||||
package logger
|
||||
|
||||
import (
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
type ZapLogger struct {
|
||||
log *zap.SugaredLogger
|
||||
}
|
||||
|
||||
func NewZapLogger() Logger {
|
||||
logger, _ := zap.NewProduction()
|
||||
sugar := logger.Sugar()
|
||||
return &ZapLogger{log: sugar}
|
||||
}
|
||||
|
||||
func (zl *ZapLogger) Debug(msg string, keysAndValues ...interface{}) {
|
||||
zl.log.Debugw(msg, keysAndValues...)
|
||||
}
|
||||
|
||||
func (zl *ZapLogger) Info(msg string, keysAndValues ...interface{}) {
|
||||
zl.log.Infow(msg, keysAndValues...)
|
||||
}
|
||||
|
||||
func (zl *ZapLogger) Warn(msg string, keysAndValues ...interface{}) {
|
||||
zl.log.Warnw(msg, keysAndValues...)
|
||||
}
|
||||
|
||||
func (zl *ZapLogger) Error(msg string, keysAndValues ...interface{}) {
|
||||
zl.log.Errorw(msg, keysAndValues...)
|
||||
}
|
||||
|
||||
func (zl *ZapLogger) Fatal(msg string, keysAndValues ...interface{}) {
|
||||
zl.log.Fatalw(msg, keysAndValues...)
|
||||
}
|
||||
36
backend/internal/logger/zerolog.go
Normal file
36
backend/internal/logger/zerolog.go
Normal file
@ -0,0 +1,36 @@
|
||||
package logger
|
||||
|
||||
import (
|
||||
"os"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
)
|
||||
|
||||
type ZeroLogger struct {
|
||||
log zerolog.Logger
|
||||
}
|
||||
|
||||
func NewZeroLogger() Logger {
|
||||
zl := zerolog.New(os.Stdout).With().Timestamp().Logger()
|
||||
return &ZeroLogger{log: zl}
|
||||
}
|
||||
|
||||
func (zl *ZeroLogger) Debug(msg string, keysAndValues ...interface{}) {
|
||||
zl.log.Debug().Fields(keysAndValues).Msg(msg)
|
||||
}
|
||||
|
||||
func (zl *ZeroLogger) Info(msg string, keysAndValues ...interface{}) {
|
||||
zl.log.Info().Fields(keysAndValues).Msg(msg)
|
||||
}
|
||||
|
||||
func (zl *ZeroLogger) Warn(msg string, keysAndValues ...interface{}) {
|
||||
zl.log.Warn().Fields(keysAndValues).Msg(msg)
|
||||
}
|
||||
|
||||
func (zl *ZeroLogger) Error(msg string, keysAndValues ...interface{}) {
|
||||
zl.log.Error().Fields(keysAndValues).Msg(msg)
|
||||
}
|
||||
|
||||
func (zl *ZeroLogger) Fatal(msg string, keysAndValues ...interface{}) {
|
||||
zl.log.Fatal().Fields(keysAndValues).Msg(msg)
|
||||
}
|
||||
49
backend/internal/ratelimiter/fixed-window.go
Normal file
49
backend/internal/ratelimiter/fixed-window.go
Normal file
@ -0,0 +1,49 @@
|
||||
package ratelimiter
|
||||
|
||||
// Copied from https://www.youtube.com/watch?v=m5oyY9fgZPs
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type FixedWindowRateLimiter struct {
|
||||
sync.RWMutex
|
||||
clients map[string]int
|
||||
limit int
|
||||
window time.Duration
|
||||
}
|
||||
|
||||
func NewFixedWindowLimiter(limit int, window time.Duration) *FixedWindowRateLimiter {
|
||||
return &FixedWindowRateLimiter{
|
||||
clients: make(map[string]int),
|
||||
limit: limit,
|
||||
window: window,
|
||||
}
|
||||
}
|
||||
|
||||
func (rl *FixedWindowRateLimiter) Allow(ip string) (bool, time.Duration) {
|
||||
rl.RLock()
|
||||
count, exists := rl.clients[ip]
|
||||
rl.RUnlock()
|
||||
|
||||
if !exists || count < rl.limit {
|
||||
// begin the reset count for a new window
|
||||
rl.Lock()
|
||||
if !exists {
|
||||
go rl.resetCount(ip)
|
||||
}
|
||||
rl.clients[ip]++
|
||||
rl.Unlock()
|
||||
return true, time.Duration(0)
|
||||
}
|
||||
return false, rl.window
|
||||
|
||||
}
|
||||
|
||||
func (rl *FixedWindowRateLimiter) resetCount(ip string) {
|
||||
time.Sleep(rl.window)
|
||||
rl.Lock()
|
||||
delete(rl.clients, ip)
|
||||
rl.Unlock()
|
||||
}
|
||||
13
backend/internal/ratelimiter/ratelimiter.go
Normal file
13
backend/internal/ratelimiter/ratelimiter.go
Normal file
@ -0,0 +1,13 @@
|
||||
package ratelimiter
|
||||
|
||||
import "time"
|
||||
|
||||
type Limiter interface {
|
||||
Allow(ip string) (bool, time.Duration)
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
RequestsPerTimeFrame int
|
||||
TimeFrame time.Duration
|
||||
Enabled bool
|
||||
}
|
||||
42
backend/internal/ratelimiter/sliding-window.go
Normal file
42
backend/internal/ratelimiter/sliding-window.go
Normal file
@ -0,0 +1,42 @@
|
||||
package ratelimiter
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type SlidingWindowRateLimiter struct {
|
||||
sync.RWMutex
|
||||
clients map[string][]time.Time
|
||||
limit int
|
||||
window time.Duration
|
||||
}
|
||||
|
||||
func NewSlidingWindowLimiter(limit int, window time.Duration) *SlidingWindowRateLimiter {
|
||||
return &SlidingWindowRateLimiter{
|
||||
clients: make(map[string][]time.Time),
|
||||
limit: limit,
|
||||
window: window,
|
||||
}
|
||||
}
|
||||
|
||||
func (rl *SlidingWindowRateLimiter) Allow(ip string) (bool, time.Duration) {
|
||||
rl.RLock()
|
||||
defer rl.Unlock()
|
||||
|
||||
// add new request attempt
|
||||
rl.clients[ip] = append(rl.clients[ip], time.Now())
|
||||
|
||||
// remove ones outside the window
|
||||
for len(rl.clients[ip]) > 0 && time.Since(rl.clients[ip][0]) > rl.window {
|
||||
rl.clients[ip] = rl.clients[ip][1:]
|
||||
}
|
||||
|
||||
// do actual check now
|
||||
if len(rl.clients[ip]) > rl.limit {
|
||||
// calc retry wait time
|
||||
retryAfter := rl.window - time.Since(rl.clients[ip][0])
|
||||
return false, time.Duration(retryAfter)
|
||||
}
|
||||
return true, time.Duration(0)
|
||||
}
|
||||
75
backend/internal/ratelimiter/token-bucket.go
Normal file
75
backend/internal/ratelimiter/token-bucket.go
Normal file
@ -0,0 +1,75 @@
|
||||
package ratelimiter
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
type TokenBucketRateLimiter struct {
|
||||
sync.RWMutex
|
||||
clients map[string]*Client
|
||||
config Config
|
||||
}
|
||||
|
||||
func NewTokenBucketRateLimiter(config Config) *TokenBucketRateLimiter {
|
||||
rl := &TokenBucketRateLimiter{
|
||||
clients: make(map[string]*Client),
|
||||
config: config,
|
||||
}
|
||||
|
||||
go rl.cleanupClients()
|
||||
|
||||
return rl
|
||||
}
|
||||
|
||||
type Client struct {
|
||||
Limiter *rate.Limiter
|
||||
LastSeen time.Time
|
||||
}
|
||||
|
||||
func (rl *TokenBucketRateLimiter) getClient(ip string) *Client {
|
||||
rl.RLock()
|
||||
_, exists := rl.clients[ip]
|
||||
rl.RUnlock()
|
||||
|
||||
if !exists {
|
||||
limiter := rate.NewLimiter(rate.Every(rl.config.TimeFrame), rl.config.RequestsPerTimeFrame)
|
||||
rl.Lock()
|
||||
rl.clients[ip] = &Client{
|
||||
Limiter: limiter,
|
||||
LastSeen: time.Now()}
|
||||
|
||||
rl.Unlock()
|
||||
}
|
||||
|
||||
client := rl.clients[ip]
|
||||
client.LastSeen = time.Now()
|
||||
|
||||
return client
|
||||
}
|
||||
|
||||
func (rl *TokenBucketRateLimiter) Allow(ip string) (bool, float64) {
|
||||
client := rl.getClient(ip)
|
||||
|
||||
allowed := client.Limiter.Allow()
|
||||
tokens := client.Limiter.Tokens()
|
||||
|
||||
return allowed, tokens
|
||||
}
|
||||
|
||||
func (rl *TokenBucketRateLimiter) cleanupClients() {
|
||||
for {
|
||||
time.Sleep(time.Minute)
|
||||
//log cleaning up clients
|
||||
|
||||
rl.Lock()
|
||||
for ip, client := range rl.clients {
|
||||
if time.Since(client.LastSeen) > 3*time.Minute { // timeout period
|
||||
delete(rl.clients, ip)
|
||||
}
|
||||
}
|
||||
rl.Unlock()
|
||||
}
|
||||
}
|
||||
26
backend/internal/storage/cache/cache.go
vendored
Normal file
26
backend/internal/storage/cache/cache.go
vendored
Normal file
@ -0,0 +1,26 @@
|
||||
package cache
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"git.ewellenr.ca/receipt_indexer/backend/internal/storage"
|
||||
// auth_storage "git.ewellenr.ca/receipt_indexer/backend/internal/storage/auth"
|
||||
)
|
||||
|
||||
type Storage struct {
|
||||
Users interface {
|
||||
Get(ctx context.Context, id int64) (*storage.User, error)
|
||||
Set(ctx context.Context, user *storage.User) error
|
||||
Delete(ctx context.Context, userID int64)
|
||||
}
|
||||
Receipts interface {
|
||||
Get(ctx context.Context, id int64) (*storage.Receipt, error)
|
||||
Set(ctx context.Context, receipt *storage.Receipt) error
|
||||
Delete(ctx context.Context, id int64)
|
||||
}
|
||||
ReceiptImage interface {
|
||||
Get(ctx context.Context, id int64) (*storage.Image, error)
|
||||
Set(ctx context.Context, image *storage.Image) error
|
||||
Delete(ctx context.Context, id int64)
|
||||
}
|
||||
}
|
||||
55
backend/internal/storage/cache/redis_user.go
vendored
Normal file
55
backend/internal/storage/cache/redis_user.go
vendored
Normal file
@ -0,0 +1,55 @@
|
||||
package cache
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
// auth_storage "git.ewellenr.ca/receipt_indexer/backend/internal/storage/auth"
|
||||
"git.ewellenr.ca/receipt_indexer/backend/internal/storage"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
type UserStore struct {
|
||||
rdb *redis.Client
|
||||
}
|
||||
|
||||
const UserExpTime = time.Minute
|
||||
|
||||
func (s *UserStore) Get(ctx context.Context, userID int64) (*storage.User, error) {
|
||||
cacheKey := fmt.Sprintf("user-%d", userID)
|
||||
|
||||
data, err := s.rdb.Get(ctx, cacheKey).Result()
|
||||
if err == redis.Nil {
|
||||
return nil, nil
|
||||
} else if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var user storage.User
|
||||
if data != "" {
|
||||
err := json.Unmarshal([]byte(data), &user)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
func (s *UserStore) Set(ctx context.Context, user *storage.User) error {
|
||||
cacheKey := fmt.Sprintf("user-%d", user.ID)
|
||||
|
||||
json, err := json.Marshal(user)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return s.rdb.Set(ctx, cacheKey, json, UserExpTime).Err()
|
||||
}
|
||||
|
||||
func (s *UserStore) Delete(ctx context.Context, userID int64) {
|
||||
cacheKey := fmt.Sprintf("user-%d", userID)
|
||||
s.rdb.Del(ctx, cacheKey)
|
||||
}
|
||||
9
backend/internal/storage/groups.go
Normal file
9
backend/internal/storage/groups.go
Normal file
@ -0,0 +1,9 @@
|
||||
package storage
|
||||
|
||||
type Group struct {
|
||||
ID int64 `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Users []int64 `json:"users"`
|
||||
Owner int64 `json:"owner"`
|
||||
Moderators []int64 `json:"moderators"`
|
||||
}
|
||||
8
backend/internal/storage/images.go
Normal file
8
backend/internal/storage/images.go
Normal file
@ -0,0 +1,8 @@
|
||||
package storage
|
||||
|
||||
type Image struct {
|
||||
ID int64 `json:"id"`
|
||||
ReceiptID string `json:"receipt_id"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
Path string `json:"path"`
|
||||
}
|
||||
22
backend/internal/storage/receipts.go
Normal file
22
backend/internal/storage/receipts.go
Normal file
@ -0,0 +1,22 @@
|
||||
package storage
|
||||
|
||||
import "time"
|
||||
|
||||
type Receipt struct {
|
||||
ID int64 `json:"id"`
|
||||
// Owner string `json:"username"`
|
||||
OwnerID int64 `json:"user_id"`
|
||||
ImageIDs []int64 `json:"image_ids"`
|
||||
Data ReceiptData `json:"receipt_data"`
|
||||
CreatedAt time.Time `json:"created_at`
|
||||
UpdatedAt time.Time `json:"updated_at`
|
||||
}
|
||||
|
||||
type ReceiptData struct {
|
||||
Date time.Time `json:"date"`
|
||||
Subtotal float64 `json:"subtotal"`
|
||||
Total float64 `json:"total"`
|
||||
Tax float64 `json:"tax"`
|
||||
Items map[string]float64 `json:"items"`
|
||||
// Currency string `json:"currency"`
|
||||
}
|
||||
57
backend/internal/storage/redis-csrf.go
Normal file
57
backend/internal/storage/redis-csrf.go
Normal file
@ -0,0 +1,57 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"git.ewellenr.ca/receipt_indexer/backend/internal/lcrypto"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
// game plan. store the session token/id and a salt. then, hash them to create the csrf token. give this csrf token out to the user. when the user ends a session, it ends the session but also deletes it from the csrf store
|
||||
|
||||
// const csrfSaltLength = 32
|
||||
// const csrfTokenLength = 128
|
||||
|
||||
// should be set by the hasher
|
||||
|
||||
type RedisCSRFStore struct {
|
||||
rdb *redis.Client
|
||||
hasher lcrypto.Hasher
|
||||
expirationTime uint
|
||||
}
|
||||
|
||||
func (r *RedisCSRFStore) AddCSRF(ctx context.Context, sessionToken string) (csrftoken string, err error) {
|
||||
csrf, err := r.hasher.Hash(sessionToken)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
csrftoken = string(csrf)
|
||||
|
||||
if err = r.rdb.Set(ctx, sessionToken, csrftoken, time.Duration(r.expirationTime)).Err(); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return csrftoken, nil
|
||||
}
|
||||
|
||||
func (r *RedisCSRFStore) RemoveCSRF(ctx context.Context, sessionToken string) error {
|
||||
r.rdb.Del(ctx, sessionToken)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *RedisCSRFStore) ValidCSRF(ctx context.Context, sessionToken string, csrfToken string) (bool, error) {
|
||||
storedcsrfToken, err := r.rdb.Get(ctx, sessionToken).Result()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
if err = r.rdb.ExpireXX(ctx, sessionToken, time.Duration(r.expirationTime)).Err(); err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
valid, err := lcrypto.CompareHash([]byte(csrfToken), []byte(storedcsrfToken))
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return valid, nil
|
||||
}
|
||||
59
backend/internal/storage/redis-sessions.go
Normal file
59
backend/internal/storage/redis-sessions.go
Normal file
@ -0,0 +1,59 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"git.ewellenr.ca/receipt_indexer/backend/internal/lcrypto"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
// game plan. store the session token/id and a salt. then, hash them to create the csrf token. give this csrf token out to the user. when the user ends a session, it ends the session but also deletes it from the csrf store
|
||||
|
||||
// const csrfSaltLength = 32
|
||||
// const csrfTokenLength = 128
|
||||
|
||||
// should be set by the hasher
|
||||
|
||||
type RedisSessionStore struct {
|
||||
rdb *redis.Client
|
||||
sessionTokenLength uint
|
||||
expirationTime uint
|
||||
}
|
||||
|
||||
func (r *RedisSessionStore) AddSession(ctx context.Context, userid int64) (token string, err error) {
|
||||
temp, err := lcrypto.GenerateRandomBytes(r.sessionTokenLength)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
token = string(temp)
|
||||
err = r.rdb.Set(ctx, token, userid, time.Duration(r.expirationTime)).Err()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return token, err
|
||||
}
|
||||
|
||||
func (r *RedisSessionStore) GetSession(ctx context.Context, token string) (valid bool, userid int64, err error) { // should also extend it by the lifespan if near the end of the time. maybe a 5 min window at the end?
|
||||
userid, err = r.rdb.Get(ctx, token).Int64()
|
||||
if err == redis.Nil {
|
||||
valid = false
|
||||
userid = -1
|
||||
} else if err != nil {
|
||||
return false, -1, err
|
||||
} else {
|
||||
valid = true
|
||||
}
|
||||
|
||||
err = r.rdb.ExpireXX(ctx, token, time.Duration(r.expirationTime)).Err()
|
||||
if err != nil {
|
||||
return false, -1, err
|
||||
}
|
||||
|
||||
return valid, userid, err
|
||||
}
|
||||
|
||||
func (r *RedisSessionStore) RemoveSession(ctx context.Context, token string) error {
|
||||
return r.rdb.Del(ctx, token).Err()
|
||||
}
|
||||
8
backend/internal/storage/roles.go
Normal file
8
backend/internal/storage/roles.go
Normal file
@ -0,0 +1,8 @@
|
||||
package storage
|
||||
|
||||
type Role struct {
|
||||
ID int64 `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Level int `json:"level"`
|
||||
}
|
||||
37
backend/internal/storage/sql-groups.go
Normal file
37
backend/internal/storage/sql-groups.go
Normal file
@ -0,0 +1,37 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
)
|
||||
|
||||
type SQLGroupsStore struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
func (s *SQLGroupsStore) GetByID(ctx context.Context, id int64) (*Group, error) {
|
||||
query := `SELECT id, name, owner FROM groups WHERE id = $1`
|
||||
|
||||
group := &Group{}
|
||||
err := s.db.QueryRowContext(ctx, query, id).Scan(&group.ID, &group.Name, &group.Owner)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return group, nil
|
||||
}
|
||||
|
||||
func (s *SQLGroupsStore) GetUserGroups(ctx context.Context, userID int64) ([]*Group, error) {
|
||||
// Implement logic to retrieve user's groups from the database
|
||||
}
|
||||
func (s *SQLGroupsStore) GetUsersInGroup(ctx context.Context, groupId int64) ([]*User, error) {
|
||||
// Implement logic to retrieve users in a group from the database
|
||||
}
|
||||
|
||||
func (s *SQLGroupsStore) Create(ctx context.Context, group *Group) error {
|
||||
// Implement logic to create a new group in the database
|
||||
}
|
||||
|
||||
func (s *SQLGroupsStore) Delete(ctx context.Context, id int64) error {
|
||||
// Implement logic to delete a group from the database
|
||||
}
|
||||
79
backend/internal/storage/sql-images.go
Normal file
79
backend/internal/storage/sql-images.go
Normal file
@ -0,0 +1,79 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
)
|
||||
|
||||
type SQLImagesStore struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
func (s SQLImagesStore) GetByID(ctx context.Context, id int64) (*Image, error) {
|
||||
query := `SELECT id, receiptid, created_at, path FROM images WHERE id = $1`
|
||||
|
||||
image := &Image{}
|
||||
err := s.db.QueryRowContext(ctx, query, id).Scan(&image.ID, &image.ReceiptID, &image.CreatedAt, &image.Path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return image, nil
|
||||
}
|
||||
|
||||
func (s SQLImagesStore) Create(ctx context.Context, img *Image) error {
|
||||
query := `
|
||||
INSERT INTO images (receiptid, path)
|
||||
VALUES ($1, $2) RETURNING id, created_at`
|
||||
|
||||
ctx, cancel := context.WithTimeout(ctx, QueryTimeoutDuration)
|
||||
defer cancel()
|
||||
|
||||
err := s.db.QueryRowContext(
|
||||
ctx,
|
||||
query,
|
||||
img.ReceiptID,
|
||||
img.Path, // might need to marshal or serialize this
|
||||
).Scan(
|
||||
&img.ID,
|
||||
&img.CreatedAt,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s SQLImagesStore) Delete(ctx context.Context, id int64) error {
|
||||
query := `DELETE FROM images WHERE id = $1`
|
||||
|
||||
ctx, cancel := context.WithTimeout(ctx, QueryTimeoutDuration)
|
||||
defer cancel()
|
||||
|
||||
res, err := s.db.ExecContext(ctx, query, id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
rows, err := res.RowsAffected()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if rows == 0 {
|
||||
return ErrNotFound
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s SQLImagesStore) ActivateImage(ctx context.Context, id int64) error {
|
||||
query := `UPDATE images SET added = $1`
|
||||
|
||||
ctx, cancel := context.WithTimeout(ctx, QueryTimeoutDuration)
|
||||
defer cancel()
|
||||
|
||||
_, err := s.db.ExecContext(ctx, query, true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
66
backend/internal/storage/sql-receipts.go
Normal file
66
backend/internal/storage/sql-receipts.go
Normal file
@ -0,0 +1,66 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
)
|
||||
|
||||
type SQLReceiptsStore struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
func (s *SQLReceiptsStore) GetByID(ctx context.Context, id int64) (*Receipt, error) {
|
||||
query := `SELECT id, groupid, data FROM receipts WHERE id = $1`
|
||||
|
||||
receipt := &Receipt{}
|
||||
err := s.db.QueryRowContext(ctx, query, id).Scan(&receipt.ID, &receipt.OwnerID, &receipt.Data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return receipt, nil
|
||||
}
|
||||
|
||||
func (s *SQLReceiptsStore) Create(ctx context.Context, receipt *Receipt) error {
|
||||
query := `
|
||||
INSERT INTO receipts (groupid, data)
|
||||
VALUES ($1, $2) RETURNING id, created_at, updated_at`
|
||||
|
||||
ctx, cancel := context.WithTimeout(ctx, QueryTimeoutDuration)
|
||||
defer cancel()
|
||||
|
||||
err := s.db.QueryRowContext(
|
||||
ctx,
|
||||
query,
|
||||
receipt.OwnerID,
|
||||
receipt.Data, // might need to marshal or serialize this
|
||||
).Scan(
|
||||
&receipt.ID,
|
||||
&receipt.CreatedAt,
|
||||
&receipt.UpdatedAt,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *SQLReceiptsStore) Delete(ctx context.Context, id int64) error {
|
||||
query := `DELETE FROM receipts WHERE id = $1`
|
||||
|
||||
ctx, cancel := context.WithTimeout(ctx, QueryTimeoutDuration)
|
||||
defer cancel()
|
||||
|
||||
res, err := s.db.ExecContext(ctx, query, id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
rows, err := res.RowsAffected()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if rows == 0 {
|
||||
return ErrNotFound
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
39
backend/internal/storage/sql-roles.go
Normal file
39
backend/internal/storage/sql-roles.go
Normal file
@ -0,0 +1,39 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
)
|
||||
|
||||
type SQLRolesStore struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
func (s *SQLRolesStore) GetByName(ctx context.Context, name string) (*Role, error) {
|
||||
ctx, cancel := context.WithTimeout(ctx, QueryTimeoutDuration)
|
||||
defer cancel()
|
||||
|
||||
query := `SELECT id, name, description, level FROM roles WHERE name = $1`
|
||||
|
||||
role := &Role{}
|
||||
err := s.db.QueryRowContext(ctx, query, name).Scan(&role.ID, &role.Name, &role.Description, &role.Level)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return role, nil
|
||||
}
|
||||
func (s *SQLRolesStore) GetById(ctx context.Context, id int64) (*Role, error) {
|
||||
ctx, cancel := context.WithTimeout(ctx, QueryTimeoutDuration)
|
||||
defer cancel()
|
||||
|
||||
query := `SELECT id, name, description, level FROM roles WHERE id = $1`
|
||||
|
||||
role := &Role{}
|
||||
err := s.db.QueryRowContext(ctx, query, id).Scan(&role.ID, &role.Name, &role.Description, &role.Level)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return role, nil
|
||||
}
|
||||
231
backend/internal/storage/sql-users.go
Normal file
231
backend/internal/storage/sql-users.go
Normal file
@ -0,0 +1,231 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
type SQLUsersStore struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
func (s *SQLUsersStore) GetByID(ctx context.Context, id int64) (*User, error) {
|
||||
query := `SELECT users.id, users.username, users.email, users.password, users.created_at, roles.*
|
||||
FROM users
|
||||
JOIN roles ON (users.role_id = roles.id)
|
||||
WHERE users.id = $1 AND is_active = true`
|
||||
|
||||
ctx, cancel := context.WithTimeout(ctx, QueryTimeoutDuration)
|
||||
defer cancel()
|
||||
|
||||
user := &User{}
|
||||
err := s.db.QueryRowContext(
|
||||
ctx,
|
||||
query,
|
||||
id,
|
||||
).Scan(
|
||||
&user.ID,
|
||||
&user.Username,
|
||||
&user.Email,
|
||||
&user.Password.hash,
|
||||
&user.CreatedAt,
|
||||
&user.Role.ID,
|
||||
&user.Role.Name,
|
||||
&user.Role.Level,
|
||||
&user.Role.Description,
|
||||
)
|
||||
if err != nil {
|
||||
switch err {
|
||||
case sql.ErrNoRows:
|
||||
return nil, ErrNotFound
|
||||
default:
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (s *SQLUsersStore) GetByEmail(ctx context.Context, email string) (*User, error) {
|
||||
query := `SELECT users.id, users.username, users.email, users.password, users.created_at, roles.*
|
||||
FROM users
|
||||
JOIN roles ON (users.role_id = roles.id)
|
||||
WHERE users.email = $1 AND is_active = true`
|
||||
|
||||
ctx, cancel := context.WithTimeout(ctx, QueryTimeoutDuration)
|
||||
defer cancel()
|
||||
|
||||
user := &User{}
|
||||
err := s.db.QueryRowContext(
|
||||
ctx,
|
||||
query,
|
||||
email,
|
||||
).Scan(
|
||||
&user.ID,
|
||||
&user.Username,
|
||||
&user.Email,
|
||||
&user.Password.hash,
|
||||
&user.CreatedAt,
|
||||
&user.Role.ID,
|
||||
&user.Role.Name,
|
||||
&user.Role.Level,
|
||||
&user.Role.Description,
|
||||
)
|
||||
if err != nil {
|
||||
switch err {
|
||||
case sql.ErrNoRows:
|
||||
return nil, ErrNotFound
|
||||
default:
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (s *SQLUsersStore) GetByUsername(ctx context.Context, username string) (*User, error) {
|
||||
query := `SELECT users.id, users.username, users.email, users.password, users.created_at, roles.*
|
||||
FROM users
|
||||
JOIN roles ON (users.role_id = roles.id)
|
||||
WHERE users.username = $1 AND is_active = true`
|
||||
|
||||
ctx, cancel := context.WithTimeout(ctx, QueryTimeoutDuration)
|
||||
defer cancel()
|
||||
|
||||
user := &User{}
|
||||
err := s.db.QueryRowContext(
|
||||
ctx,
|
||||
query,
|
||||
username,
|
||||
).Scan(
|
||||
&user.ID,
|
||||
&user.Username,
|
||||
&user.Email,
|
||||
&user.Password.hash,
|
||||
&user.CreatedAt,
|
||||
&user.Role.ID,
|
||||
&user.Role.Name,
|
||||
&user.Role.Level,
|
||||
&user.Role.Description,
|
||||
)
|
||||
if err != nil {
|
||||
switch err {
|
||||
case sql.ErrNoRows:
|
||||
return nil, ErrNotFound
|
||||
default:
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (s *SQLUsersStore) create(ctx context.Context, user *User, tx *sql.Tx) error { // creates the personal group and binds them
|
||||
ctx, cancel := context.WithTimeout(ctx, QueryTimeoutDuration)
|
||||
defer cancel()
|
||||
|
||||
query := `INSERT INTO groups (name, owner) VALUES ($1, $2) RETURNING id`
|
||||
|
||||
err := tx.QueryRowContext(ctx, query, fmt.Sprintf("User-%d-Personal-Group", user.ID), user.ID).Scan(&user.PersonalGroup)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
query = `INSERT INTO users (username, password, email, role_id, personalgroup) VALUES
|
||||
($1, $2, $3, (SELECT id FROM roles WHERE name = $4), $5)
|
||||
RETURNING id, created_at`
|
||||
|
||||
role := user.Role.Name
|
||||
if role == "" {
|
||||
role = "user"
|
||||
}
|
||||
|
||||
err = tx.QueryRowContext(
|
||||
ctx,
|
||||
query,
|
||||
user.Username,
|
||||
user.Password,
|
||||
user.Email,
|
||||
role,
|
||||
user.PersonalGroup,
|
||||
).Scan(
|
||||
&user.ID,
|
||||
&user.CreatedAt,
|
||||
)
|
||||
if err != nil {
|
||||
switch {
|
||||
case err.Error() == `pq: duplicate key value violates unique constraint "users_email_key"`:
|
||||
return ErrDuplicateEmail
|
||||
case err.Error() == `pq: duplicate key value violates unique constraint "users_username_key"`:
|
||||
return ErrDuplicateUsername
|
||||
default:
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
query = `INSERT INTO groupMembership (groupid, userid, moderator) VALUES ($1, $2, $3)`
|
||||
|
||||
_, err = tx.ExecContext(ctx, query, user.PersonalGroup, user.ID, true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SQLUsersStore) Create(ctx context.Context, user *User) error { // create a non-exported create function which does take in the tx
|
||||
return withTx(s.db, ctx, func(tx *sql.Tx) error {
|
||||
return s.create(ctx, user, tx)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *SQLUsersStore) CreateAndInvite(ctx context.Context, user *User, token string, exp time.Duration) error { // figure this out
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SQLUsersStore) Activate(context.Context, string) error { // what does this do?
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SQLUsersStore) delete(ctx context.Context, id int64, tx *sql.Tx) error {
|
||||
query := `DELETE FROM users WHERE id = $1`
|
||||
|
||||
ctx, cancel := context.WithTimeout(ctx, QueryTimeoutDuration)
|
||||
defer cancel()
|
||||
|
||||
res, err := s.db.ExecContext(ctx, query, id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
rows, err := res.RowsAffected()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if rows == 0 {
|
||||
return ErrNotFound
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SQLUsersStore) Delete(ctx context.Context, id int64) error {
|
||||
return withTx(s.db, ctx, func(tx *sql.Tx) error {
|
||||
return s.delete(ctx, id, tx)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *SQLUsersStore) UpdateUserPass(ctx context.Context, user string, oldPassword string, newPass string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SQLUsersStore) CheckPass(ctx context.Context, name string, pass string) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (s *SQLUsersStore) SigninUser(ctx context.Context, name string, pass string) (bool, *User, error) {
|
||||
return false, nil, nil
|
||||
}
|
||||
25
backend/internal/storage/sql.go
Normal file
25
backend/internal/storage/sql.go
Normal file
@ -0,0 +1,25 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
QueryTimeoutDuration = time.Second * 5
|
||||
)
|
||||
|
||||
func withTx(db *sql.DB, ctx context.Context, fn func(*sql.Tx) error) error {
|
||||
tx, err := db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := fn(tx); err != nil {
|
||||
_ = tx.Rollback()
|
||||
return err
|
||||
}
|
||||
|
||||
return tx.Commit()
|
||||
}
|
||||
80
backend/internal/storage/storage.go
Normal file
80
backend/internal/storage/storage.go
Normal file
@ -0,0 +1,80 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrNotFound = errors.New("resource not found")
|
||||
ErrConflict = errors.New("resource already exists")
|
||||
ErrDuplicateEmail = errors.New("a user with that email already exists")
|
||||
ErrDuplicateUsername = errors.New("a user with that username already exists")
|
||||
)
|
||||
|
||||
type Storage struct {
|
||||
Users interface { // store user id, username, password(hashed+salted), role?
|
||||
GetByID(ctx context.Context, id int64) (*User, error)
|
||||
GetByEmail(context.Context, string) (*User, error)
|
||||
GetByUsername(context.Context, string) (*User, error)
|
||||
Create(context.Context, *User) error // create a non-exported create function which does take in the tx
|
||||
CreateAndInvite(ctx context.Context, user *User, token string, exp time.Duration) error // figure this out
|
||||
Activate(context.Context, string) error // what does this do?
|
||||
Delete(ctx context.Context, id int64) error
|
||||
|
||||
UpdateUserPass(ctx context.Context, user string, oldPassword string, newPass string) error
|
||||
|
||||
// CheckPass(ctx context.Context, name string, pass string) (bool, error)
|
||||
// SigninUser(ctx context.Context, name string, pass string) (bool, *User, error)
|
||||
// ValidCredentials(ctx context.Context, user *User, pass string) (bool, error)
|
||||
}
|
||||
|
||||
Sessions interface { // store just session tokens, and their corresponding user id
|
||||
AddSession(ctx context.Context, userid int64) (token string, err error)
|
||||
GetSession(ctx context.Context, token string) (valid bool, userid int64, err error) // extends it's expiry
|
||||
RemoveSession(ctx context.Context, token string) error
|
||||
// SetLifespan(ctx context.Context, token string, lf time.Time) error
|
||||
}
|
||||
|
||||
CSRF interface {
|
||||
AddCSRF(ctx context.Context, sessionToken string) (csrftoken string, err error)
|
||||
CheckCSRF(ctx context.Context, sessionToken string, csrfToken string) (bool, error)
|
||||
RemoveCSRF(ctx context.Context, sessionToken string) error
|
||||
// CleanupCSRF()
|
||||
}
|
||||
|
||||
Roles interface {
|
||||
GetByName(context.Context, string) (*Role, error)
|
||||
GetById(context.Context, int64) (*Role, error)
|
||||
}
|
||||
|
||||
Receipts interface {
|
||||
GetByID(ctx context.Context, id int64) (*Receipt, error)
|
||||
Create(ctx context.Context, receipt *Receipt) error
|
||||
Delete(ctx context.Context, id int64) error
|
||||
}
|
||||
Images interface {
|
||||
GetByID(context.Context, int64) (*Image, error)
|
||||
Create(context.Context, *Image) error
|
||||
Delete(context.Context, int64) error
|
||||
ActivateImage(context.Context, int64) error // may need to change this. Consider finishing https://www.youtube.com/watch?v=pmEmQcd9_KA
|
||||
|
||||
// Can have it so that the client get's a presigned url
|
||||
// For creation, do a delayed check in 10 minutes to clean up the image from AWS or check
|
||||
}
|
||||
Groups interface {
|
||||
GetByID(context.Context, int64) (*Group, error)
|
||||
GetUserGroups(context.Context, int64) ([]*Group, error)
|
||||
GetUsersInGroup(context.Context, int64) ([]*User, error)
|
||||
Create(context.Context, *Group) error
|
||||
Delete(context.Context, int64) error
|
||||
}
|
||||
}
|
||||
|
||||
func NewSQLRedisMinIOStorage(db *sql.DB) Storage {
|
||||
return Storage{
|
||||
Users: &SQLUsersStore{db},
|
||||
}
|
||||
}
|
||||
29
backend/internal/storage/users.go
Normal file
29
backend/internal/storage/users.go
Normal file
@ -0,0 +1,29 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"errors"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrUserNotFound = errors.New("user not found")
|
||||
ErrExistingUser = errors.New("user already exists")
|
||||
ErrPasswordIncorrect = errors.New("password incorrect")
|
||||
)
|
||||
|
||||
type User struct {
|
||||
ID int64 `json:"id"`
|
||||
Username string `json:"username"`
|
||||
Email string `json:"email"`
|
||||
Password string `json:"-"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
IsActive bool `json:"is_active"`
|
||||
Role Role `json:"role"`
|
||||
PersonalGroup int64 `json:"user_group"`
|
||||
Groups []int64 `json:"groups"`
|
||||
}
|
||||
|
||||
// type Password struct {
|
||||
// text *string
|
||||
// hash []byte
|
||||
// encoded *string
|
||||
// }
|
||||
17
backend/shared/image_transform/imgtransform.go
Normal file
17
backend/shared/image_transform/imgtransform.go
Normal file
@ -0,0 +1,17 @@
|
||||
package imgtransform
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"image"
|
||||
|
||||
"github.com/nfnt/resize"
|
||||
)
|
||||
|
||||
func BytesToImage(data []byte) (image.Image, error) {
|
||||
image, _, err := image.Decode(bytes.NewReader(data))
|
||||
return image, err
|
||||
}
|
||||
|
||||
func ResizeImage(img image.Image, width uint, height uint) image.Image {
|
||||
return resize.Resize(width, height, img, resize.Lanczos3)
|
||||
}
|
||||
17
code/autocropper/.vscode/c_cpp_properties.json
vendored
17
code/autocropper/.vscode/c_cpp_properties.json
vendored
@ -1,17 +0,0 @@
|
||||
{
|
||||
"configurations": [
|
||||
{
|
||||
"name": "Linux",
|
||||
"includePath": [
|
||||
"${workspaceFolder}/**"
|
||||
],
|
||||
"defines": [],
|
||||
"compilerPath": "/usr/bin/gcc",
|
||||
// "cStandard": "c17",
|
||||
// "cppStandard": "gnu++17",
|
||||
"intelliSenseMode": "linux-gcc-x64",
|
||||
"configurationProvider": "ms-vscode.cmake-tools"
|
||||
}
|
||||
],
|
||||
"version": 4
|
||||
}
|
||||
65
code/autocropper/.vscode/settings.json
vendored
65
code/autocropper/.vscode/settings.json
vendored
@ -1,65 +0,0 @@
|
||||
{
|
||||
"C_Cpp.errorSquiggles": "enabled",
|
||||
"files.associations": {
|
||||
"array": "cpp",
|
||||
"atomic": "cpp",
|
||||
"bit": "cpp",
|
||||
"*.tcc": "cpp",
|
||||
"cctype": "cpp",
|
||||
"chrono": "cpp",
|
||||
"clocale": "cpp",
|
||||
"cmath": "cpp",
|
||||
"compare": "cpp",
|
||||
"complex": "cpp",
|
||||
"concepts": "cpp",
|
||||
"condition_variable": "cpp",
|
||||
"cstdarg": "cpp",
|
||||
"cstddef": "cpp",
|
||||
"cstdint": "cpp",
|
||||
"cstdio": "cpp",
|
||||
"cstdlib": "cpp",
|
||||
"cstring": "cpp",
|
||||
"ctime": "cpp",
|
||||
"cwchar": "cpp",
|
||||
"cwctype": "cpp",
|
||||
"deque": "cpp",
|
||||
"list": "cpp",
|
||||
"map": "cpp",
|
||||
"set": "cpp",
|
||||
"string": "cpp",
|
||||
"unordered_map": "cpp",
|
||||
"vector": "cpp",
|
||||
"exception": "cpp",
|
||||
"algorithm": "cpp",
|
||||
"functional": "cpp",
|
||||
"iterator": "cpp",
|
||||
"memory": "cpp",
|
||||
"memory_resource": "cpp",
|
||||
"numeric": "cpp",
|
||||
"random": "cpp",
|
||||
"ratio": "cpp",
|
||||
"string_view": "cpp",
|
||||
"system_error": "cpp",
|
||||
"tuple": "cpp",
|
||||
"type_traits": "cpp",
|
||||
"utility": "cpp",
|
||||
"fstream": "cpp",
|
||||
"initializer_list": "cpp",
|
||||
"iomanip": "cpp",
|
||||
"iosfwd": "cpp",
|
||||
"iostream": "cpp",
|
||||
"istream": "cpp",
|
||||
"limits": "cpp",
|
||||
"mutex": "cpp",
|
||||
"new": "cpp",
|
||||
"numbers": "cpp",
|
||||
"ostream": "cpp",
|
||||
"semaphore": "cpp",
|
||||
"sstream": "cpp",
|
||||
"stdexcept": "cpp",
|
||||
"stop_token": "cpp",
|
||||
"streambuf": "cpp",
|
||||
"thread": "cpp",
|
||||
"typeinfo": "cpp"
|
||||
},
|
||||
}
|
||||
@ -1,24 +0,0 @@
|
||||
cmake_minimum_required(VERSION 3.22)
|
||||
project(autocropper
|
||||
VERSION 0.1
|
||||
DESCRIPTION "Autocrops Receipt Pictures"
|
||||
LANGUAGES CXX)
|
||||
|
||||
#GLOBING
|
||||
file(GLOB_RECURSE SOURCE_FILES src/*.cpp)
|
||||
add_executable(CropperEx main.cpp ${SOURCE_FILES})
|
||||
|
||||
# add_executable(CropperEx main.cpp
|
||||
# src/dog.cpp
|
||||
# src/operations.cpp)
|
||||
|
||||
target_compile_features(CropperEx PRIVATE cxx_std_20)
|
||||
|
||||
find_package(OpenCV REQUIRED)
|
||||
|
||||
target_link_libraries(CropperEx ${OpenCV_LIBS})
|
||||
|
||||
target_include_directories(CropperEx
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../externallibraries/stbimagehelpers
|
||||
PRIVATE ${OpenCV_INCLUDE_DIRS})
|
||||
@ -1,721 +0,0 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/usr/local/lib/python3.10/dist-packages/torchvision/datapoints/__init__.py:12: UserWarning: The torchvision.datapoints and torchvision.transforms.v2 namespaces are still Beta. While we do not expect major breaking changes, some APIs may still change according to user feedback. Please submit any feedback you may have in this issue: https://github.com/pytorch/vision/issues/6753, and you can also check out https://github.com/pytorch/vision/issues/7319 to learn more about the APIs that we suspect might involve future changes. You can silence this warning by calling torchvision.disable_beta_transforms_warning().\n",
|
||||
" warnings.warn(_BETA_TRANSFORMS_WARNING)\n",
|
||||
"/usr/local/lib/python3.10/dist-packages/torchvision/transforms/v2/__init__.py:54: UserWarning: The torchvision.datapoints and torchvision.transforms.v2 namespaces are still Beta. While we do not expect major breaking changes, some APIs may still change according to user feedback. Please submit any feedback you may have in this issue: https://github.com/pytorch/vision/issues/6753, and you can also check out https://github.com/pytorch/vision/issues/7319 to learn more about the APIs that we suspect might involve future changes. You can silence this warning by calling torchvision.disable_beta_transforms_warning().\n",
|
||||
" warnings.warn(_BETA_TRANSFORMS_WARNING)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import cv2\n",
|
||||
"import myfunctions as mf\n",
|
||||
"import numpy as np\n",
|
||||
"import math\n",
|
||||
"import scipy.stats as st"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"import pathlib\n",
|
||||
"import time\n",
|
||||
"\n",
|
||||
"def removeextensionandnumeric(filename):\n",
|
||||
" suffix = pathlib.Path(filename).suffix\n",
|
||||
" num = filename[:-len(suffix)]\n",
|
||||
" numint = int(num)\n",
|
||||
" return numint\n",
|
||||
" \n",
|
||||
"\n",
|
||||
"def testondataset(pathtodataset, function):\n",
|
||||
" imagefileextensions = [\".jpg\", \".png\"]\n",
|
||||
" filenames = next(os.walk(pathtodataset), (None, None, []))[2]\n",
|
||||
" \n",
|
||||
" filenames.sort(key=removeextensionandnumeric)\n",
|
||||
" # print(filenames)\n",
|
||||
" outs = []\n",
|
||||
" tdiffs = []\n",
|
||||
" for filename in filenames:\n",
|
||||
" suffix = pathlib.Path(filename).suffix\n",
|
||||
" if (suffix not in imagefileextensions):\n",
|
||||
" print(\"Not a valid image \"+filename)\n",
|
||||
" continue\n",
|
||||
" img = cv2.imread(pathtodataset+filename)\n",
|
||||
" t1 = time.time()\n",
|
||||
" outs.append(function(img))\n",
|
||||
" tdiffs.append(time.time() - t1)\n",
|
||||
" tdiffs = np.array(tdiffs)\n",
|
||||
" print(\"average time: \" + str(np.mean(tdiffs))+\"(s)\")\n",
|
||||
" return outs\n",
|
||||
" "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def showimgs(imgs):\n",
|
||||
" if (isinstance(imgs, np.ndarray)):\n",
|
||||
" if (imgs.shape[0] > imgs.shape[1]):\n",
|
||||
" cv2.imshow(\"test\", mf.ResizeWithAspectRatio(imgs, height=1350))\n",
|
||||
" else:\n",
|
||||
" cv2.imshow(\"test\", mf.ResizeWithAspectRatio(imgs, width=1000))\n",
|
||||
" else:\n",
|
||||
" for i, out in enumerate(imgs):\n",
|
||||
" if (out.shape[0] > out.shape[1]):\n",
|
||||
" cv2.imshow(\"test\"+str(i), mf.ResizeWithAspectRatio(out, height=1350))\n",
|
||||
" else:\n",
|
||||
" cv2.imshow(\"test\"+str(i), mf.ResizeWithAspectRatio(out, width=1000))\n",
|
||||
" cv2.waitKey(0)\n",
|
||||
" cv2.destroyAllWindows()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def writeimgs(directorypath, imgs):\n",
|
||||
" if (isinstance(imgs, np.ndarray)):\n",
|
||||
" cv2.imwrite(directorypath+\"test.png\", imgs)\n",
|
||||
" else:\n",
|
||||
" for i, out in enumerate(imgs):\n",
|
||||
" cv2.imwrite(directorypath+\"test\"+str(i)+\".png\", out)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"img = cv2.imread('/mnt/dataset/baseimages/12.jpg')\n",
|
||||
"# img = cv2.imread('/mnt/code/autocropper/test_images/IMG_7605.jpg')\n",
|
||||
"testall = False"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"## NEED TO FIX THE EARLIER PARTS SO THAT IT DOESN'T HAVE THOSE BLACK SECTIONS AFTER THE ROTATION\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def whiteoutbackground(image):\n",
|
||||
" ogshape = image.shape\n",
|
||||
" shrunkdim=1000\n",
|
||||
" if (image.shape[1] > image.shape[0]):\n",
|
||||
" shrunkimg, scaler = mf.ResizeWithAspectRatio(image, width=shrunkdim, retscale=True)\n",
|
||||
" else:\n",
|
||||
" shrunkimg, scaler = mf.ResizeWithAspectRatio(image, height=shrunkdim, retscale=True)\n",
|
||||
" \n",
|
||||
" mainimage = shrunkimg\n",
|
||||
" \n",
|
||||
" sdim = int(min(mainimage.shape[0], mainimage.shape[1])/5)\n",
|
||||
" srkernel = cv2.getStructuringElement(cv2.MORPH_RECT, (sdim, sdim))\n",
|
||||
" skernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (sdim, sdim))\n",
|
||||
" \n",
|
||||
" \n",
|
||||
" lab = cv2.cvtColor(mainimage, cv2.COLOR_BGR2LAB)\n",
|
||||
" \n",
|
||||
" imglist = []\n",
|
||||
" # imglist.append(mainimage)\n",
|
||||
" \n",
|
||||
" labl = lab[:,:,0]\n",
|
||||
" # imglist.append(labl)\n",
|
||||
" # imglist.append(cv2.cvtColor(image, cv2.COLOR_BGR2GRAY))\n",
|
||||
" laba = lab[:,:,1]\n",
|
||||
" # imglist.append(laba)\n",
|
||||
" labb = lab[:,:,2]\n",
|
||||
" # imglist.append(labb)\n",
|
||||
" \n",
|
||||
" \n",
|
||||
" # canny = cv2.Canny(labl, 0, 500)\n",
|
||||
" threshl = cv2.threshold(labl, 0, 255, cv2.THRESH_OTSU)[1]\n",
|
||||
" # return threshl\n",
|
||||
" \n",
|
||||
" \n",
|
||||
" dim = int(min(mainimage.shape[0], mainimage.shape[1])/100)\n",
|
||||
" # dim = 2\n",
|
||||
" # dim = dotsize\n",
|
||||
" kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (dim, dim))\n",
|
||||
" kernelell = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (dim, dim))\n",
|
||||
" \n",
|
||||
" paddedl = mf.padWithColour(threshl, sdim*2, sdim*2, fill=0)\n",
|
||||
" # return paddedl\n",
|
||||
" \n",
|
||||
" \n",
|
||||
" # morphedl = 255-cv2.morphologyEx(255-threshl, cv2.MORPH_OPEN, kernel, iterations=3)\n",
|
||||
" morphedl = paddedl\n",
|
||||
" # morphedl = cv2.morphologyEx(morphedl, cv2.MORPH_ERODE, kernel, iterations=1)\n",
|
||||
" morphed1l = cv2.morphologyEx(morphedl, cv2.MORPH_ERODE, kernelell, iterations=1)\n",
|
||||
"\n",
|
||||
" # return morphedl\n",
|
||||
" \n",
|
||||
" contours, heirarchy = cv2.findContours(morphed1l, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)\n",
|
||||
" biggestcontour = max(contours, key=cv2.contourArea)\n",
|
||||
" \n",
|
||||
" \n",
|
||||
" blank = np.full(labl.shape, 255, dtype=np.uint8)\n",
|
||||
" mask1 = blank.copy()\n",
|
||||
" mask1 = mf.padWithColour(mask1, sdim*2, sdim*2, fill=255)\n",
|
||||
" mask1 = cv2.drawContours(mask1, [biggestcontour], -1, 0, thickness=cv2.FILLED)\n",
|
||||
" \n",
|
||||
" \n",
|
||||
" mask1 = cv2.morphologyEx(mask1, cv2.MORPH_DILATE, kernelell, iterations=2)\n",
|
||||
" \n",
|
||||
" \n",
|
||||
" # mask1 = mask1[(sdim*2):-(sdim*2), (sdim*2):-(sdim*2)]\n",
|
||||
" # return mask1\n",
|
||||
" \n",
|
||||
" # morphed2l = mf.padWithColour(morphedl, sdim*2, sdim*2, fill=255)\n",
|
||||
" morphed2l = cv2.morphologyEx(morphedl, cv2.MORPH_OPEN, kernel, iterations=1)\n",
|
||||
" # morphed2l = morphed2l[(sdim*2):-(sdim*2), (sdim*2):-(sdim*2)]\n",
|
||||
" \n",
|
||||
" # return morphed2l\n",
|
||||
" # print(mask1.shape)\n",
|
||||
" # print(morphed2l.shape)\n",
|
||||
" morphed2l = cv2.bitwise_or(morphed2l, 255-mask1)\n",
|
||||
" # return morphed2l\n",
|
||||
" \n",
|
||||
" morphed2l = morphed2l[(sdim*2):-(sdim*2), (sdim*2):-(sdim*2)]\n",
|
||||
" temp_final = cv2.bitwise_or(threshl, 255-morphed2l)\n",
|
||||
" return temp_final\n",
|
||||
" \n",
|
||||
" canny = cv2.Canny(morphed2l, 0, 500)\n",
|
||||
" # return canny\n",
|
||||
"\n",
|
||||
" vminlength = mainimage.shape[0]//10\n",
|
||||
" vmaxgap = mainimage.shape[0]//50\n",
|
||||
" vlinesP = cv2.HoughLinesP(canny, 1, np.pi / 180, 10, None, vminlength, vmaxgap)\n",
|
||||
" \n",
|
||||
" hminlength = mainimage.shape[1]//15\n",
|
||||
" hmaxgap = mainimage.shape[1]//40\n",
|
||||
" hlinesP = cv2.HoughLinesP(canny, 1, np.pi / 180, 10, None, hminlength, hmaxgap)\n",
|
||||
" # print(linesP)\n",
|
||||
" \n",
|
||||
" vmarginlines = mf.WithinXDegrees(vlinesP, 15)\n",
|
||||
" hmarginlines = mf.WithinXDegrees(hlinesP, 15, baseangle=90)\n",
|
||||
" \n",
|
||||
" marginlines = np.append(vmarginlines, hmarginlines, axis=0)\n",
|
||||
" # marginlines = marginlines.astype(int)\n",
|
||||
" # # print(marginlines)\n",
|
||||
" # reshaped = np.reshape(marginlines, (-1,1, 2))\n",
|
||||
" # # reshaped = cv2.convexHull(reshaped)\n",
|
||||
" # # print(reshaped)\n",
|
||||
" \n",
|
||||
" \n",
|
||||
" \n",
|
||||
" colourdst = cv2.cvtColor(morphedl, cv2.COLOR_GRAY2BGR)\n",
|
||||
" # out = cv2.drawContours(colourdst, [reshaped], -1, (0,255,0), thickness=3)\n",
|
||||
" # return out\n",
|
||||
" \n",
|
||||
" \n",
|
||||
" #### NEW IDEA: MERGE THE WHITEOUT BACKGROUND AND TEXT CLARIFICATION STEP BECAUSE DOING THE OTSU THRESHOLD SEEMS TO WORK PRETTY WELL AND IF I JUST WHITE OUT THE OUTER AREA (ACTUALLY WHITE)\n",
|
||||
" # THEN I HAVE JUST THE TEXT\n",
|
||||
" \n",
|
||||
"\n",
|
||||
" if marginlines is not None:\n",
|
||||
" for l in marginlines:\n",
|
||||
" cv2.line(colourdst, (int(l[0]), int(l[1])), (int(l[2]), int(l[3])), (0,0,255), 3, cv2.LINE_AA)\n",
|
||||
" return colourdst\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
" ## IDEA:\n",
|
||||
" # MASK OUT THE WORDS USING OUR MASKS MADE FROM THE STUFF BELOW. THEN WHEN CANNY IS DONE TO IT, IT SHOULDN'T HAVE A WHOLE BUNCH OF SHIT IN THE CENTER. STILL NEED TO FIGURE OUT HOW TO LINK THE HOUGH LINES AROUND THE RECEIPT\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
" # morphedl = 255-cv2.morphologyEx(255-threshl, cv2.MORPH_OPEN, kernel, iterations=3)\n",
|
||||
" morphedl = paddedl\n",
|
||||
" morphedl = cv2.morphologyEx(morphedl, cv2.MORPH_ERODE, kernel, iterations=1)\n",
|
||||
" morphedl = cv2.morphologyEx(morphedl, cv2.MORPH_ERODE, kernelell, iterations=1)\n",
|
||||
"\n",
|
||||
" # return morphedl\n",
|
||||
" \n",
|
||||
" contours, heirarchy = cv2.findContours(morphedl, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)\n",
|
||||
" # print(contours[0].shape)\n",
|
||||
" print(contours[0])\n",
|
||||
" biggestcontour = max(contours, key=cv2.contourArea)\n",
|
||||
" return canny\n",
|
||||
" \n",
|
||||
" \n",
|
||||
" blank = np.full(labl.shape, 255, dtype=np.uint8)\n",
|
||||
" mask1 = blank.copy()\n",
|
||||
" mask1 = mf.padWithColour(mask1, sdim*2, sdim*2, fill=255)\n",
|
||||
" mask1 = cv2.drawContours(mask1, [biggestcontour], -1, 0, thickness=cv2.FILLED)\n",
|
||||
" \n",
|
||||
" \n",
|
||||
" mask1 = mask1[(sdim*2):-(sdim*2), (sdim*2):-(sdim*2)]\n",
|
||||
" \n",
|
||||
" \n",
|
||||
" # resizemask = cv2.resize(mask1, (ogshape[1], ogshape[0]))\n",
|
||||
" # return resizemask\n",
|
||||
" maskc = cv2.cvtColor(mask1, cv2.COLOR_GRAY2BGR)\n",
|
||||
" # print(maskc.shape)\n",
|
||||
" # print(image.shape)\n",
|
||||
" whitedbackground = cv2.bitwise_or(mainimage, maskc)\n",
|
||||
" # return whitedbackground\n",
|
||||
" \n",
|
||||
" \n",
|
||||
" lab2 = cv2.cvtColor(whitedbackground, cv2.COLOR_BGR2LAB)\n",
|
||||
" \n",
|
||||
" lab2l = lab2[:,:,0]\n",
|
||||
" \n",
|
||||
" \n",
|
||||
" otsu2 = cv2.threshold(lab2l, 0, 255, cv2.THRESH_OTSU)[1]\n",
|
||||
" \n",
|
||||
" expandedmask1 = cv2.morphologyEx(mask1, cv2.MORPH_DILATE, kernel, iterations=1)\n",
|
||||
" expandedmask1 = cv2.morphologyEx(expandedmask1, cv2.MORPH_DILATE, kernelell, iterations=1)\n",
|
||||
" # return expandedmask1\n",
|
||||
" \n",
|
||||
" maskmerge = cv2.bitwise_and(otsu2, 255-expandedmask1)\n",
|
||||
" return mask1\n",
|
||||
" return maskmerge\n",
|
||||
" \n",
|
||||
" # return otsu2\n",
|
||||
" \n",
|
||||
" mpad = mf.padWithColour(maskmerge, sdim*2, sdim*2, fill=0)\n",
|
||||
" return mpad\n",
|
||||
" \n",
|
||||
" #MORPHOLOGIES \n",
|
||||
" morphed2 = cv2.morphologyEx(mpad, cv2.MORPH_ERODE, kernel, iterations=1)\n",
|
||||
" morphed2 = cv2.morphologyEx(morphed2, cv2.MORPH_ERODE, kernelell, iterations=1)\n",
|
||||
" return morphed2\n",
|
||||
" \n",
|
||||
" contours, heirarchy = cv2.findContours(morphed2, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)\n",
|
||||
" biggestcontour = max(contours, key=cv2.contourArea)\n",
|
||||
" \n",
|
||||
" \n",
|
||||
" mask2 = blank.copy()\n",
|
||||
" mask2 = mf.padWithColour(mask2, sdim*2, sdim*2, fill=255)\n",
|
||||
" mask2 = cv2.drawContours(mask2, [biggestcontour], -1, 0, thickness=cv2.FILLED)\n",
|
||||
" \n",
|
||||
" \n",
|
||||
" mask2 = mask2[(sdim*2):-(sdim*2), (sdim*2):-(sdim*2)]\n",
|
||||
" \n",
|
||||
" return mask2\n",
|
||||
" \n",
|
||||
" test = cv2.inpaint(whitedbackground, resizemask, 3, cv2.INPAINT_TELEA)\n",
|
||||
" \n",
|
||||
" return test\n",
|
||||
" \n",
|
||||
" contours, heirarchy = cv2.findContours(255-labl, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)\n",
|
||||
" \n",
|
||||
" imgout = cv2.drawContours(mainimage, contours, -1, (0,255,0), thickness=3)\n",
|
||||
" return imgout\n",
|
||||
" \n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def textleaver(image):\n",
|
||||
" ogshape = image.shape\n",
|
||||
" shrunkdim=1000\n",
|
||||
" if (image.shape[1] > image.shape[0]):\n",
|
||||
" shrunkimg, scaler = mf.ResizeWithAspectRatio(image, width=shrunkdim, retscale=True)\n",
|
||||
" else:\n",
|
||||
" shrunkimg, scaler = mf.ResizeWithAspectRatio(image, height=shrunkdim, retscale=True)\n",
|
||||
" \n",
|
||||
" mainimage = shrunkimg\n",
|
||||
" \n",
|
||||
" sdim = int(min(mainimage.shape[0], mainimage.shape[1])/5)\n",
|
||||
" srkernel = cv2.getStructuringElement(cv2.MORPH_RECT, (sdim, sdim))\n",
|
||||
" skernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (sdim, sdim))\n",
|
||||
" \n",
|
||||
" oglab = cv2.cvtColor(image, cv2.COLOR_BGR2LAB)\n",
|
||||
" lab = cv2.cvtColor(mainimage, cv2.COLOR_BGR2LAB)\n",
|
||||
" \n",
|
||||
" imglist = []\n",
|
||||
" # imglist.append(mainimage)\n",
|
||||
" \n",
|
||||
" labl = lab[:,:,0]\n",
|
||||
" oglabl = oglab[:,:,0]\n",
|
||||
" # # imglist.append(labl)\n",
|
||||
" # # imglist.append(cv2.cvtColor(image, cv2.COLOR_BGR2GRAY))\n",
|
||||
" # laba = lab[:,:,1]\n",
|
||||
" # # imglist.append(laba)\n",
|
||||
" # labb = lab[:,:,2]\n",
|
||||
" # # imglist.append(labb)\n",
|
||||
" \n",
|
||||
" divisor = 1.5\n",
|
||||
" window = int(min(labl.shape)/divisor)\n",
|
||||
" window = window if window%2 == 1 else window + 1\n",
|
||||
" # canny = cv2.Canny(labl, 0, 500)\n",
|
||||
" ethreshl = cv2.threshold(labl, 0, 255, cv2.THRESH_OTSU)[1]\n",
|
||||
" threshl = cv2.threshold(labl, 0, 255, cv2.THRESH_OTSU)[1]\n",
|
||||
" # threshl = cv2.adaptiveThreshold(labl, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY, window, 35)\n",
|
||||
" \n",
|
||||
" \n",
|
||||
" ogwindow = int(min(oglabl.shape)/divisor)\n",
|
||||
" ogwindow = window if window%2 == 1 else window + 1\n",
|
||||
" print(ogwindow)\n",
|
||||
" ogthreshl = cv2.threshold(oglabl, 0, 255, cv2.THRESH_TRIANGLE)[1]\n",
|
||||
" return ogthreshl\n",
|
||||
" # ogthreshl = cv2.adaptiveThreshold(oglabl, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY, ogwindow, 35)\n",
|
||||
" # return threshl\n",
|
||||
" \n",
|
||||
" colourthresh = cv2.cvtColor(threshl, cv2.COLOR_GRAY2BGR)\n",
|
||||
" \n",
|
||||
" dim = int(min(mainimage.shape[0], mainimage.shape[1])/100)\n",
|
||||
" # dim = 2\n",
|
||||
" # dim = dotsize\n",
|
||||
" dim = max(3,dim)\n",
|
||||
" kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (dim, dim))\n",
|
||||
" kernelell = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (dim, dim))\n",
|
||||
" \n",
|
||||
" # paddedl = mf.padWithColour(threshl, sdim*2, sdim*2, fill=0)\n",
|
||||
" paddedl = threshl\n",
|
||||
" # return paddedl\n",
|
||||
" \n",
|
||||
" \n",
|
||||
" # morphedl = 255-cv2.morphologyEx(255-threshl, cv2.MORPH_OPEN, kernel, iterations=3)\n",
|
||||
" morphedl = paddedl\n",
|
||||
" morphed1l = cv2.morphologyEx(morphedl, cv2.MORPH_ERODE, kernel, iterations=1)\n",
|
||||
" # morphed1l = cv2.morphologyEx(morphed1l, cv2.MORPH_OPEN, kernel, iterations=1)\n",
|
||||
" # morphed1l = cv2.morphologyEx(morphed1l, cv2.MORPH_OPEN, kernel, iterations=1)\n",
|
||||
" # morphed1l = cv2.morphologyEx(morphedl, cv2.MORPH_ERODE, kernelell, iterations=2)\n",
|
||||
" \n",
|
||||
" emorphed1l = cv2.morphologyEx(ethreshl, cv2.MORPH_ERODE, kernel, iterations=1)\n",
|
||||
"\n",
|
||||
" # return morphedl\n",
|
||||
" \n",
|
||||
" contours, heirarchy = cv2.findContours(morphed1l, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)\n",
|
||||
" biggestcontour = max(contours, key=cv2.contourArea)\n",
|
||||
" \n",
|
||||
" # temp = cv2.drawContours(colourthresh, [biggestcontour], -1, (0,255,0), thickness=1)\n",
|
||||
" # return temp\n",
|
||||
" \n",
|
||||
" \n",
|
||||
" blank = np.full(labl.shape, 255, dtype=np.uint8)\n",
|
||||
" mask1 = blank.copy()\n",
|
||||
" # mask1 = mf.padWithColour(mask1, sdim*2, sdim*2, fill=255)\n",
|
||||
" mask1 = cv2.drawContours(mask1, [biggestcontour], -1, 0, thickness=cv2.FILLED)\n",
|
||||
" ## need to change the erosion so that if the paper goes to the edge, it doesn't get eroded in (because that means the paper is right to the edge and writing may be close)\n",
|
||||
" \n",
|
||||
" contours, heirarchy = cv2.findContours(morphed1l, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)\n",
|
||||
" biggestcontour = max(contours, key=cv2.contourArea)\n",
|
||||
" \n",
|
||||
" emask1 = blank.copy()\n",
|
||||
" emask1 = cv2.drawContours(emask1, [biggestcontour], -1, 0, thickness=cv2.FILLED)\n",
|
||||
" \n",
|
||||
" mask1 = 255-cv2.morphologyEx(255-mask1, cv2.MORPH_ERODE, kernel, iterations=2)\n",
|
||||
" \n",
|
||||
" emask1 = 255-cv2.morphologyEx(255-emask1, cv2.MORPH_ERODE, kernel, iterations=2)\n",
|
||||
" \n",
|
||||
" \n",
|
||||
" # mask1 = mask1[(sdim*2):-(sdim*2), (sdim*2):-(sdim*2)]\n",
|
||||
" # return mask1\n",
|
||||
" \n",
|
||||
" # morphed2l = mf.padWithColour(morphedl, sdim*2, sdim*2, fill=255)\n",
|
||||
" morphed2l = cv2.morphologyEx(morphedl, cv2.MORPH_OPEN, kernel, iterations=1)\n",
|
||||
" morphed2l = cv2.morphologyEx(morphedl, cv2.MORPH_ERODE, kernel, iterations=1)\n",
|
||||
" # morphed2l = morphed2l[(sdim*2):-(sdim*2), (sdim*2):-(sdim*2)]\n",
|
||||
" \n",
|
||||
" # return morphed2l\n",
|
||||
" # print(mask1.shape)\n",
|
||||
" # print(morphed2l.shape)\n",
|
||||
" morphed2l = cv2.bitwise_or(morphed2l, 255-mask1)\n",
|
||||
" # return morphed2l\n",
|
||||
"\n",
|
||||
" # paddedthreshl = mf.padWithColour(morphed2l, sdim*2, sdim*2, fill=255)\n",
|
||||
" # temp = cv2.drawContours(colourthresh, [biggestcontour], -1, (0,255,0), thickness=1)\n",
|
||||
" # return temp\n",
|
||||
"\n",
|
||||
"\n",
|
||||
" morphed2l = cv2.morphologyEx(morphed2l, cv2.MORPH_ERODE, kernel, iterations=1)\n",
|
||||
" morphed2l = cv2.morphologyEx(morphed2l, cv2.MORPH_ERODE, kernelell, iterations=1)\n",
|
||||
" # return morphed2l\n",
|
||||
" # morphed2l = cv2.bitwise_or(morphed2l, 255-emask1)\n",
|
||||
" \n",
|
||||
" # morphed2l = morphed2l[(sdim*2):-(sdim*2), (sdim*2):-(sdim*2)]\n",
|
||||
" \n",
|
||||
" resizedmask = cv2.resize(255-morphed2l, (ogshape[1], ogshape[0]))\n",
|
||||
" temp_final = cv2.bitwise_or(ogthreshl, resizedmask)\n",
|
||||
" \n",
|
||||
" dim=3\n",
|
||||
" kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (dim, dim))\n",
|
||||
" temp_final = cv2.morphologyEx(temp_final, cv2.MORPH_OPEN, kernel)\n",
|
||||
" temp_final = cv2.morphologyEx(temp_final, cv2.MORPH_OPEN, kernel)\n",
|
||||
" # temp_final = cv2.morphologyEx(temp_final, cv2.MORPH_CLOSE, kernel)\n",
|
||||
" # temp_final = cv2.morphologyEx(temp_final, cv2.MORPH_OPEN, kernel)\n",
|
||||
" return temp_final"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def cropclarifying(image):\n",
|
||||
" # whitedbackground = whiteoutbackground(image)\n",
|
||||
" # return whitedbackground\n",
|
||||
"\n",
|
||||
" # textrefined = mf.textClarifying(whitedbackground)\n",
|
||||
" textrefined = textleaver(image)\n",
|
||||
" return textrefined\n",
|
||||
" #maybe now is when I put in the line removing function\n",
|
||||
"\n",
|
||||
" lineout = mf.removeLinesFromText(textrefined)\n",
|
||||
"\n",
|
||||
" return lineout\n",
|
||||
" # implement a function that's called refine text"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def houghlineprocessing(image):\n",
|
||||
" croppedanddeskewed, angle = mf.houghlinedeskewandcrop(image)\n",
|
||||
" # return croppedanddeskewed\n",
|
||||
" \n",
|
||||
" \n",
|
||||
" # postprocessed = cropclarifying(croppedanddeskewed)\n",
|
||||
" postprocessed = croppedanddeskewed\n",
|
||||
" # return postprocessed\n",
|
||||
" # postprocessed = mf.croptoblack(postprocessed)\n",
|
||||
" \n",
|
||||
" # postprocessed = cv2.cvtColor(postprocessed, cv2.COLOR_GRAY2BGR)\n",
|
||||
" # return postprocessed\n",
|
||||
" \n",
|
||||
" # final = mf.externaldeskew(postprocessed, fill=(255,255,255))\n",
|
||||
" # rotangle = mf.receipttextdeskew(postprocessed, fill=(255,255,255), returnangle=True)\n",
|
||||
" final = postprocessed\n",
|
||||
" \n",
|
||||
" \n",
|
||||
" # final = mf.croptoblack(final)\n",
|
||||
" \n",
|
||||
" # cv2.imshow(\"postprocessed\", mf.ResizeWithAspectRatio(postprocessed, 1000))\n",
|
||||
" # cv2.imshow(\"final\", mf.ResizeWithAspectRatio(final, 1000))\n",
|
||||
" # cv2.waitKey(0)\n",
|
||||
" # cv2.destroyAllWindows()\n",
|
||||
" \n",
|
||||
" return final"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# print(img.shape)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"0.0\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# prepped, scaler, hp, vp = mf.squareandthenresize(img, fill=255, width=1000, returnscalerinfo=True)\n",
|
||||
"outs = houghlineprocessing(img)\n",
|
||||
"# outs = prepimageforhoughline(img, returnrect=True)\n",
|
||||
"# print(img.shape)\n",
|
||||
"# outs = houghlinedeskewandcrop(img)\n",
|
||||
"# outs = outs[0]\n",
|
||||
"# print(croprect)\n",
|
||||
"#need to fix premorphCrop. it removes too much"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# shrunk, scaler, hp, vp = mf.squareandthenresize(img, fill=255, width=1000, returnscalerinfo=True)\n",
|
||||
"# shrunk1, croprect = mf.premorphCrop(shrunk)\n",
|
||||
"# print(croprect)\n",
|
||||
"# print(int(30*4.032 - 0))\n",
|
||||
"# # temp = img[100:, :, :]\n",
|
||||
"# temp = shrunk[croprect[1]:croprect[1]+croprect[3], croprect[0]:croprect[0]+croprect[2], :]\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# cv2.imshow(\"temp\", mf.ResizeWithAspectRatio(out, height=1000))\n",
|
||||
"# # cv2.imshow(\"shrunk1\", mf.ResizeWithAspectRatio(shrunk1, height=1000))\n",
|
||||
"# cv2.waitKey(0)\n",
|
||||
"# cv2.destroyAllWindows()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"testall = True"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 15,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"if not testall:\n",
|
||||
" showimgs(outs)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 16,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# # for out in outs:\n",
|
||||
"# # if (out.shape[0] > out.shape[1]):\n",
|
||||
"# # cv2.imshow(\"test1\", mf.ResizeWithAspectRatio(out, height=1000))\n",
|
||||
"# # else:\\\n",
|
||||
"# # cv2.imshow(\"test1\", mf.ResizeWithAspectRatio(out, width=1000))\n",
|
||||
"# # key = cv2.waitKey(0)\n",
|
||||
"# # cv2.destroyAllWindows()\n",
|
||||
"# # if (key == 107):\n",
|
||||
"# # break\n",
|
||||
"# if (isinstance(outs, np.ndarray)):\n",
|
||||
"# if (outs.shape[0] > outs.shape[1]):\n",
|
||||
"# cv2.imshow(\"test\", mf.ResizeWithAspectRatio(outs, height=1350))\n",
|
||||
"# else:\n",
|
||||
"# cv2.imshow(\"test\", mf.ResizeWithAspectRatio(outs, width=1000))\n",
|
||||
"# else:\n",
|
||||
"# for i, out in enumerate(outs):\n",
|
||||
"# if (out.shape[0] > out.shape[1]):\n",
|
||||
"# cv2.imshow(\"test\"+str(i), mf.ResizeWithAspectRatio(out, height=1350))\n",
|
||||
"# else:\n",
|
||||
"# cv2.imshow(\"test\"+str(i), mf.ResizeWithAspectRatio(out, width=1000))\n",
|
||||
"# cv2.waitKey(0)\n",
|
||||
"# cv2.destroyAllWindows()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 17,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"0.9740282517223996\n",
|
||||
"-2.0053522829578814\n",
|
||||
"-0.9740282517223996\n",
|
||||
"0.0\n",
|
||||
"0.9740282517223996\n",
|
||||
"-0.9740282517223996\n",
|
||||
"-0.011669615052326776\n",
|
||||
"2.0053522829578814\n",
|
||||
"0.0\n",
|
||||
"0.0\n",
|
||||
"0.0\n",
|
||||
"-2.979380534680281\n",
|
||||
"0.0\n",
|
||||
"0.0\n",
|
||||
"-2.0053522829578814\n",
|
||||
"-11.000789666511807\n",
|
||||
"average time: 0.19967518746852875(s)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"if testall:\n",
|
||||
" results = testondataset(\"/mnt/dataset/baseimages/\", houghlineprocessing)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 18,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# if testall:\n",
|
||||
"# showimgs(results)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 19,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# print(results[0])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 20,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"if testall:\n",
|
||||
" writeimgs(\"/mnt/code/autocropper/result_images/\", results)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.12"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
@ -1,11 +0,0 @@
|
||||
#ifndef CROPPER_H
|
||||
#define CROPPER_H
|
||||
|
||||
#include <opencv2/opencv.hpp>
|
||||
|
||||
|
||||
bool crop(cv::InputArray src, cv::OutputArray dst, bool fastsearch = true, int imageHeight = 700);
|
||||
|
||||
|
||||
|
||||
#endif //CROPPER_H
|
||||
@ -1 +0,0 @@
|
||||
#define DEBUG 1
|
||||
@ -1,43 +0,0 @@
|
||||
#include <cropper.h>
|
||||
|
||||
#include <opencv2/opencv.hpp>
|
||||
|
||||
// PLAN:
|
||||
// Implement selective search
|
||||
// Implement Canny edge detection and then find a good rectangle
|
||||
// Do L2 loss with the corners of the rectangle and choose the selective search rectangle with the lowest loss
|
||||
|
||||
|
||||
//for testing delete later
|
||||
#include <iostream>
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
|
||||
if (argc < 2) {
|
||||
std::cerr << "BAD" << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
cv::Mat imOut, result;
|
||||
|
||||
imOut = cv::imread(argv[1]);
|
||||
if (imOut.empty()) {
|
||||
std::cout << "Could not open or find the image!\n" << std::endl;
|
||||
std::cout << "Usage: " << argv[0] << " <Input image>" << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
crop(imOut, result, true, 1000);
|
||||
|
||||
|
||||
int imageHeight = 800;
|
||||
int newWidth = result.cols * imageHeight / result.rows;
|
||||
cv::resize(result, result, cv::Size(newWidth, imageHeight));
|
||||
|
||||
cv::imshow("banana", result);
|
||||
imwrite("../testing_space/cropped.jpg", result);
|
||||
cv::waitKey();
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because one or more lines are too long
@ -1,430 +0,0 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# https://docs.opencv.org/3.4/d9/db0/tutorial_hough_lines.html\n",
|
||||
"# https://medium.com/@9sphere/machine-vision-recipes-deskewing-document-images-e17827894c34\n",
|
||||
"# https://towardsdatascience.com/pre-processing-in-ocr-fc231c6035a7"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"#initially for deskewing and cropping. moving to a doc for just cropping now that deskewing"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 15,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import cv2\n",
|
||||
"import numpy as np\n",
|
||||
"import math\n",
|
||||
"import myfunctions as mf\n",
|
||||
"\n",
|
||||
"import scipy.stats as st"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 16,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# def ResizeWithAspectRatio(image, width=None, height=None, inter=cv2.INTER_AREA, retscale=False):\n",
|
||||
"# dim = None\n",
|
||||
"# (h, w) = image.shape[:2]\n",
|
||||
"\n",
|
||||
"# if width is None and height is None:\n",
|
||||
"# if (retscale == True):\n",
|
||||
"# return (image, 1)\n",
|
||||
"# return image\n",
|
||||
"# if width is None:\n",
|
||||
"# r = height / float(h)\n",
|
||||
"# dim = (int(w * r), height)\n",
|
||||
"# else:\n",
|
||||
"# r = width / float(w)\n",
|
||||
"# dim = (width, int(h * r))\n",
|
||||
"\n",
|
||||
"# if (retscale == True):\n",
|
||||
"# # print(\"hi\")\n",
|
||||
"# return (cv2.resize(image, dim, interpolation=inter), 1/r)\n",
|
||||
"# return cv2.resize(image, dim, interpolation=inter)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# class SquarePad:\n",
|
||||
"# def __init__(self, fill):\n",
|
||||
"# self.fill = fill\n",
|
||||
" \n",
|
||||
"# def __call__(self, image):\n",
|
||||
"# w, h = image.shape[1], image.shape[0]\n",
|
||||
"# max_wh = np.max([w, h])\n",
|
||||
"# hp = int((max_wh - w) / 2)\n",
|
||||
"# vp = int((max_wh - h) / 2)\n",
|
||||
"# padding = (hp, vp, hp, vp)\n",
|
||||
"# return cv2.copyMakeBorder(image, vp, vp, hp, hp, cv2.BORDER_CONSTANT, self.fill)\n",
|
||||
" \n",
|
||||
" \n",
|
||||
" \n",
|
||||
"# def rotate(img, angle):\n",
|
||||
"# rows,cols = img.shape[0], img.shape[1]\n",
|
||||
"# M = cv2.getRotationMatrix2D((cols/2,rows/2),angle,1)\n",
|
||||
"# dst = cv2.warpAffine(img,M,(cols,rows))\n",
|
||||
"# return dst"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 17,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# def morphologyCrop(image):\n",
|
||||
"# # convert to grayscale\n",
|
||||
"# gray = cv2.cvtColor(image,cv2.COLOR_BGR2GRAY)\n",
|
||||
"\n",
|
||||
"# # threshold\n",
|
||||
"# thresh = cv2.threshold(gray, 170, 255, cv2.THRESH_BINARY)[1]\n",
|
||||
"\n",
|
||||
"# # apply morphology\n",
|
||||
"# kernel = np.ones((7,7), np.uint8)\n",
|
||||
"# morph = cv2.morphologyEx(thresh, cv2.MORPH_CLOSE, kernel)\n",
|
||||
"# kernel = np.ones((9,9), np.uint8)\n",
|
||||
"# morph = cv2.morphologyEx(morph, cv2.MORPH_ERODE, kernel)\n",
|
||||
"\n",
|
||||
"# # get largest contour\n",
|
||||
"# contours = cv2.findContours(morph, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)\n",
|
||||
"# contours = contours[0] if len(contours) == 2 else contours[1]\n",
|
||||
"# area_thresh = 0\n",
|
||||
"# for c in contours:\n",
|
||||
"# area = cv2.contourArea(c)\n",
|
||||
"# if area > area_thresh:\n",
|
||||
"# area_thresh = area\n",
|
||||
"# big_contour = c\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# # get bounding box\n",
|
||||
"# x,y,w,h = cv2.boundingRect(big_contour)\n",
|
||||
"\n",
|
||||
"# # draw filled contour on black background\n",
|
||||
"# mask = np.zeros_like(gray)\n",
|
||||
"# mask = cv2.merge([mask,mask,mask])\n",
|
||||
"# cv2.drawContours(mask, [big_contour], -1, (255,255,255), cv2.FILLED)\n",
|
||||
"\n",
|
||||
"# # apply mask to input\n",
|
||||
"# result1 = image.copy()\n",
|
||||
"# result1 = cv2.bitwise_and(result1, mask)\n",
|
||||
"\n",
|
||||
"# # crop result\n",
|
||||
"# result2 = result1[y:y+h, x:x+w]\n",
|
||||
"# return result2"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 18,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"\n",
|
||||
"# x = -2*np.pi/3\n",
|
||||
"# print(x)\n",
|
||||
"# print(np.pi/3)\n",
|
||||
"# print(x % np.pi)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 19,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# def lineAngle(line):\n",
|
||||
"# # print(line)\n",
|
||||
"# angle = (math.atan2(line[3] - line[1], line[2] - line[0]) % np.pi) - (np.pi/2)\n",
|
||||
"# return angle\n",
|
||||
" \n",
|
||||
"# def WithinXDegrees(lines, margin):\n",
|
||||
"# # outlines = np.array([[]])\n",
|
||||
"# outlines = np.empty((0, 4))\n",
|
||||
"# # print(outlines.shape)\n",
|
||||
"# for line in lines:\n",
|
||||
"# # print(type(line))\n",
|
||||
"# # print(abs(lineAngle(line[0])))\n",
|
||||
"# if (np.rad2deg(abs(lineAngle(line[0]))) <= margin):\n",
|
||||
"# outlines = np.append(outlines, [line[0]], axis=0)\n",
|
||||
"# return outlines\n",
|
||||
"\n",
|
||||
"# def lineBoundingRect(lines):\n",
|
||||
"# maxvals = lines.max(0)\n",
|
||||
"# minvals = lines.min(0)\n",
|
||||
"# boundingrect = (min(minvals[0],minvals[2]), min(minvals[1],minvals[3]), max(maxvals[0],maxvals[2]),max(maxvals[1],maxvals[3]))\n",
|
||||
"# return boundingrect\n",
|
||||
"# # print(lines.max(0))\n",
|
||||
"# # print(type(lines))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 20,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"img = cv2.imread('./test_images/IMG_7605.jpg')\n",
|
||||
"img = mf.SquarePad(fill=255)(img)\n",
|
||||
"img = mf.rotate(img, 54)\n",
|
||||
"img = mf.morphologyCrop(mf.ResizeWithAspectRatio(img,1000))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 21,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)\n",
|
||||
"# img = cv2.threshold(img, 200, 255, cv2.THRESH_BINARY)[1]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 22,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"cv2.imshow(\"Detected Lines (in red) - Standard Hough Line Transform\", mf.ResizeWithAspectRatio(mf.SquarePad(fill=255)(img), 1000))\n",
|
||||
"cv2.waitKey(0)\n",
|
||||
"cv2.destroyAllWindows()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 23,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"resizedimg = mf.ResizeWithAspectRatio(mf.SquarePad(fill=255)(img), 500)\n",
|
||||
"\n",
|
||||
"# cv2.imshow(\"Detected Lines (in red) - Standard Hough Line Transform\", img)\n",
|
||||
"# cv2.waitKey(0)\n",
|
||||
"# cv2.destroyAllWindows()\n",
|
||||
"\n",
|
||||
"gray = cv2.cvtColor(resizedimg ,cv2.COLOR_BGR2GRAY)\n",
|
||||
"gray = cv2.threshold(gray, 200, 255, cv2.THRESH_BINARY)[1]\n",
|
||||
"cdst = resizedimg.copy()\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"dst = cv2.Canny(gray, 50, 200, None, 3)\n",
|
||||
"lines = cv2.HoughLines(dst, 1, np.pi/180, 150, None, 0, 0)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 24,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"angles = np.zeros(len(lines))\n",
|
||||
"if lines is not None:\n",
|
||||
" for i in range(0, len(lines)):\n",
|
||||
" rho = lines[i][0][0]\n",
|
||||
" theta = lines[i][0][1]\n",
|
||||
" a = math.cos(theta)\n",
|
||||
" b = math.sin(theta)\n",
|
||||
" x0 = a * rho\n",
|
||||
" y0 = b * rho\n",
|
||||
" unroundedpt1 = (x0 + 1000*(-b), y0 + 1000*(a))\n",
|
||||
" unroundedpt2 = (x0 - 1000*(-b), y0 - 1000*(a))\n",
|
||||
" pt1 = (int(unroundedpt1[0]), int(unroundedpt1[1]))\n",
|
||||
" pt2 = (int(unroundedpt2[0]), int(unroundedpt2[1]))\n",
|
||||
" v1_theta = math.atan2(pt1[1], pt1[0])\n",
|
||||
" v2_theta = math.atan2(pt2[1], pt2[0])\n",
|
||||
" # print(math.atan2(unroundedpt2[1] - unroundedpt1[1], unroundedpt2[0] - unroundedpt1[0]) % np.pi)\n",
|
||||
" # print(lineAngle((unroundedpt1[0], unroundedpt1[1], unroundedpt2[0], unroundedpt2[1])))\n",
|
||||
" # angles[i] = math.atan2(unroundedpt2[1] - unroundedpt1[1], unroundedpt2[0] - unroundedpt1[0]) % np.pi\n",
|
||||
" angles[i] = mf.lineAngle((unroundedpt1[0], unroundedpt1[1], unroundedpt2[0], unroundedpt2[1]))\n",
|
||||
" cv2.line(cdst, pt1, pt2, (0,0,255), 3, cv2.LINE_AA)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 25,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"-56.7228217179515\n",
|
||||
"-56.7228217179515\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# print(st.mode(np.around(angles, decimals=1)))\n",
|
||||
"mode = st.mode(np.around(angles, decimals=2))[0]\n",
|
||||
"print(np.rad2deg(mode))\n",
|
||||
"# slope = math.tan(np.deg2rad(mode))\n",
|
||||
"# print(slope)\n",
|
||||
"# myy0 = 0\n",
|
||||
"# p1 = [0,myy0]\n",
|
||||
"# p2 = [0,myy0]\n",
|
||||
"# while (math.dist(p1, p2) < 5000):\n",
|
||||
"# p2[0] += 0.5\n",
|
||||
"# p2[1] += 0.5*slope*1000\n",
|
||||
"# p2[1] = int(p2[1])\n",
|
||||
"# print(p2)\n",
|
||||
"# cv2.line(cdst, p1, p2, (0,255,0), 3, cv2.LINE_AA)\n",
|
||||
"# rotationangle = np.rad2deg(mode)-90\n",
|
||||
"rotationangle = np.rad2deg(mode)\n",
|
||||
"print(rotationangle)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 26,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"cv2.imshow(\"Detected Lines (in red) - Standard Hough Line Transform\", cdst)\n",
|
||||
"cv2.waitKey(0)\n",
|
||||
"cv2.destroyAllWindows()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 27,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# cv2.imshow(\"Detected Lines (in red) - Standard Hough Line Transform\", rotate(cdst,rotationangle))\n",
|
||||
"# cv2.waitKey(0)\n",
|
||||
"# cv2.destroyAllWindows()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 28,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"rotatedimg = mf.SquarePad(fill=255)(mf.rotate(img, rotationangle))\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 29,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# cv2.imshow(\"Rotated Image\", ResizeWithAspectRatio(rotatedimg, 1000))\n",
|
||||
"# cv2.waitKey(0)\n",
|
||||
"# cv2.destroyAllWindows()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 30,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"resizedrotatedimg = mf.ResizeWithAspectRatio(rotatedimg, 500)\n",
|
||||
"gray1 = cv2.cvtColor(resizedrotatedimg, cv2.COLOR_BGR2GRAY)\n",
|
||||
"dst1 = cv2.Canny(gray1, 0, 500, None, 3)\n",
|
||||
"cdstP = resizedrotatedimg.copy()\n",
|
||||
"cdstPmargin = cdstP.copy()\n",
|
||||
"linesP = cv2.HoughLinesP(dst1, 1, np.pi / 180, 30, None, 100, 30)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 31,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"if linesP is not None:\n",
|
||||
" for i in range(0, len(linesP)):\n",
|
||||
" l = linesP[i][0]\n",
|
||||
" cv2.line(cdstP, (l[0], l[1]), (l[2], l[3]), (0,0,255), 3, cv2.LINE_AA)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 32,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"cv2.imshow(\"Detected Lines (in red) - Standard Hough Line Transform\", cdstP)\n",
|
||||
"cv2.waitKey(0)\n",
|
||||
"cv2.destroyAllWindows()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 33,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# print(linesP)\n",
|
||||
"marginlines = mf.WithinXDegrees(linesP, 2)\n",
|
||||
"# print(marginlines)\n",
|
||||
"if marginlines is not None:\n",
|
||||
" for i in range(0, len(marginlines)):\n",
|
||||
" l = marginlines[i]\n",
|
||||
" cv2.line(cdstPmargin, (int(l[0]), int(l[1])), (int(l[2]), int(l[3])), (0,0,255), 3, cv2.LINE_AA)\n",
|
||||
" \n",
|
||||
"# boundingrectout = mf.lineBoundingRect(marginlines)\n",
|
||||
"# # print(boundingrectout)\n",
|
||||
"# cdstPmargin = cv2.rectangle(cdstPmargin,(int(boundingrectout[0]),int(boundingrectout[1])),(int(boundingrectout[2]),int(boundingrectout[3])),(0,255,0),2)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 34,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"cv2.imshow(\"Detected Lines (in red) - Standard Hough Line Transform\", cdstPmargin)\n",
|
||||
"cv2.waitKey(0)\n",
|
||||
"cv2.destroyAllWindows()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.12"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
@ -1,385 +0,0 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 137,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"version=2.0\n",
|
||||
"cachepath=\"../.cache/\"\n",
|
||||
"savepath=\"./savespot/\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 138,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import torch\n",
|
||||
"from torch.utils.data import DataLoader\n",
|
||||
"import torch.nn as nn\n",
|
||||
"import torch.nn.functional as fn\n",
|
||||
"import torch.optim as optim\n",
|
||||
"import torchvision.transforms.functional as tvf\n",
|
||||
"import torchvision.transforms.v2 as v2\n",
|
||||
"import torchvision.models as models\n",
|
||||
"import torchvision.transforms as t\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"from PIL import Image\n",
|
||||
"\n",
|
||||
"import datasets as ds\n",
|
||||
"from tqdm.autonotebook import tqdm\n",
|
||||
"\n",
|
||||
"import random\n",
|
||||
"\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"\n",
|
||||
"import numpy as np\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"torch.cuda.empty_cache()\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"import os\n",
|
||||
"import cv2"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 139,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# array = np.load(\"./testing_space/outputarray.npy\")\n",
|
||||
"# counter = np.load(\"./testing_space/counter.npy\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 140,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# print(array)\n",
|
||||
"# print(counter)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 141,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class RotationDeterminer(nn.Module):\n",
|
||||
" def __init__(self, new=False):\n",
|
||||
" super(RotationDeterminer,self).__init__()\n",
|
||||
" \n",
|
||||
" torch.cuda.empty_cache()\n",
|
||||
" \n",
|
||||
" self.device = torch.device(\"cpu\")\n",
|
||||
" if torch.cuda.is_available:\n",
|
||||
" self.device = torch.device(\"cuda:0\")\n",
|
||||
" \n",
|
||||
" \n",
|
||||
" self.appliers = [v2.RandomApply(transforms=[v2.RandomPosterize(bits=1)], p=0.25),\n",
|
||||
" v2.RandomApply(transforms=[v2.ElasticTransform(alpha=25.0)], p=0.25), # maybe add fill=appliedFill\n",
|
||||
" v2.RandomApply(transforms=[v2.GaussianBlur(kernel_size=(5,9), sigma=(0.1,2.))],p=0.25),\n",
|
||||
" v2.RandomApply(transforms=[v2.RandomEqualize()],p=0.25)]\n",
|
||||
" \n",
|
||||
" \n",
|
||||
" # self.conv = nn.Sequential(nn.Conv2d(3, 9, kernel_size=11,stride=3), # 1100 x 1100 => 201 x 201\n",
|
||||
" # nn.ReLU(inplace=True),\n",
|
||||
" # nn.Conv2d(9, 18, kernel_size=5,stride=1),\n",
|
||||
" # nn.ReLU(inplace=True),\n",
|
||||
" # nn.MaxPool2d(kernel_size=4, stride=2),\n",
|
||||
" # nn.Conv2d(18, 36, kernel_size=3,stride=2),\n",
|
||||
" # nn.BatchNorm2d(36),\n",
|
||||
" # nn.ReLU(inplace=True),\n",
|
||||
" # nn.Conv2d(36, 72, kernel_size=3,stride=2),\n",
|
||||
" # nn.ReLU(inplace=True),\n",
|
||||
" # nn.AvgPool2d(kernel_size=5, stride=3),\n",
|
||||
" # nn.Conv2d(72, 144, kernel_size=3,stride=1),\n",
|
||||
" # nn.ReLU(inplace=True),\n",
|
||||
" # nn.Conv2d(144, 288, kernel_size=5,stride=1),\n",
|
||||
" # nn.ReLU(inplace=True),\n",
|
||||
" # nn.MaxPool2d(kernel_size=4, stride=1),\n",
|
||||
" # nn.Conv2d(288, 192, kernel_size=3,stride=1),\n",
|
||||
" # nn.ReLU(inplace=True),\n",
|
||||
" # nn.Conv2d(192, 192, kernel_size=3,stride=1), # => 1\n",
|
||||
" # nn.ReLU(inplace=True))\n",
|
||||
" # print(\"hi\")\n",
|
||||
" self.conv = models.resnet18(pretrained=new)\n",
|
||||
" \n",
|
||||
" self.classifier = nn.Sequential(nn.Linear(1000, 4096),\n",
|
||||
" nn.ReLU(inplace=True),\n",
|
||||
" nn.Linear(4096,1))\n",
|
||||
" \n",
|
||||
" self.lossfunc = nn.MSELoss()\n",
|
||||
" \n",
|
||||
" self.imageprep = v2.Compose([self.SquarePad(),v2.Resize(512),v2.Grayscale(num_output_channels=3),v2.CenterCrop(512),v2.ToImageTensor(), v2.ConvertImageDtype()])\n",
|
||||
" \n",
|
||||
" \n",
|
||||
" class SquarePad:\n",
|
||||
" def __call__(self, image):\n",
|
||||
" # print(\"hi type:\", type(image))\n",
|
||||
" temp = image.size()\n",
|
||||
" w = temp[-2]\n",
|
||||
" h = temp[-1]\n",
|
||||
" max_wh = max([w, h])\n",
|
||||
" hp = int((max_wh - w) / 2)\n",
|
||||
" vp = int((max_wh - h) / 2)\n",
|
||||
" padding = (hp, vp, hp, vp)\n",
|
||||
" return tvf.pad(image, padding, 0, 'edge')\n",
|
||||
"\n",
|
||||
"\n",
|
||||
" \n",
|
||||
"\n",
|
||||
" \n",
|
||||
" def forward(self, image):\n",
|
||||
"\n",
|
||||
" transformedimage = self.imageprep(image)\n",
|
||||
" transformedimage = transformedimage.to(self.device)\n",
|
||||
"\n",
|
||||
" if (len(transformedimage.shape) != 4 and len(transformedimage.shape) != 3):\n",
|
||||
" raise Exception(\"Sorry, Dimension of image is incorrect (\", len(transformedimage.shape),\"). Expected a 3D (single image) or 4D (batch of images) tensor\")\n",
|
||||
"\n",
|
||||
" if (len(transformedimage.shape) == 3):\n",
|
||||
" x = transformedimage.unsqueeze(0)\n",
|
||||
" else:\n",
|
||||
" x = transformedimage\n",
|
||||
" \n",
|
||||
" x = self.conv(x)\n",
|
||||
" # print(x.shape)\n",
|
||||
" # x = nn.Flatten(start_dim=-1)(x)\n",
|
||||
" # print(x.shape)\n",
|
||||
" x = self.classifier(x)\n",
|
||||
" # print(x.shape)\n",
|
||||
" guessRotation = nn.Flatten(start_dim=0)(x)\n",
|
||||
" \n",
|
||||
" return guessRotation\n",
|
||||
" \n",
|
||||
" def loss(self, guess, trueAnswer):\n",
|
||||
" return self.lossfunc(guess, trueAnswer)\n",
|
||||
" \n",
|
||||
" "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 142,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/usr/local/lib/python3.10/dist-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.\n",
|
||||
" warnings.warn(\n",
|
||||
"/usr/local/lib/python3.10/dist-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=ResNet18_Weights.IMAGENET1K_V1`. You can also use `weights=ResNet18_Weights.DEFAULT` to get the most up-to-date weights.\n",
|
||||
" warnings.warn(msg)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"model = RotationDeterminer(new=True)\n",
|
||||
"device = torch.device(\"cpu\")\n",
|
||||
"if torch.cuda.is_available:\n",
|
||||
" device = torch.device(\"cuda:0\")\n",
|
||||
" model = model.to(device)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 143,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# def ResizeWithAspectRatio(image, width=None, height=None, inter=cv2.INTER_AREA):\n",
|
||||
"# dim = None\n",
|
||||
"# (h, w) = image.shape[:2]\n",
|
||||
"\n",
|
||||
"# if width is None and height is None:\n",
|
||||
"# return image\n",
|
||||
"# if width is None:\n",
|
||||
"# r = height / float(h)\n",
|
||||
"# dim = (int(w * r), height)\n",
|
||||
"# else:\n",
|
||||
"# r = width / float(w)\n",
|
||||
"# dim = (width, int(h * r))\n",
|
||||
"\n",
|
||||
"# return cv2.resize(image, dim, interpolation=inter)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 163,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"torch.Size([1, 4032, 3024])\n",
|
||||
"torch.Size([3, 4032, 3024])\n",
|
||||
"0.7532281875610352\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"working_dataset = ds.load_from_disk(cachepath + \"datasets/customrotation/\")\n",
|
||||
"prepimage = v2.Compose([v2.Grayscale(num_output_channels=3),v2.Resize(512), v2.CenterCrop(512),v2.ToImageTensor(), v2.ConvertImageDtype()])\n",
|
||||
"tensorize = v2.Compose([v2.ToImageTensor(), v2.ConvertImageDtype()])\n",
|
||||
"grayscaler = v2.Grayscale(num_output_channels=3)\n",
|
||||
"working_dataset.set_transform(prepimage)\n",
|
||||
"counter = np.load(savepath + \"/v\"+str(version)+\"/counter.npy\")\n",
|
||||
"model.load_state_dict(torch.load(savepath + \"/v\"+str(version)+\"/modelsave\" + str(counter) +\"epochs\"))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 165,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"torch.Size([1, 800, 723])\n",
|
||||
"torch.Size([3, 800, 723])\n",
|
||||
"-1.3860492706298828\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"filereadimage = cv2.imread(\"./testing_space/cropped.jpg\", 0)\n",
|
||||
"# print(type(filereadimage))\n",
|
||||
"tensorizedimage = torch.unsqueeze(torch.from_numpy(filereadimage),0)\n",
|
||||
"print(tensorizedimage.shape)\n",
|
||||
"adjustedtensorizedimage = tensorize(grayscaler(t.ToPILImage()(tensorizedimage)))\n",
|
||||
"print(adjustedtensorizedimage.shape)\n",
|
||||
"rotation = model(adjustedtensorizedimage).item()\n",
|
||||
"print(rotation)\n",
|
||||
"rotatedimage = t.Resize(size=1000)(tvf.rotate(adjustedtensorizedimage, rotation))\n",
|
||||
"# imS = mf.ResizeWithAspectRatio(filereadimage, 1000)\n",
|
||||
"# imS = cv2.resize(filereadimage, (960, 540)) \n",
|
||||
"open_cv_image = np.array(t.ToPILImage()(rotatedimage))\n",
|
||||
"cv2.imshow(f'image', open_cv_image)\n",
|
||||
"key = cv2.waitKey(0)\n",
|
||||
"cv2.destroyAllWindows()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"index = 0\n",
|
||||
"active_dataset = working_dataset['test']"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# plt.imshow(t.ToPILImage()(working_dataset['test'][3]['image']), cmap='gray', vmin=0, vmax=255)\n",
|
||||
"# plt.show()\n",
|
||||
"# rotationapplier = model(working_dataset['test'][3]['image']).item()\n",
|
||||
"# print(rotationapplier)\n",
|
||||
"# img = tvf.rotate(working_dataset['test'][3]['image'], rotationapplier)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# plt.imshow(t.ToPILImage()(img), cmap='gray', vmin=0, vmax=255)\n",
|
||||
"# plt.show()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# # To call the model on a bunch of the images and rotate them back\n",
|
||||
"\n",
|
||||
"# while(True):\n",
|
||||
"# activeimage = active_dataset[index]['image']\n",
|
||||
"# # img = cv2.imread(active_dataset[index]['image'], 0)\n",
|
||||
"# activeimage = tvf.rotate(activeimage, model(activeimage).item())\n",
|
||||
"# open_cv_image = np.array(t.ToPILImage()(activeimage))\n",
|
||||
"# print(index)\n",
|
||||
"# cv2.imshow(f'current image', open_cv_image)\n",
|
||||
"# key = cv2.waitKey(0)\n",
|
||||
"\n",
|
||||
"# if key == ord('c'):\n",
|
||||
"# print(\"\\tCopying this one\")\n",
|
||||
"# elif key == ord('x'):\n",
|
||||
"# index -= 1\n",
|
||||
"# elif key == ord('v'):\n",
|
||||
"# index +=1\n",
|
||||
"# elif key == ord('q'):\n",
|
||||
"# break\n",
|
||||
"\n",
|
||||
"# cv2.destroyAllWindows()\n",
|
||||
"# cv2.destroyAllWindows()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# # for trying to call the model on the picture repeatedly to see if it will just get more and more straight if it's called multiple times\n",
|
||||
"\n",
|
||||
"# currentimage = working_dataset['test'][3]['image']\n",
|
||||
"# while(True):\n",
|
||||
"# rotationapplier = model(currentimage).item()\n",
|
||||
"# print(rotationapplier)\n",
|
||||
"# img = tvf.rotate(currentimage, rotationapplier)\n",
|
||||
"# open_cv_image = np.array(t.ToPILImage()(img))\n",
|
||||
"# cv2.imshow(f'current image', open_cv_image)\n",
|
||||
"# key = cv2.waitKey(0)\n",
|
||||
" \n",
|
||||
"# if key == ord('q'):\n",
|
||||
"# break\n",
|
||||
"# elif key == ord('v'):\n",
|
||||
"# currentimage = img\n",
|
||||
"# # cv2.destroyAllWindows()\n",
|
||||
"# cv2.destroyAllWindows()"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.12"
|
||||
},
|
||||
"orig_nbformat": 4
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
@ -1,417 +0,0 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 18,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# from datasets import load_dataset, Image\n",
|
||||
"import datasets as ds\n",
|
||||
"import PIL\n",
|
||||
"import torchvision.transforms.functional as tvf\n",
|
||||
"from torchvision.transforms import v2\n",
|
||||
"import random\n",
|
||||
"import numpy as np"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 19,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"original_dataset = ds.load_dataset(\"aharley/rvl_cdip\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 20,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Create own dataset from the images of the original dataset but make the labels the float value for the rotation. do the random rotation on all of the training ones but the labels for the validation and test should/can be 0\n",
|
||||
"trainblacklist = [5, 102664, 102667, 277943]\n",
|
||||
"testblacklist = [6, 11, 14, 18, 27, 35, 37, 54, 33669] # 33669 is a corrupt image\n",
|
||||
"validationblacklist = []\n",
|
||||
"og_training_dataset = original_dataset['train'].select([i for i in range(len(original_dataset['train'])) if i not in trainblacklist])\n",
|
||||
"og_testing_dataset = original_dataset['test'].select([i for i in range(len(original_dataset['test'])) if i not in testblacklist])\n",
|
||||
"og_validation_dataset = original_dataset['validation'].select([i for i in range(len(original_dataset['validation'])) if i not in validationblacklist])\n",
|
||||
"\n",
|
||||
"# type(og_testing_dataset)\n",
|
||||
"\n",
|
||||
"# print(type(transform_picture(og_testing_dataset[0], params)))\n",
|
||||
"# out = transform_picture(og_testing_dataset[0], params)\n",
|
||||
"# print(out['image'])\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 21,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"319998\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print(len(og_training_dataset))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def has_valid_image(ex):\n",
|
||||
" print(type(ex))\n",
|
||||
" try:\n",
|
||||
" PIL.Image.open(ex[\"image\"][\"path\"])\n",
|
||||
" except Exception:\n",
|
||||
" print(\"hi\")\n",
|
||||
" return False\n",
|
||||
" return True\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# dataset = original_dataset.cast_column(\"image\", ds.Image(decode=False))\n",
|
||||
"# dataset = dataset.filter(has_valid_image)\n",
|
||||
"# filtered_dataset = dataset.cast_column(\"image\", ds.Image(decode=True))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Parameter Declaration\n",
|
||||
"minRotation=-180\n",
|
||||
"maxRotation=180\n",
|
||||
"minTranslation=0\n",
|
||||
"maxTranslation=150\n",
|
||||
"minScale = 0.4\n",
|
||||
"maxScale = 1\n",
|
||||
"minShear = 0\n",
|
||||
"maxShear = 0\n",
|
||||
"\n",
|
||||
"minFill=255\n",
|
||||
"maxFill=255\n",
|
||||
"\n",
|
||||
"params = {\"minRotation\":minRotation,\"maxRotation\":maxRotation,\"minTranslation\":minTranslation,\"maxTranslation\":maxTranslation,\"minScale\":minScale,\"maxScale\":maxScale,\"minShear\":minShear,\"maxShear\":maxShear,\"minFill\":minFill,\"maxFill\":maxFill}"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class SquarePad:\n",
|
||||
" def __init__(self, fill):\n",
|
||||
" self.fill = fill\n",
|
||||
" \n",
|
||||
" def __call__(self, image):\n",
|
||||
" w, h = image.size\n",
|
||||
" max_wh = np.max([w, h])\n",
|
||||
" hp = int((max_wh - w) / 2)\n",
|
||||
" vp = int((max_wh - h) / 2)\n",
|
||||
" padding = (hp, vp, hp, vp)\n",
|
||||
" return tvf.pad(image, padding,fill=self.fill, padding_mode='constant')\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def transform_picture(image_label, parameters):\n",
|
||||
" image = image_label['image']\n",
|
||||
"\n",
|
||||
" appliedRotation = random.uniform(parameters['minRotation'], parameters['maxRotation'])\n",
|
||||
" appliedXTranslation = random.uniform(parameters['minTranslation'], parameters['maxTranslation'])\n",
|
||||
" appliedYTranslation = random.uniform(parameters['minTranslation'], parameters['maxTranslation'])\n",
|
||||
" appliedScale = random.uniform(parameters['minScale'], parameters['maxScale'])\n",
|
||||
" appliedFill = random.uniform(parameters['minFill'], parameters['maxFill'])\n",
|
||||
" appliedXShear = random.uniform(parameters['minShear'], parameters['maxShear'])\n",
|
||||
" appliedYShear = random.uniform(parameters['minShear'], parameters['maxShear'])\n",
|
||||
" \n",
|
||||
" appliers = v2.Compose([v2.RandomApply(transforms=[v2.RandomPosterize(bits=1)], p=0.25),\n",
|
||||
" v2.RandomApply(transforms=[v2.ElasticTransform(alpha=25.0, fill=appliedFill)], p=0.25), # maybe add fill=appliedFill\n",
|
||||
" v2.RandomApply(transforms=[v2.GaussianBlur(kernel_size=(5,9), sigma=(0.1,2.))],p=0.25),\n",
|
||||
" v2.RandomApply(transforms=[v2.RandomEqualize()],p=0.25),\n",
|
||||
" SquarePad(fill=appliedFill),v2.Resize(1100)])\n",
|
||||
" \n",
|
||||
" adjustedimage = tvf.affine(image, appliedRotation, [appliedXTranslation,appliedYTranslation], appliedScale, [appliedXShear, appliedYShear], fill=appliedFill)\n",
|
||||
"\n",
|
||||
" adjustedimage = appliers(adjustedimage)\n",
|
||||
"\n",
|
||||
" \n",
|
||||
" return {'image':adjustedimage,'rotation':appliedRotation}"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "bd95dc0201c2419e982f8167e16db6b5",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"Map (num_proc=4): 0%| | 0/39999 [00:00<?, ? examples/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"new_testing_dataset = og_testing_dataset.map(transform_picture, fn_kwargs={'parameters':params}, num_proc=4)\n",
|
||||
"#33669 has bad EXIF data so it is ignored at load time"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "17c8b6a170ae4072b385b6d3e965d9e8",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"Map (num_proc=4): 0%| | 0/40000 [00:00<?, ? examples/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"new_validation_dataset = og_validation_dataset.map(transform_picture, fn_kwargs={'parameters':params}, num_proc=4)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "32d7adecb479420cb8a4b3eee898ec1b",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"Map (num_proc=4): 0%| | 0/320000 [00:00<?, ? examples/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"new_training_dataset = og_training_dataset.map(transform_picture, fn_kwargs={'parameters':params}, num_proc=4)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def setlabelname(entry):\n",
|
||||
" return {'image':entry['image'], 'rotation':entry['label']}"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# new_testing_dataset = new_testing_dataset.map(setlabelname, num_proc=4, batch_size=700, batched=True, writer_batch_size=700)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# new_training_dataset = new_training_dataset.remove_columns(\"label\")\n",
|
||||
"# new_testing_dataset = new_testing_dataset.remove_columns(\"label\")\n",
|
||||
"# new_validation_dataset = new_validation_dataset.remove_columns(\"label\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"new_dataset = ds.DatasetDict({'train': new_training_dataset,'test': new_testing_dataset, 'validation': new_validation_dataset})\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"DatasetDict({\n",
|
||||
" train: Dataset({\n",
|
||||
" features: ['image', 'label', 'rotation'],\n",
|
||||
" num_rows: 320000\n",
|
||||
" })\n",
|
||||
" test: Dataset({\n",
|
||||
" features: ['image', 'label', 'rotation'],\n",
|
||||
" num_rows: 39999\n",
|
||||
" })\n",
|
||||
" validation: Dataset({\n",
|
||||
" features: ['image', 'label', 'rotation'],\n",
|
||||
" num_rows: 40000\n",
|
||||
" })\n",
|
||||
"})\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print(new_dataset)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 15,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"new_dataset = new_dataset.remove_columns(\"label\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 16,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"DatasetDict({\n",
|
||||
" train: Dataset({\n",
|
||||
" features: ['image', 'rotation'],\n",
|
||||
" num_rows: 320000\n",
|
||||
" })\n",
|
||||
" test: Dataset({\n",
|
||||
" features: ['image', 'rotation'],\n",
|
||||
" num_rows: 39999\n",
|
||||
" })\n",
|
||||
" validation: Dataset({\n",
|
||||
" features: ['image', 'rotation'],\n",
|
||||
" num_rows: 40000\n",
|
||||
" })\n",
|
||||
"})\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print(new_dataset)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 17,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "a614c8b5206649f0b774dc25909bca75",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"Saving the dataset (0/65 shards): 0%| | 0/320000 [00:00<?, ? examples/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "bc7984d6eafe443aa4980e43350fee03",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"Saving the dataset (0/9 shards): 0%| | 0/39999 [00:00<?, ? examples/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "ab4cc0e859bd46599345129fd70bcb37",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"Saving the dataset (0/9 shards): 0%| | 0/40000 [00:00<?, ? examples/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"new_dataset.save_to_disk(\"../.cache/huggingfaces/datasets/customrotation/\", max_shard_size=\"500MB\")"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.12"
|
||||
},
|
||||
"orig_nbformat": 4
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
@ -1,159 +0,0 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# from datasets import load_dataset, Image\n",
|
||||
"import datasets as ds\n",
|
||||
"import PIL\n",
|
||||
"import torchvision.transforms.functional as tvf\n",
|
||||
"from torchvision.transforms import v2\n",
|
||||
"import random\n",
|
||||
"import numpy as np\n",
|
||||
"\n",
|
||||
"import torchvision.utils as utils\n",
|
||||
"\n",
|
||||
"from tqdm.autonotebook import tqdm"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"original_dataset = ds.load_dataset(\"aharley/rvl_cdip\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Create own dataset from the images of the original dataset but make the labels the float value for the rotation. do the random rotation on all of the training ones but the labels for the validation and test should/can be 0\n",
|
||||
"trainblacklist = []\n",
|
||||
"testblacklist = [33669] # index 33669 is just corrupted\n",
|
||||
"validationblacklist = []\n",
|
||||
"og_training_dataset = original_dataset['train'].select([i for i in range(len(original_dataset['train'])) if i not in trainblacklist])\n",
|
||||
"og_testing_dataset = original_dataset['test'].select([i for i in range(len(original_dataset['test'])) if i not in testblacklist])\n",
|
||||
"og_validation_dataset = original_dataset['validation'].select([i for i in range(len(original_dataset['validation'])) if i not in validationblacklist])\n",
|
||||
"\n",
|
||||
"tensorize = v2.Compose([v2.ToImageTensor(), v2.ConvertImageDtype()])\n",
|
||||
"\n",
|
||||
"og_training_dataset.set_transform(tensorize)\n",
|
||||
"og_testing_dataset.set_transform(tensorize)\n",
|
||||
"og_validation_dataset.set_transform(tensorize)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "755255ae5bea49cc866c96f0d291b570",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
" 0%| | 0/39999 [00:00<?, ?it/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"pbar = tqdm(og_testing_dataset)\n",
|
||||
"\n",
|
||||
"for i, entry in enumerate(pbar):\n",
|
||||
" index = i\n",
|
||||
" if (i >= 33669):\n",
|
||||
" index = index + 1\n",
|
||||
" utils.save_image(entry['image'], \"./datasetimages/test/\"+str(index)+\".jpg\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "88288b649a64430bb52e2ae5720e4b1f",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
" 0%| | 0/320000 [00:00<?, ?it/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"pbar = tqdm(og_training_dataset)\n",
|
||||
"\n",
|
||||
"for i, entry in enumerate(pbar):\n",
|
||||
" utils.save_image(entry['image'], \"./datasetimages/train/\"+str(i)+\".jpg\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "ed6ce8bc3d224f278df6723fc0c41d72",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
" 0%| | 0/40000 [00:00<?, ?it/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"pbar = tqdm(og_validation_dataset)\n",
|
||||
"\n",
|
||||
"for i, entry in enumerate(pbar):\n",
|
||||
" utils.save_image(entry['image'], \"./datasetimages/validation/\"+str(i)+\".jpg\")"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.12"
|
||||
},
|
||||
"orig_nbformat": 4
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
@ -1,430 +0,0 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"## ORIGINAL FILE FOR SELECTIVE SEGMENTATION SEARCH"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 350,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import cv2\n",
|
||||
"import numpy as np\n",
|
||||
"from queue import PriorityQueue\n",
|
||||
"import myfunctions as mf\n",
|
||||
"\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"import random"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 351,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# def ResizeWithAspectRatio(image, width=None, height=None, inter=cv2.INTER_AREA):\n",
|
||||
"# dim = None\n",
|
||||
"# (h, w) = image.shape[:2]\n",
|
||||
"\n",
|
||||
"# if width is None and height is None:\n",
|
||||
"# return image\n",
|
||||
"# if width is None:\n",
|
||||
"# r = height / float(h)\n",
|
||||
"# dim = (int(w * r), height)\n",
|
||||
"# else:\n",
|
||||
"# r = width / float(w)\n",
|
||||
"# dim = (width, int(h * r))\n",
|
||||
"\n",
|
||||
"# return cv2.resize(image, dim, interpolation=inter)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 352,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import heapq as hq\n",
|
||||
"\n",
|
||||
"class MaxHeapObj(object):\n",
|
||||
" def __init__(self, val): self.val = val\n",
|
||||
" def __lt__(self, other): return self.val > other.val\n",
|
||||
" def __eq__(self, other): return self.val == other.val\n",
|
||||
" def __str__(self): return str(self.val)\n",
|
||||
" \n",
|
||||
"class MinHeap(object):\n",
|
||||
" def __init__(self): self.h = []\n",
|
||||
" def heappush(self, x): heapq.heappush(self.h, x)\n",
|
||||
" def heappop(self): return heapq.heappop(self.h)\n",
|
||||
" def __getitem__(self, i): return self.h[i]\n",
|
||||
" def __len__(self): return len(self.h)\n",
|
||||
" \n",
|
||||
"class MaxHeap(MinHeap):\n",
|
||||
" def heappush(self, x): heapq.heappush(self.h, MaxHeapObj(x))\n",
|
||||
" def heappop(self): return heapq.heappop(self.h).val\n",
|
||||
" def __getitem__(self, i): return self.h[i].val"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 353,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# def clip(n, lower, upper):\n",
|
||||
"# return max(lower, min(n, upper))\n",
|
||||
"\n",
|
||||
"# def colourscaler(n, min, max):\n",
|
||||
"# temp = n-min\n",
|
||||
"# diff = abs(max - min)\n",
|
||||
"# return clip((temp/diff)*255, 0, 255)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 354,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# inline double clip(double n, double lower, double upper) {\n",
|
||||
"# return std::max(lower, std::min(n, upper));\n",
|
||||
"# };\n",
|
||||
"\n",
|
||||
"# inline double colourscaler(double n, double min, double max) {\n",
|
||||
"# double temp = n - min;\n",
|
||||
"# double diff = std::abs(max - min);\n",
|
||||
"# return clip((temp / diff) * 255, 0, 255);\n",
|
||||
"# };"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 355,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# ## Test this code for the masking/colour squishing. it essentially can just speed up clipping the edges.\n",
|
||||
"# #!/usr/local/bin/python3\n",
|
||||
"# import cv2 as cv\n",
|
||||
"# import numpy as np\n",
|
||||
"\n",
|
||||
"# # Load the aerial image and convert to HSV colourspace\n",
|
||||
"# image = cv.imread(\"aerial.png\")\n",
|
||||
"# hsv=cv.cvtColor(image,cv.COLOR_BGR2HSV)\n",
|
||||
"\n",
|
||||
"# # Define lower and uppper limits of what we call \"brown\"\n",
|
||||
"# brown_lo=np.array([10,0,0])\n",
|
||||
"# brown_hi=np.array([20,255,255])\n",
|
||||
"\n",
|
||||
"# # Mask image to only select browns\n",
|
||||
"# mask=cv.inRange(hsv,brown_lo,brown_hi)\n",
|
||||
"\n",
|
||||
"# # Change image to red where we found brown\n",
|
||||
"# image[mask>0]=(0,0,255)\n",
|
||||
"\n",
|
||||
"# cv.imwrite(\"result.png\",image)\n",
|
||||
"\n",
|
||||
"#CAN ALSO TRY USING NUMPY VECTORIZATION"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 356,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# def rotate(img, angle):\n",
|
||||
"# rows,cols = img.shape[0], img.shape[1]\n",
|
||||
"# M = cv2.getRotationMatrix2D((cols/2,rows/2),angle,1)\n",
|
||||
"# dst = cv2.warpAffine(img,M,(cols,rows))\n",
|
||||
"# return dst"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 357,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def crop(image, lower = 100, upper = 255, threshold1 = 50, threshold2 = 350):\n",
|
||||
" lower = max(0,lower)\n",
|
||||
" upper = min(255, upper)\n",
|
||||
" gray = cv2.cvtColor(image,cv2.COLOR_BGR2GRAY)\n",
|
||||
"\n",
|
||||
" scaled_gray = np.zeros(gray.shape, gray.dtype)\n",
|
||||
" \n",
|
||||
" # for y in range(0,gray.shape[0]):\n",
|
||||
" # for x in range(0,gray.shape[1]):\n",
|
||||
" # scaled_gray[y][x] = colourscaler(gray[y][x], lower, upper)\n",
|
||||
" scaled_gray = gray\n",
|
||||
" \n",
|
||||
" blurred = cv2.GaussianBlur(scaled_gray, (15,15),0)\n",
|
||||
" # blurred = scaled_gray\n",
|
||||
" edged = cv2.Canny(blurred, threshold1, threshold2)\n",
|
||||
" # meangrayscale = cv2.mean(scaled_gray)[0]\n",
|
||||
" # print(meangrayscale)\n",
|
||||
" # edged = cv2.Canny(blurred, int(meangrayscale*2), int(meangrayscale*4))\n",
|
||||
" return edged\n",
|
||||
" "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 358,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def selectiveSearchSegmentationImp(image):\n",
|
||||
" ss = cv2.ximgproc.segmentation.createSelectiveSearchSegmentation()\n",
|
||||
" ss.setBaseImage(image)\n",
|
||||
" ss.switchToSelectiveSearchFast()\n",
|
||||
" return ss.process()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 359,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"img = cv2.imread('./testing_space/final.jpg')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 360,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# def rectArea(rect):\n",
|
||||
"# # print(rect)\n",
|
||||
"# return rect[2]*rect[3]\n",
|
||||
"\n",
|
||||
"# def biggestRects(n, rects):\n",
|
||||
"# dict = {}\n",
|
||||
"# # outrects = np.zeros(shape=(n, 4))\n",
|
||||
"# for rect in rects:\n",
|
||||
"# dict[tuple(rect)] = mf.rectArea(rect)\n",
|
||||
"# # maxh.heappush(mf.rectArea(rect))\n",
|
||||
"# # print(maxh[0])\n",
|
||||
" \n",
|
||||
" \n",
|
||||
"# heap = [(-value, key) for key,value in dict.items()]\n",
|
||||
"# largest = hq.nsmallest(n, heap)\n",
|
||||
" \n",
|
||||
"\n",
|
||||
"# # hq.heapify(list(dict.items()))\n",
|
||||
"# # for i in range(0,n):\n",
|
||||
"# # outrects[i] = maxh.heappop()\n",
|
||||
"# # print(outrects)\n",
|
||||
"# return [key for value, key in largest]\n",
|
||||
"\n",
|
||||
"# def overlapRect(rects):\n",
|
||||
"# leftwall = -1\n",
|
||||
"# rightwall = -1\n",
|
||||
"# topwall = -1\n",
|
||||
"# bottomwall = -1\n",
|
||||
"# for (x, y, w, h) in rects:\n",
|
||||
"# if (leftwall == -1):\n",
|
||||
"# leftwall = x\n",
|
||||
"# rightwall = x + w\n",
|
||||
"# topwall = y\n",
|
||||
"# bottomwall = y + h\n",
|
||||
"# continue\n",
|
||||
"# leftwall = max(leftwall, x)\n",
|
||||
"# rightwall = min(rightwall, x+w)\n",
|
||||
"# topwall = max(topwall, y)\n",
|
||||
"# bottomwall = min(bottomwall, y+h)\n",
|
||||
" \n",
|
||||
"# if (topwall >= bottomwall or leftwall >= rightwall):\n",
|
||||
"# return (-1, -1, -1, -1)\n",
|
||||
"# return (leftwall, topwall, rightwall-leftwall, bottomwall-topwall)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 344,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"(-1, -1, -1, -1)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# rect = crop(img)\n",
|
||||
"\n",
|
||||
"# _, thresholded = cv2.threshold(cv2.cvtColor(img, cv2.COLOR_BGR2GRAY), 200, 255, cv2.THRESH_BINARY)\n",
|
||||
"\n",
|
||||
"rects = selectiveSearchSegmentationImp(cv2.GaussianBlur(ResizeWithAspectRatio(img,300), (15,15),0))\n",
|
||||
"# mf.rectArea(rects[0])\n",
|
||||
"bigRects = mf.biggestRects(20, rects)\n",
|
||||
"# print(bigRects)\n",
|
||||
"\n",
|
||||
"finalrect = mf.overlapRect(bigRects)\n",
|
||||
"print(finalrect)\n",
|
||||
"output = ResizeWithAspectRatio(img,300)\n",
|
||||
"for (x, y, w, h) in [finalrect]:\n",
|
||||
"\t\t# draw the region proposal bounding box on the image\n",
|
||||
"\t\tcolor = [random.randint(0, 255) for j in range(0, 3)]\n",
|
||||
"\t\tcv2.rectangle(output, (x, y), (x + w, y + h), color, 2)\n",
|
||||
"\n",
|
||||
"# edges = cv2.Canny(cv2.GaussianBlur(cv2.cvtColor(img, cv2.COLOR_BGR2GRAY), (15,15),0),255 / 4, 255)\n",
|
||||
"\n",
|
||||
"# plt.imshow(edges, cmap='gray', vmin=0, vmax=255)\n",
|
||||
"# plt.show()\n",
|
||||
"\n",
|
||||
"cv2.imshow(\"banana\", output)\n",
|
||||
"cv2.waitKey(0)\n",
|
||||
"cv2.destroyAllWindows()\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# print(range(0,img.shape[1]))\n",
|
||||
"# for i in range(0,img.shape[1]):\n",
|
||||
"# print(i)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 389,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"1\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"temp = ResizeWithAspectRatio(crop(img, threshold1=150, threshold2=350),500)\n",
|
||||
"contours, _ = cv2.findContours(temp, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)\n",
|
||||
"# print(type(contours))\n",
|
||||
"# max(cv2.contourArea(contours))\n",
|
||||
"# areas = list(map(cv2.contourArea, contours))\n",
|
||||
"# print(areas)\n",
|
||||
"contourindex = np.argmax(list(map(cv2.contourArea, contours)))\n",
|
||||
"temp = cv2.drawContours(temp, contours, contourindex, (255,0,0), 2)\n",
|
||||
"cv2.imshow(\"banana\", temp)\n",
|
||||
"cv2.waitKey(0)\n",
|
||||
"cv2.destroyAllWindows()\n",
|
||||
"print(contourindex)\n",
|
||||
"rect = cv2.boundingRect(contours[contourindex])\n",
|
||||
"color = (random.randint(0,256), random.randint(0,256), random.randint(0,256))\n",
|
||||
"result = cv2.rectangle(ResizeWithAspectRatio(img,500), rect, color, 3)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 362,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# print(contourindex)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 371,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"cv2.imshow(\"banana\", result)\n",
|
||||
"cv2.waitKey(0)\n",
|
||||
"cv2.destroyAllWindows()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 348,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# HSV = cv2.cvtColor(ResizeWithAspectRatio(img,500), cv2.COLOR_BGR2HSV)\n",
|
||||
"# low = np.array([0,0,10])\n",
|
||||
"# high = np.array([179,10,255])\n",
|
||||
"\n",
|
||||
"# mask = cv2.inRange(HSV,low,high)\n",
|
||||
"\n",
|
||||
"# cv2.imshow(\"banana\", mask)\n",
|
||||
"# cv2.waitKey(0)\n",
|
||||
"# cv2.destroyAllWindows()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 349,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
" # cv::Mat gray, scaled_gray, blurred, edged;\n",
|
||||
"\n",
|
||||
" # lower = std::max(lower, 0);\n",
|
||||
" # upper = std::min(upper, 255);\n",
|
||||
"\n",
|
||||
" # cv::cvtColor(src, gray, cv::COLOR_BGR2GRAY);\n",
|
||||
" # scaled_gray = cv::Mat::zeros(gray.size(), gray.type());\n",
|
||||
"\n",
|
||||
" # for (int y = 0; y < gray.rows; y++) {\n",
|
||||
" # for (int x = 0; x < gray.cols; x++) {\n",
|
||||
" # scaled_gray.at<uchar>(y, x) =\n",
|
||||
" # cv::saturate_cast<uchar>(colourscaler(gray.at<uchar>(y, x), lower, upper));\n",
|
||||
" # }\n",
|
||||
" # }\n",
|
||||
"\n",
|
||||
" # cv::GaussianBlur(scaled_gray, blurred, cv::Size(15, 15), 0);\n",
|
||||
" # cv::Canny(blurred, edged, threshold1, threshold2);\n",
|
||||
"\n",
|
||||
" # std::vector<std::vector<cv::Point>> contours;\n",
|
||||
" # std::vector<cv::Vec4i> heirarchy;\n",
|
||||
" # cv::Mat approx;\n",
|
||||
"\n",
|
||||
" # cv::findContours(edged, contours, heirarchy, cv::RETR_TREE, cv::CHAIN_APPROX_SIMPLE);\n",
|
||||
"\n",
|
||||
" # cv::cvtColor(gray, gray, cv::COLOR_GRAY2BGR);\n",
|
||||
"\n",
|
||||
" # std::sort(contours.begin(), contours.end(), [](std::vector<cv::Point> a, std::vector<cv::Point> b) {\n",
|
||||
" # return cv::arcLength(a, false) > cv::arcLength(b, false); });\n",
|
||||
"\n",
|
||||
" # int numContours = contours.size();\n",
|
||||
"\n",
|
||||
"\n",
|
||||
" # return cv::boundingRect(contours[0]);"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.12"
|
||||
},
|
||||
"orig_nbformat": 4
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
@ -1,94 +0,0 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 17,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import cv2\n",
|
||||
"import numpy as np\n",
|
||||
"\n",
|
||||
"import torch\n",
|
||||
"import torchvision.transforms.functional as tvf\n",
|
||||
"import torchvision.transforms.v2 as v2\n",
|
||||
"import torchvision.transforms as t\n",
|
||||
"import myfunctions as mf\n",
|
||||
"\n",
|
||||
"from skimage import io\n",
|
||||
"from matplotlib import pyplot as plt\n",
|
||||
"import time\n",
|
||||
"\n",
|
||||
"import myfunctions as mf"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 18,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# read image as grayscale\n",
|
||||
"img = cv2.imread('./test_images/IMG_7594.jpg')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 19,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"True"
|
||||
]
|
||||
},
|
||||
"execution_count": 19,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"cropped = mf.morphologyCrop(img)\n",
|
||||
"# rotated = deskew(cropped)\n",
|
||||
"# cropped2 = morphologyCrop(rotated)\n",
|
||||
"# cropped2 = selectiveSearchCrop(rotated)\n",
|
||||
"# cropped3 = cannyEdgeCrop(cropped2)\n",
|
||||
"cv2.imwrite(\"./testing_space/final.jpg\", cropped)\n",
|
||||
"# final = rotate(cropped2, 180) # need to implement the code to determine if a doc is upside down"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 20,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"### Deskew seems to work \n",
|
||||
"# Note licencing for the deskew package and "
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.12"
|
||||
},
|
||||
"orig_nbformat": 4
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
@ -1,316 +0,0 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# can probably be deleted or put somewhere. Was the original code for the rowsumdeskew"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 36,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import cv2\n",
|
||||
"import numpy as np\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 37,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"\n",
|
||||
"\n",
|
||||
"src = 255 - cv2.imread('./testing_space/cropped1.jpg',0)\n",
|
||||
"scores = []\n",
|
||||
"\n",
|
||||
"h,w = src.shape\n",
|
||||
"small_dimention = min(h,w)\n",
|
||||
"src = src[:small_dimention, :small_dimention]\n",
|
||||
"\n",
|
||||
"out = cv2.VideoWriter('./temp/video.avi',\n",
|
||||
" cv2.VideoWriter_fourcc('M','J','P','G'),\n",
|
||||
" 15, (320,320))\n",
|
||||
"\n",
|
||||
"src = cv2.threshold(src, 100, 255, cv2.THRESH_BINARY)[1]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 38,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def rotate(img, angle):\n",
|
||||
" rows,cols = img.shape\n",
|
||||
" M = cv2.getRotationMatrix2D((cols/2,rows/2),angle,1)\n",
|
||||
" dst = cv2.warpAffine(img,M,(cols,rows))\n",
|
||||
" return dst\n",
|
||||
"\n",
|
||||
"def sum_rows(img):\n",
|
||||
" # Create a list to store the row sums\n",
|
||||
" row_sums = []\n",
|
||||
" # Iterate through the rows\n",
|
||||
" for r in range(img.shape[0]-1):\n",
|
||||
" # Sum the row\n",
|
||||
" row_sum = sum(sum(img[r:r+1,:]))\n",
|
||||
" # Add the sum to the list\n",
|
||||
" row_sums.append(row_sum)\n",
|
||||
" # Normalize range to (0,255)\n",
|
||||
" row_sums = (row_sums/max(row_sums)) * 255\n",
|
||||
" # Return\n",
|
||||
" return row_sums\n",
|
||||
"\n",
|
||||
"def display_data(roi, row_sums, buffer): \n",
|
||||
" # Create background to draw transform on\n",
|
||||
" bg = np.zeros((buffer*2, buffer*2), np.uint8) \n",
|
||||
" # Iterate through the rows and draw on the background\n",
|
||||
" for row in range(roi.shape[0]-1):\n",
|
||||
" row_sum = row_sums[row]\n",
|
||||
" bg[row:row+1, :] = row_sum\n",
|
||||
" left_side = int(buffer/3)\n",
|
||||
" bg[:, left_side:] = roi[:,left_side:] \n",
|
||||
" cv2.imshow('bg1', bg)\n",
|
||||
" k = cv2.waitKey(1)\n",
|
||||
" out.write(cv2.cvtColor(cv2.resize(bg, (320,320)), cv2.COLOR_GRAY2BGR))\n",
|
||||
" return k\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 39,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"count = 0\n",
|
||||
"othercount = 0\n",
|
||||
"goodangle = 0"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 40,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# cv2.imshow('bg1', src)\n",
|
||||
"# cv2.waitKey(0)\n",
|
||||
"# cv2.destroyAllWindows()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 41,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"found optimal rotation\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"found optimal rotation\n",
|
||||
"found optimal rotation\n",
|
||||
"found optimal rotation\n",
|
||||
"found optimal rotation\n",
|
||||
"found optimal rotation\n",
|
||||
"found optimal rotation\n",
|
||||
"found optimal rotation\n",
|
||||
"found optimal rotation\n",
|
||||
"found optimal rotation\n",
|
||||
"found optimal rotation\n",
|
||||
"found optimal rotation\n",
|
||||
"found optimal rotation\n",
|
||||
"found optimal rotation\n",
|
||||
"found optimal rotation\n",
|
||||
"found optimal rotation\n",
|
||||
"found optimal rotation\n",
|
||||
"found optimal rotation\n",
|
||||
"found optimal rotation\n",
|
||||
"found optimal rotation\n",
|
||||
"found optimal rotation\n",
|
||||
"found optimal rotation\n",
|
||||
"found optimal rotation\n",
|
||||
"found optimal rotation\n",
|
||||
"found optimal rotation\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Rotate the image around in a circle\n",
|
||||
"angle = 0\n",
|
||||
"while angle <= 360:\n",
|
||||
" # Rotate the source image\n",
|
||||
" img = rotate(src, angle) \n",
|
||||
" # Crop the center 1/3rd of the image (roi is filled with text)\n",
|
||||
" h,w = img.shape\n",
|
||||
" buffer = min(h, w) - int(min(h,w)/1.5)\n",
|
||||
" roi = img[int(h/2-buffer):int(h/2+buffer), int(w/2-buffer):int(w/2+buffer)]\n",
|
||||
" # Create background to draw transform on\n",
|
||||
" bg = np.zeros((buffer*2, buffer*2), np.uint8)\n",
|
||||
" # Compute the sums of the rows\n",
|
||||
" row_sums = sum_rows(roi)\n",
|
||||
" # High score --> Zebra stripes\n",
|
||||
" score = np.count_nonzero(row_sums)\n",
|
||||
" scores.append(score)\n",
|
||||
" othercount = othercount + 1\n",
|
||||
" # Image has best rotation\n",
|
||||
" if score <= min(scores):\n",
|
||||
" count = count + 1\n",
|
||||
" # Save the rotatied image\n",
|
||||
" print('found optimal rotation')\n",
|
||||
" best_rotation = img.copy()\n",
|
||||
" goodangle = angle\n",
|
||||
" k = display_data(roi, row_sums, buffer)\n",
|
||||
" if k == 27: break\n",
|
||||
" # Increment angle and try again\n",
|
||||
" angle += .75\n",
|
||||
"cv2.destroyAllWindows()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 42,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"25\n",
|
||||
"481\n",
|
||||
"349.5\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print(count)\n",
|
||||
"print(othercount)\n",
|
||||
"print(goodangle)\n",
|
||||
"cv2.imshow('bg1', best_rotation)\n",
|
||||
"cv2.waitKey(0)\n",
|
||||
"cv2.destroyAllWindows()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print(\"start\")\n",
|
||||
"\n",
|
||||
"# Rotate the image around in a circle\n",
|
||||
"angle = 0\n",
|
||||
"while angle <= 360:\n",
|
||||
" # Rotate the source image\n",
|
||||
" img = rotate(src, angle) \n",
|
||||
" # Crop the center 1/3rd of the image (roi is filled with text)\n",
|
||||
" h,w = img.shape\n",
|
||||
" buffer = min(h, w) - int(min(h,w)/1.5)\n",
|
||||
" #roi = img.copy()\n",
|
||||
" roi = img[int(h/2-buffer):int(h/2+buffer), int(w/2-buffer):int(w/2+buffer)]\n",
|
||||
" # Create background to draw transform on\n",
|
||||
" bg = np.zeros((buffer*2, buffer*2), np.uint8)\n",
|
||||
" # Threshold image\n",
|
||||
" _, roi = cv2.threshold(roi, 140, 255, cv2.THRESH_BINARY)\n",
|
||||
" # Compute the sums of the rows\n",
|
||||
" row_sums = sum_rows(roi)\n",
|
||||
" # High score --> Zebra stripes\n",
|
||||
" score = np.count_nonzero(row_sums)\n",
|
||||
" if sum(row_sums) < 100000: scores.append(angle)\n",
|
||||
" k = display_data(roi, row_sums, buffer)\n",
|
||||
" if k == 27: break\n",
|
||||
" # Increment angle and try again\n",
|
||||
" angle += .5\n",
|
||||
" print(\"loop\")\n",
|
||||
"cv2.destroyAllWindows()\n",
|
||||
"\n",
|
||||
"print(\"endofrotate\")\n",
|
||||
"\n",
|
||||
"# Create images for display purposes\t\n",
|
||||
"display = src.copy()\n",
|
||||
"# Create an image that contains bins. \n",
|
||||
"bins_image = np.zeros_like(display)\n",
|
||||
"for angle in scores:\n",
|
||||
" # Rotate the image and draw a line on it\n",
|
||||
" display = rotate(display, angle) \n",
|
||||
" cv2.line(display, (0,int(h/2)), (w,int(h/2)), 255, 1)\n",
|
||||
" display = rotate(display, -angle)\n",
|
||||
" # Rotate the bins image\n",
|
||||
" bins_image = rotate(bins_image, angle)\n",
|
||||
" # Draw a line on a temporary image\n",
|
||||
" temp = np.zeros_like(bins_image)\n",
|
||||
" cv2.line(temp, (0,int(h/2)), (w,int(h/2)), 50, 1)\n",
|
||||
" # 'Fill' up the bins\n",
|
||||
" bins_image += temp\n",
|
||||
" bins_image = rotate(bins_image, -angle)\n",
|
||||
" \n",
|
||||
"print(\"endofbins\")\n",
|
||||
"\n",
|
||||
"# Find the most filled bin\n",
|
||||
"for col in range(bins_image.shape[0]-1):\n",
|
||||
"\tcolumn = bins_image[:, col:col+1]\n",
|
||||
"\tif np.amax(column) == np.amax(bins_image): x = col\n",
|
||||
"for col in range(bins_image.shape[0]-1):\n",
|
||||
"\tcolumn = bins_image[:, col:col+1]\n",
|
||||
"\tif np.amax(column) == np.amax(bins_image): y = col\n",
|
||||
"# Draw circles showing the most filled bin\n",
|
||||
"cv2.circle(display, (x,y), 560, 255, 5)\n",
|
||||
"\n",
|
||||
"print(\"plotting\")\n",
|
||||
"\n",
|
||||
"# Plot with Matplotlib\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"import matplotlib.image as mpimg\n",
|
||||
"f, axarr = plt.subplots(1,3, sharex=True)\n",
|
||||
"axarr[0].imshow(src)\n",
|
||||
"axarr[1].imshow(display)\n",
|
||||
"axarr[2].imshow(bins_image)\n",
|
||||
"axarr[0].set_title('Source Image')\n",
|
||||
"axarr[1].set_title('Output')\n",
|
||||
"axarr[2].set_title('Bins Image')\n",
|
||||
"axarr[0].axis('off')\n",
|
||||
"axarr[1].axis('off')\n",
|
||||
"axarr[2].axis('off')\n",
|
||||
"plt.show()\n",
|
||||
"\n",
|
||||
"cv2.waitKey()\n",
|
||||
"cv2.destroyAllWindows()"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.12"
|
||||
},
|
||||
"orig_nbformat": 4
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -1,777 +0,0 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/usr/local/lib/python3.10/dist-packages/torchvision/datapoints/__init__.py:12: UserWarning: The torchvision.datapoints and torchvision.transforms.v2 namespaces are still Beta. While we do not expect major breaking changes, some APIs may still change according to user feedback. Please submit any feedback you may have in this issue: https://github.com/pytorch/vision/issues/6753, and you can also check out https://github.com/pytorch/vision/issues/7319 to learn more about the APIs that we suspect might involve future changes. You can silence this warning by calling torchvision.disable_beta_transforms_warning().\n",
|
||||
" warnings.warn(_BETA_TRANSFORMS_WARNING)\n",
|
||||
"/usr/local/lib/python3.10/dist-packages/torchvision/transforms/v2/__init__.py:54: UserWarning: The torchvision.datapoints and torchvision.transforms.v2 namespaces are still Beta. While we do not expect major breaking changes, some APIs may still change according to user feedback. Please submit any feedback you may have in this issue: https://github.com/pytorch/vision/issues/6753, and you can also check out https://github.com/pytorch/vision/issues/7319 to learn more about the APIs that we suspect might involve future changes. You can silence this warning by calling torchvision.disable_beta_transforms_warning().\n",
|
||||
" warnings.warn(_BETA_TRANSFORMS_WARNING)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import torch\n",
|
||||
"import torch.nn as nn\n",
|
||||
"import torch.nn.functional as fn\n",
|
||||
"import torch.optim as optim\n",
|
||||
"import torchvision.transforms.functional as tvf\n",
|
||||
"from torchvision.transforms import v2\n",
|
||||
"from torch.utils.data import DataLoader\n",
|
||||
"\n",
|
||||
"from PIL import Image\n",
|
||||
"\n",
|
||||
"import datasets as ds\n",
|
||||
"from tqdm.autonotebook import tqdm\n",
|
||||
"\n",
|
||||
"import random\n",
|
||||
"\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"\n",
|
||||
"import numpy as np\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# original_dataset = ds.load_dataset(\"aharley/rvl_cdip\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"working_dataset = ds.load_from_disk(\"../.cache/huggingfaces/datasets/customrotation/\")\n",
|
||||
"prepimage = v2.Compose([v2.Grayscale(num_output_channels=3),v2.Resize(1100), v2.CenterCrop(1100),v2.ToImageTensor(), v2.ConvertImageDtype()])\n",
|
||||
"working_dataset.set_transform(prepimage)\n",
|
||||
"torch.cuda.empty_cache()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Parameter Declaration\n",
|
||||
"minRotation=-180\n",
|
||||
"maxRotation=180\n",
|
||||
"minTranslation=0\n",
|
||||
"maxTranslation=150\n",
|
||||
"minScale = 0.4\n",
|
||||
"maxScale = 1\n",
|
||||
"minShear = 0\n",
|
||||
"maxShear = 0\n",
|
||||
"\n",
|
||||
"minFill=0\n",
|
||||
"maxFill=255\n",
|
||||
"\n",
|
||||
"params = {\"minRotation\":minRotation,\"maxRotation\":maxRotation,\"minTranslation\":minTranslation,\"maxTranslation\":maxTranslation,\"minScale\":minScale,\"maxScale\":maxScale,\"minShear\":minShear,\"maxShear\":maxShear,\"minFill\":minFill,\"maxFill\":maxFill}"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def transform_picture(image_label, parameters):\n",
|
||||
" image = image_label['image']\n",
|
||||
"\n",
|
||||
" appliedRotation = random.uniform(parameters['minRotation'], parameters['maxRotation'])\n",
|
||||
" appliedXTranslation = random.uniform(parameters['minTranslation'], parameters['maxTranslation'])\n",
|
||||
" appliedYTranslation = random.uniform(parameters['minTranslation'], parameters['maxTranslation'])\n",
|
||||
" appliedScale = random.uniform(parameters['minScale'], parameters['maxScale'])\n",
|
||||
" appliedFill = random.uniform(parameters['minFill'], parameters['maxFill'])\n",
|
||||
" appliedXShear = random.uniform(parameters['minShear'], parameters['maxShear'])\n",
|
||||
" appliedYShear = random.uniform(parameters['minShear'], parameters['maxShear'])\n",
|
||||
" \n",
|
||||
" appliers = [v2.RandomApply(transforms=[v2.RandomPosterize(bits=1)], p=0.25),\n",
|
||||
" v2.RandomApply(transforms=[v2.ElasticTransform(alpha=25.0, fill=appliedFill)], p=0.25), # maybe add fill=appliedFill\n",
|
||||
" v2.RandomApply(transforms=[v2.GaussianBlur(kernel_size=(5,9), sigma=(0.1,2.))],p=0.25),\n",
|
||||
" v2.RandomApply(transforms=[v2.RandomEqualize()],p=0.25)]\n",
|
||||
" \n",
|
||||
" adjustedimage = tvf.affine(image, appliedRotation, [appliedXTranslation,appliedYTranslation], appliedScale, [appliedXShear, appliedYShear], fill=appliedFill)\n",
|
||||
"\n",
|
||||
" for applier in appliers:\n",
|
||||
" adjustedimage = applier(adjustedimage)\n",
|
||||
"\n",
|
||||
" \n",
|
||||
" adjustedimage = tvf.resize(adjustedimage, size=[1100,1100])\n",
|
||||
" \n",
|
||||
" return {'image':adjustedimage,'label':appliedRotation}"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# # Create own dataset from the images of the original dataset but make the labels the float value for the rotation. do the random rotation on all of the training ones but the labels for the validation and test should/can be 0\n",
|
||||
"# og_training_dataset = original_dataset['train']\n",
|
||||
"# og_testing_dataset = original_dataset['test']\n",
|
||||
"# og_validation_dataset = original_dataset['validation']\n",
|
||||
"\n",
|
||||
"# type(og_testing_dataset[0]['label'])\n",
|
||||
"\n",
|
||||
"# # type(transform_picture(og_testing_dataset[0], params))\n",
|
||||
"# new_testing_dataset = og_testing_dataset.map(transform_picture, fn_kwargs={'parameters':params})"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# class WorkaroundDataset(torch.utils.data.Dataset):\n",
|
||||
"# def __init__(self, dataset):\n",
|
||||
"# self._dataset = dataset\n",
|
||||
"\n",
|
||||
"# def __len__(self):\n",
|
||||
"# return len(self._dataset)\n",
|
||||
"\n",
|
||||
"# def __getitem__(self, idx):\n",
|
||||
"# return v2.Compose([v2.ToImageTensor(), v2.ConvertImageDtype()])(self._dataset[idx]['image'])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# # type(image_dataset['train'][0]['image'])\n",
|
||||
"# # print(image_dataset['train'][0]['image'])\n",
|
||||
"# img = image_dataset['train'][2]['image']\n",
|
||||
"# # img\n",
|
||||
"# # print(img.size)\n",
|
||||
"# crop = tvf.resize(img, size=[500])\n",
|
||||
"# # crop\n",
|
||||
"# # print(crop.size)\n",
|
||||
"# newimg = tvf.affine(crop, 180, [0,0], 0.7, 0)\n",
|
||||
"# newimg"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# appliedRotation = random.uniform(minRotation, maxRotation)\n",
|
||||
"# appliedXTranslation = random.uniform(minTranslation, maxTranslation)\n",
|
||||
"# appliedYTranslation = random.uniform(minTranslation, maxTranslation)\n",
|
||||
"# appliedScale = random.uniform(minScale, maxScale)\n",
|
||||
"# appliedFill = random.uniform(minFill, maxFill)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# newimg = tvf.affine(crop, appliedRotation, [appliedXTranslation,appliedYTranslation], appliedScale, shear, fill=appliedFill)\n",
|
||||
"# newimg\n",
|
||||
"\n",
|
||||
"# appliers = [v2.RandomApply(transforms=[v2.RandomPosterize(bits=1)], p=0.25),\n",
|
||||
"# v2.RandomApply(transforms=[v2.ElasticTransform(alpha=25.0, fill=appliedFill)], p=0.25),\n",
|
||||
"# v2.RandomApply(transforms=[v2.GaussianBlur(kernel_size=(5,9), sigma=(0.1,2.))],p=0.25),\n",
|
||||
"# v2.RandomApply(transforms=[v2.RandomEqualize()],p=0.25)]\n",
|
||||
"\n",
|
||||
"# for applier in appliers:\n",
|
||||
"# newimg = applier(newimg)\n",
|
||||
" \n",
|
||||
"# # newimg\n",
|
||||
"# newimg= tvf.resize(newimg, size=[1000,1000])\n",
|
||||
"# newimg\n",
|
||||
" "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# class SquarePad:\n",
|
||||
"# \tdef __call__(self, image):\n",
|
||||
"# \t\tw, h = image.size\n",
|
||||
"# \t\tmax_wh = np.max([w, h])\n",
|
||||
"# \t\thp = int((max_wh - w) / 2)\n",
|
||||
"# \t\tvp = int((max_wh - h) / 2)\n",
|
||||
"# \t\tpadding = (hp, vp, hp, vp)\n",
|
||||
"# \t\treturn tvf.pad(image, padding, 0, 'constant')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"\n",
|
||||
"class RotationDeterminer(nn.Module):\n",
|
||||
" def __init__(self):\n",
|
||||
" super().__init__()\n",
|
||||
" \n",
|
||||
" torch.cuda.empty_cache()\n",
|
||||
" \n",
|
||||
" self.device = torch.device(\"cpu\")\n",
|
||||
" if torch.cuda.is_available:\n",
|
||||
" self.device = torch.device(\"cuda:0\")\n",
|
||||
" \n",
|
||||
" \n",
|
||||
" self.appliers = [v2.RandomApply(transforms=[v2.RandomPosterize(bits=1)], p=0.25),\n",
|
||||
" v2.RandomApply(transforms=[v2.ElasticTransform(alpha=25.0)], p=0.25), # maybe add fill=appliedFill\n",
|
||||
" v2.RandomApply(transforms=[v2.GaussianBlur(kernel_size=(5,9), sigma=(0.1,2.))],p=0.25),\n",
|
||||
" v2.RandomApply(transforms=[v2.RandomEqualize()],p=0.25)]\n",
|
||||
" \n",
|
||||
" \n",
|
||||
" self.conv = nn.Sequential(nn.Conv2d(3, 9, kernel_size=11,stride=3), # 1100 x 1100 => 201 x 201\n",
|
||||
" nn.ReLU(inplace=True),\n",
|
||||
" nn.Conv2d(9, 18, kernel_size=5,stride=1),\n",
|
||||
" nn.ReLU(inplace=True),\n",
|
||||
" nn.MaxPool2d(kernel_size=4, stride=2),\n",
|
||||
" nn.Conv2d(18, 36, kernel_size=3,stride=2),\n",
|
||||
" nn.ReLU(inplace=True),\n",
|
||||
" nn.Conv2d(36, 72, kernel_size=3,stride=2),\n",
|
||||
" nn.ReLU(inplace=True),\n",
|
||||
" nn.AvgPool2d(kernel_size=5, stride=3),\n",
|
||||
" nn.Conv2d(72, 144, kernel_size=3,stride=1),\n",
|
||||
" nn.ReLU(inplace=True),\n",
|
||||
" nn.Conv2d(144, 288, kernel_size=5,stride=1),\n",
|
||||
" nn.ReLU(inplace=True),\n",
|
||||
" nn.MaxPool2d(kernel_size=4, stride=1),\n",
|
||||
" nn.Conv2d(288, 192, kernel_size=3,stride=1),\n",
|
||||
" nn.ReLU(inplace=True),\n",
|
||||
" nn.Conv2d(192, 192, kernel_size=3,stride=1), # => 1\n",
|
||||
" nn.ReLU(inplace=True))\n",
|
||||
" \n",
|
||||
" self.classifier = nn.Sequential(nn.Dropout(),\n",
|
||||
" nn.Linear(192, 2048),\n",
|
||||
" nn.ReLU(inplace=True),\n",
|
||||
" nn.Dropout(),\n",
|
||||
" nn.Linear(2048,2048),\n",
|
||||
" nn.ReLU(inplace=True),\n",
|
||||
" nn.Linear(2048,1))\n",
|
||||
" \n",
|
||||
" self.lossfunc = nn.MSELoss()\n",
|
||||
" \n",
|
||||
" self.imageprep = v2.Compose([self.SquarePad(),v2.Resize(1100),v2.Grayscale(num_output_channels=3),v2.CenterCrop(1100),v2.ToImageTensor(), v2.ConvertImageDtype()])\n",
|
||||
" \n",
|
||||
" \n",
|
||||
" class SquarePad:\n",
|
||||
" def __call__(self, image):\n",
|
||||
" # print(\"hi type:\", type(image))\n",
|
||||
" temp = image.size()\n",
|
||||
" w = temp[-2]\n",
|
||||
" h = temp[-1]\n",
|
||||
" max_wh = max([w, h])\n",
|
||||
" hp = int((max_wh - w) / 2)\n",
|
||||
" vp = int((max_wh - h) / 2)\n",
|
||||
" padding = (hp, vp, hp, vp)\n",
|
||||
" return tvf.pad(image, padding, 0, 'edge')\n",
|
||||
"\n",
|
||||
"\n",
|
||||
" \n",
|
||||
"\n",
|
||||
" \n",
|
||||
" def forward(self, image):\n",
|
||||
"\n",
|
||||
" transformedimage = self.imageprep(image)\n",
|
||||
" transformedimage = transformedimage.to(self.device)\n",
|
||||
"\n",
|
||||
" x = self.conv(transformedimage)\n",
|
||||
" x = nn.Flatten(start_dim=-3)(x)\n",
|
||||
" x = self.classifier(x)\n",
|
||||
" guessRotation = nn.Flatten(start_dim=0)(x)\n",
|
||||
" \n",
|
||||
" return guessRotation\n",
|
||||
" \n",
|
||||
" def loss(self, guess, trueAnswer):\n",
|
||||
" return self.lossfunc(guess, trueAnswer)\n",
|
||||
" \n",
|
||||
" "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# def batchmaker(entries, batchsize):\n",
|
||||
"# random.shuffle(entries)\n",
|
||||
"# listing = []\n",
|
||||
"# for i in range(0,len(entries), batchsize):\n",
|
||||
"# listing.append(entries[i:i+batchsize])\n",
|
||||
"# return listing"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# print(type(v2.Compose([v2.ToImageTensor(), v2.ConvertImageDtype()])(image_dataset['train'][0]['image'])))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# a, b, x = working_dataset['train'][0]['image'].size()\n",
|
||||
"# print(x)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 15,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def train(model, dataset, batchsize, num_epochs, stepsize, totalnumiters = -1):\n",
|
||||
" device = torch.device(\"cpu\")\n",
|
||||
" if torch.cuda.is_available:\n",
|
||||
" device = torch.device(\"cuda:0\")\n",
|
||||
" model = model.cuda()\n",
|
||||
" optimizer = optim.Adam(model.parameters(), lr=stepsize)\n",
|
||||
" \n",
|
||||
" counter = totalnumiters\n",
|
||||
" model = model.train()\n",
|
||||
" \n",
|
||||
" breakearly = True\n",
|
||||
" if totalnumiters == -1:\n",
|
||||
" print(\"hi\")\n",
|
||||
" breakearly = False\n",
|
||||
" totalnumiters = len(dataset) + 1\n",
|
||||
" \n",
|
||||
" for e in range(num_epochs):\n",
|
||||
" \n",
|
||||
" train_dataloader = DataLoader(dataset, batch_size=batchsize, shuffle=True)\n",
|
||||
" \n",
|
||||
" pbar = tqdm(train_dataloader)\n",
|
||||
" \n",
|
||||
" for i, batch in enumerate(pbar):\n",
|
||||
" torch.cuda.empty_cache()\n",
|
||||
" images, truerotations = batch['image'], batch['rotation']\n",
|
||||
" images = images.to(device)\n",
|
||||
" truerotations = truerotations.to(device)\n",
|
||||
"\n",
|
||||
" optimizer.zero_grad()\n",
|
||||
" \n",
|
||||
" guessRotation = model(images)\n",
|
||||
" \n",
|
||||
" truerotations = truerotations.float()\n",
|
||||
" \n",
|
||||
" loss = model.loss(guessRotation, truerotations)\n",
|
||||
" \n",
|
||||
" loss.backward()\n",
|
||||
" \n",
|
||||
" optimizer.step()\n",
|
||||
" counter = counter - batchsize\n",
|
||||
" if counter <= 0 and breakearly:\n",
|
||||
" print(\"endearly\")\n",
|
||||
" return\n",
|
||||
"\n",
|
||||
" "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 16,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"testimage = working_dataset['train'][10]['image']\n",
|
||||
"\n",
|
||||
"# testimage = v2.Compose([v2.Grayscale(num_output_channels=3),v2.ToTensor(),])(testimage)\n",
|
||||
"# testimage.size()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 17,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# plt.imshow(testimage)\n",
|
||||
"# plt.show()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 18,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# temp = testimage.size()\n",
|
||||
"# print(temp[-3])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 19,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"model = RotationDeterminer()\n",
|
||||
"device = torch.device(\"cpu\")\n",
|
||||
"if torch.cuda.is_available:\n",
|
||||
" device = torch.device(\"cuda:0\")\n",
|
||||
" model = model.cuda()\n",
|
||||
"\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 20,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# output = model(testimage)\n",
|
||||
"# print(output)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 21,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# train_dataloader = DataLoader(working_dataset['test'], batch_size=100, shuffle=True)\n",
|
||||
"# hold = next(iter(train_dataloader))\n",
|
||||
"# images1, labels1 = hold['image'], hold['rotation']\n",
|
||||
"# # print(images1)\n",
|
||||
"# print(labels1.size())"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 22,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/usr/local/lib/python3.10/dist-packages/torchvision/transforms/functional.py:1603: UserWarning: The default value of the antialias parameter of all the resizing transforms (Resize(), RandomResizedCrop(), etc.) will change from None to True in v0.17, in order to be consistent across the PIL and Tensor backends. To suppress this warning, directly pass antialias=True (recommended, future default), antialias=None (current default, which means False for Tensors and True for PIL), or antialias=False (only works on Tensors - PIL will still use antialiasing). This also applies if you are using the inference transforms from the models weights: update the call to weights.transforms(antialias=True).\n",
|
||||
" warnings.warn(\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"hi\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "faff1411ea0d485b9321271ebe6820db",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
" 0%| | 0/12800 [00:00<?, ?it/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "d26b6872bee74eaab8be6e7cfe53b190",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
" 0%| | 0/12800 [00:00<?, ?it/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"outputarray = np.array([working_dataset['train'][10]['rotation']])\n",
|
||||
"model = model.eval()\n",
|
||||
"output = model(testimage)\n",
|
||||
"outputarray = np.append(outputarray, output.detach().cpu().numpy())\n",
|
||||
"counter = 0\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"train(model, working_dataset['train'], 25, 2, 5e-3)\n",
|
||||
"\n",
|
||||
"model = model.eval()\n",
|
||||
"\n",
|
||||
"counter = 2 + counter\n",
|
||||
"output = model(testimage)\n",
|
||||
"outputarray = np.append(outputarray, output.detach().cpu().numpy())\n",
|
||||
"np.save(\"./testing_space/outputarray\", outputarray)\n",
|
||||
"np.save(\"./testing_space/counter\", counter)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 23,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"-164.93280103082208\n",
|
||||
"8.194759720936418e-05\n",
|
||||
"-0.1751984804868698\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print(outputarray[0])\n",
|
||||
"print(outputarray[1])\n",
|
||||
"print(outputarray[2])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 24,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"3"
|
||||
]
|
||||
},
|
||||
"execution_count": 24,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"len(outputarray)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 25,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"torch.save(model.state_dict(), \"./testing_space/modelsave\" + str(counter) +\" epochs\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 26,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"#load model\n",
|
||||
"# model.load_state_dict(torch.load(\"./testing_space/modelsave2epochs\"))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 27,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"hi\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "e144d16317094603b328e2db88a4853a",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
" 0%| | 0/12800 [00:00<?, ?it/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "bb6cf6e77ed34628bdfa6ed2a64ef284",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
" 0%| | 0/12800 [00:00<?, ?it/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"train(model, working_dataset['train'], 25, 2, 1e-3)\n",
|
||||
"counter = 2 + counter"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 28,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# outputarray = []"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 29,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"model = model.eval()\n",
|
||||
"output = model(testimage)\n",
|
||||
"outputarray = np.append(outputarray, output.detach().cpu().numpy())\n",
|
||||
"np.save(\"./testing_space/outputarray\", outputarray)\n",
|
||||
"np.save(\"./testing_space/counter\", counter)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 30,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"torch.save(model.state_dict(), \"./testing_space/modelsave\" + str(counter) +\" epochs\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 31,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"hi\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "c270b991cacd4abc996c602748e742f7",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
" 0%| | 0/12800 [00:00<?, ?it/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "1515223a29da4cfea86be156155fd06e",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
" 0%| | 0/12800 [00:00<?, ?it/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"train(model, working_dataset['train'], 25, 2, 1e-2)\n",
|
||||
"counter = 2 + counter"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 32,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"model = model.eval()\n",
|
||||
"output = model(testimage)\n",
|
||||
"outputarray = np.append(outputarray, output.detach().cpu().numpy())\n",
|
||||
"np.save(\"./testing_space/outputarray\", outputarray)\n",
|
||||
"np.save(\"./testing_space/counter\", counter)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 33,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"torch.save(model.state_dict(), \"./testing_space/modelsave\" + str(counter) +\" epochs\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 35,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[-1.64932801e+02 8.19475972e-05 -1.75198480e-01 -2.21363053e-01\n",
|
||||
" -2.17262208e-01]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print(outputarray)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.12"
|
||||
},
|
||||
"orig_nbformat": 4
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
@ -1,645 +0,0 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"version=2.0\n",
|
||||
"cachepath=\"../.cache/\"\n",
|
||||
"savepath=\"./savespot/\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/usr/local/lib/python3.10/dist-packages/torchvision/datapoints/__init__.py:12: UserWarning: The torchvision.datapoints and torchvision.transforms.v2 namespaces are still Beta. While we do not expect major breaking changes, some APIs may still change according to user feedback. Please submit any feedback you may have in this issue: https://github.com/pytorch/vision/issues/6753, and you can also check out https://github.com/pytorch/vision/issues/7319 to learn more about the APIs that we suspect might involve future changes. You can silence this warning by calling torchvision.disable_beta_transforms_warning().\n",
|
||||
" warnings.warn(_BETA_TRANSFORMS_WARNING)\n",
|
||||
"/usr/local/lib/python3.10/dist-packages/torchvision/transforms/v2/__init__.py:54: UserWarning: The torchvision.datapoints and torchvision.transforms.v2 namespaces are still Beta. While we do not expect major breaking changes, some APIs may still change according to user feedback. Please submit any feedback you may have in this issue: https://github.com/pytorch/vision/issues/6753, and you can also check out https://github.com/pytorch/vision/issues/7319 to learn more about the APIs that we suspect might involve future changes. You can silence this warning by calling torchvision.disable_beta_transforms_warning().\n",
|
||||
" warnings.warn(_BETA_TRANSFORMS_WARNING)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import torch\n",
|
||||
"from torch.utils.data import DataLoader\n",
|
||||
"import torch.nn as nn\n",
|
||||
"import torch.nn.functional as fn\n",
|
||||
"import torch.optim as optim\n",
|
||||
"import torchvision.transforms.functional as tvf\n",
|
||||
"import torchvision.transforms.v2 as v2\n",
|
||||
"import torchvision.models as models\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"from PIL import Image\n",
|
||||
"\n",
|
||||
"import datasets as ds\n",
|
||||
"from tqdm.autonotebook import tqdm\n",
|
||||
"\n",
|
||||
"import random\n",
|
||||
"\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"\n",
|
||||
"import numpy as np\n",
|
||||
"\n",
|
||||
"import os"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# models.list_models()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"torch.cuda.empty_cache()\n",
|
||||
"working_dataset = ds.load_from_disk(cachepath + \"datasets/customrotation/\")\n",
|
||||
"prepimage = v2.Compose([v2.Grayscale(num_output_channels=3),v2.Resize(512), v2.CenterCrop(512),v2.ToImageTensor(), v2.ConvertImageDtype()])\n",
|
||||
"working_dataset.set_transform(prepimage)\n",
|
||||
"testsample = working_dataset['train'][10]\n",
|
||||
"testimage = testsample['image']\n",
|
||||
"torch.cuda.empty_cache()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# print(models.resnet18(pretrained=True))\n",
|
||||
"# temp = models.resnet18(pretrained=True)\n",
|
||||
"# print(temp(testimage.unsqueeze(0)).shape)\n",
|
||||
"# device = torch.device(\"cpu\")\n",
|
||||
"# if torch.cuda.is_available:\n",
|
||||
"# device = torch.device(\"cuda:0\")\n",
|
||||
"# temp = temp.to(device)\n",
|
||||
"\n",
|
||||
"#to be deleted"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# print(temp(testimage).shape)\n",
|
||||
"#to be deleted"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class RotationDeterminer(nn.Module):\n",
|
||||
" def __init__(self, new=False):\n",
|
||||
" super(RotationDeterminer,self).__init__()\n",
|
||||
" \n",
|
||||
" torch.cuda.empty_cache()\n",
|
||||
" \n",
|
||||
" self.device = torch.device(\"cpu\")\n",
|
||||
" if torch.cuda.is_available:\n",
|
||||
" self.device = torch.device(\"cuda:0\")\n",
|
||||
" \n",
|
||||
" \n",
|
||||
" self.appliers = [v2.RandomApply(transforms=[v2.RandomPosterize(bits=1)], p=0.25),\n",
|
||||
" v2.RandomApply(transforms=[v2.ElasticTransform(alpha=25.0)], p=0.25), # maybe add fill=appliedFill\n",
|
||||
" v2.RandomApply(transforms=[v2.GaussianBlur(kernel_size=(5,9), sigma=(0.1,2.))],p=0.25),\n",
|
||||
" v2.RandomApply(transforms=[v2.RandomEqualize()],p=0.25)]\n",
|
||||
" \n",
|
||||
" \n",
|
||||
" # self.conv = nn.Sequential(nn.Conv2d(3, 9, kernel_size=11,stride=3), # 1100 x 1100 => 201 x 201\n",
|
||||
" # nn.ReLU(inplace=True),\n",
|
||||
" # nn.Conv2d(9, 18, kernel_size=5,stride=1),\n",
|
||||
" # nn.ReLU(inplace=True),\n",
|
||||
" # nn.MaxPool2d(kernel_size=4, stride=2),\n",
|
||||
" # nn.Conv2d(18, 36, kernel_size=3,stride=2),\n",
|
||||
" # nn.BatchNorm2d(36),\n",
|
||||
" # nn.ReLU(inplace=True),\n",
|
||||
" # nn.Conv2d(36, 72, kernel_size=3,stride=2),\n",
|
||||
" # nn.ReLU(inplace=True),\n",
|
||||
" # nn.AvgPool2d(kernel_size=5, stride=3),\n",
|
||||
" # nn.Conv2d(72, 144, kernel_size=3,stride=1),\n",
|
||||
" # nn.ReLU(inplace=True),\n",
|
||||
" # nn.Conv2d(144, 288, kernel_size=5,stride=1),\n",
|
||||
" # nn.ReLU(inplace=True),\n",
|
||||
" # nn.MaxPool2d(kernel_size=4, stride=1),\n",
|
||||
" # nn.Conv2d(288, 192, kernel_size=3,stride=1),\n",
|
||||
" # nn.ReLU(inplace=True),\n",
|
||||
" # nn.Conv2d(192, 192, kernel_size=3,stride=1), # => 1\n",
|
||||
" # nn.ReLU(inplace=True))\n",
|
||||
" # print(\"hi\")\n",
|
||||
" self.conv = models.resnet18(pretrained=new)\n",
|
||||
" \n",
|
||||
" self.classifier = nn.Sequential(nn.Linear(1000, 4096),\n",
|
||||
" nn.ReLU(inplace=True),\n",
|
||||
" nn.Linear(4096,1))\n",
|
||||
" \n",
|
||||
" self.lossfunc = nn.MSELoss()\n",
|
||||
" \n",
|
||||
" self.imageprep = v2.Compose([self.SquarePad(),v2.Resize(512),v2.Grayscale(num_output_channels=3),v2.CenterCrop(512),v2.ToImageTensor(), v2.ConvertImageDtype()])\n",
|
||||
" \n",
|
||||
" \n",
|
||||
" class SquarePad:\n",
|
||||
" def __call__(self, image):\n",
|
||||
" # print(\"hi type:\", type(image))\n",
|
||||
" temp = image.size()\n",
|
||||
" w = temp[-2]\n",
|
||||
" h = temp[-1]\n",
|
||||
" max_wh = max([w, h])\n",
|
||||
" hp = int((max_wh - w) / 2)\n",
|
||||
" vp = int((max_wh - h) / 2)\n",
|
||||
" padding = (hp, vp, hp, vp)\n",
|
||||
" return tvf.pad(image, padding, 0, 'edge')\n",
|
||||
"\n",
|
||||
"\n",
|
||||
" \n",
|
||||
"\n",
|
||||
" \n",
|
||||
" def forward(self, image):\n",
|
||||
"\n",
|
||||
" transformedimage = self.imageprep(image)\n",
|
||||
" transformedimage = transformedimage.to(self.device)\n",
|
||||
"\n",
|
||||
" if (len(transformedimage.shape) != 4 and len(transformedimage.shape) != 3):\n",
|
||||
" raise Exception(\"Sorry, Dimension of image is incorrect (\", len(transformedimage.shape),\"). Expected a 3D (single image) or 4D (batch of images) tensor\")\n",
|
||||
"\n",
|
||||
" if (len(transformedimage.shape) == 3):\n",
|
||||
" x = transformedimage.unsqueeze(0)\n",
|
||||
" else:\n",
|
||||
" x = transformedimage\n",
|
||||
" \n",
|
||||
" x = self.conv(x)\n",
|
||||
" # print(x.shape)\n",
|
||||
" # x = nn.Flatten(start_dim=-1)(x)\n",
|
||||
" # print(x.shape)\n",
|
||||
" x = self.classifier(x)\n",
|
||||
" # print(x.shape)\n",
|
||||
" guessRotation = nn.Flatten(start_dim=0)(x)\n",
|
||||
" \n",
|
||||
" return guessRotation\n",
|
||||
" \n",
|
||||
" def loss(self, guess, trueAnswer):\n",
|
||||
" return self.lossfunc(guess, trueAnswer)\n",
|
||||
" \n",
|
||||
" "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def train(model, dataset, batchsize, num_epochs, stepsize, totalnumiters = -1):\n",
|
||||
" device = torch.device(\"cpu\")\n",
|
||||
" if torch.cuda.is_available:\n",
|
||||
" device = torch.device(\"cuda:0\")\n",
|
||||
" model = model.cuda()\n",
|
||||
" optimizer = optim.Adam(model.parameters(), lr=stepsize)\n",
|
||||
" \n",
|
||||
" counter = totalnumiters\n",
|
||||
" model = model.train()\n",
|
||||
" \n",
|
||||
" breakearly = True\n",
|
||||
" if totalnumiters == -1:\n",
|
||||
" print(\"hi\")\n",
|
||||
" breakearly = False\n",
|
||||
" totalnumiters = len(dataset) + 1\n",
|
||||
" \n",
|
||||
" for e in range(num_epochs):\n",
|
||||
" \n",
|
||||
" train_dataloader = DataLoader(dataset, batch_size=batchsize, shuffle=True)\n",
|
||||
" \n",
|
||||
" pbar = tqdm(train_dataloader)\n",
|
||||
" \n",
|
||||
" for i, batch in enumerate(pbar):\n",
|
||||
" torch.cuda.empty_cache()\n",
|
||||
" images, truerotations = batch['image'], batch['rotation']\n",
|
||||
" images = images.to(device)\n",
|
||||
" truerotations = truerotations.to(device)\n",
|
||||
"\n",
|
||||
" optimizer.zero_grad()\n",
|
||||
" \n",
|
||||
" guessRotation = model(images)\n",
|
||||
" \n",
|
||||
" truerotations = truerotations.float()\n",
|
||||
" \n",
|
||||
" loss = model.loss(guessRotation, truerotations)\n",
|
||||
" \n",
|
||||
" loss.backward()\n",
|
||||
" \n",
|
||||
" optimizer.step()\n",
|
||||
" counter = counter - batchsize\n",
|
||||
" if counter <= 0 and breakearly:\n",
|
||||
" print(\"endearly\")\n",
|
||||
" return\n",
|
||||
"\n",
|
||||
" "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def measure(model, dataset):\n",
|
||||
" total=0\n",
|
||||
" within30=0\n",
|
||||
" within15=0\n",
|
||||
" within10=0\n",
|
||||
" within5=0\n",
|
||||
" within1=0\n",
|
||||
" withintenth=0\n",
|
||||
" model = model.eval()\n",
|
||||
" pbar = tqdm(dataset)\n",
|
||||
" for i, sample in enumerate(pbar):\n",
|
||||
" if (i % 100 == 0):\n",
|
||||
" torch.cuda.empty_cache()\n",
|
||||
" images, truerotations = sample['image'], sample['rotation']\n",
|
||||
" output = model(images)\n",
|
||||
" outputvalue = output.item()\n",
|
||||
" total = total + 1\n",
|
||||
" if (abs(outputvalue - truerotations) < 0.1):\n",
|
||||
" withintenth = withintenth + 1\n",
|
||||
" within1 = within1 + 1\n",
|
||||
" within5 = within5 + 1\n",
|
||||
" within10 = within10 + 1\n",
|
||||
" within15 = within15 + 1\n",
|
||||
" within30 = within30 + 1\n",
|
||||
" elif (abs(outputvalue - truerotations) < 1):\n",
|
||||
" within1 = within1 + 1\n",
|
||||
" within5 = within5 + 1\n",
|
||||
" within10 = within10 + 1\n",
|
||||
" within15 = within15 + 1\n",
|
||||
" within30 = within30 + 1\n",
|
||||
" elif (abs(outputvalue - truerotations) < 5):\n",
|
||||
" within5 = within5 + 1\n",
|
||||
" within10 = within10 + 1\n",
|
||||
" within15 = within15 + 1\n",
|
||||
" within30 = within30 + 1\n",
|
||||
" elif (abs(outputvalue - truerotations) < 10):\n",
|
||||
" within10 = within10 + 1\n",
|
||||
" within15 = within15 + 1\n",
|
||||
" within30 = within30 + 1\n",
|
||||
" elif (abs(outputvalue - truerotations) < 15):\n",
|
||||
" within15 = within15 + 1\n",
|
||||
" within30 = within30 + 1\n",
|
||||
" elif (abs(outputvalue - truerotations) < 30):\n",
|
||||
" within30 = within30 + 1\n",
|
||||
" # print(\"Hi\")\n",
|
||||
" return {\"Within 30 Degrees\": within30/total, \"Within 15 Degrees\": within15/total, \"Within 10 Degrees\": within10/total, \"Within 5 Degrees\": within5/total, \"Within 1 Degree\": within1/total, \"Within 0.1 Degree\": withintenth/total}"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/usr/local/lib/python3.10/dist-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.\n",
|
||||
" warnings.warn(\n",
|
||||
"/usr/local/lib/python3.10/dist-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=ResNet18_Weights.IMAGENET1K_V1`. You can also use `weights=ResNet18_Weights.DEFAULT` to get the most up-to-date weights.\n",
|
||||
" warnings.warn(msg)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"model = RotationDeterminer(new=True)\n",
|
||||
"device = torch.device(\"cpu\")\n",
|
||||
"if torch.cuda.is_available:\n",
|
||||
" device = torch.device(\"cuda:0\")\n",
|
||||
" model = model.to(device)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# # used when starting a new model training\n",
|
||||
"# counter = 0\n",
|
||||
"# outputarray = np.array([])\n",
|
||||
"# tempdict = {\"Epochs Done\": counter}\n",
|
||||
"# tempdict.update(measure(model, working_dataset['validation']))\n",
|
||||
"# outputarray = np.append(outputarray, tempdict)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# load values\n",
|
||||
"counter = np.load(savepath + \"/v\"+str(version)+\"/counter.npy\")\n",
|
||||
"model.load_state_dict(torch.load(savepath + \"/v\"+str(version)+\"/modelsave\" + str(counter) +\"epochs\"))\n",
|
||||
"outputarray = np.load(savepath + \"/v\"+str(version)+\"/outputarray.npy\", allow_pickle=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# # used to rollback the model one training loop\n",
|
||||
"# counter = 6\n",
|
||||
"# outputarray = #removed the 7th element, will go from the 6th epoch"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"hi\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "3048e1546e12444193f99b15781768d9",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
" 0%| | 0/12800 [00:00<?, ?it/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/usr/local/lib/python3.10/dist-packages/torchvision/transforms/functional.py:1603: UserWarning: The default value of the antialias parameter of all the resizing transforms (Resize(), RandomResizedCrop(), etc.) will change from None to True in v0.17, in order to be consistent across the PIL and Tensor backends. To suppress this warning, directly pass antialias=True (recommended, future default), antialias=None (current default, which means False for Tensors and True for PIL), or antialias=False (only works on Tensors - PIL will still use antialiasing). This also applies if you are using the inference transforms from the models weights: update the call to weights.transforms(antialias=True).\n",
|
||||
" warnings.warn(\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "1a36a12b123e4b24bf00a8eeec2e396a",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
" 0%| | 0/12800 [00:00<?, ?it/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# train\n",
|
||||
"numepochs = 2\n",
|
||||
"batchsize = 25\n",
|
||||
"stepsize = 1e-3\n",
|
||||
"train(model, working_dataset['train'], batchsize, numepochs, stepsize)\n",
|
||||
"# model = model.eval()\n",
|
||||
"# output = model(testimage)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 15,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# print(output)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 16,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "6b99a98f480745c4a375bf1e713708ed",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
" 0%| | 0/40000 [00:00<?, ?it/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# outputarray = np.append(outputarray, output.detach().cpu().numpy())\n",
|
||||
"counter = numepochs + counter\n",
|
||||
"tempdict = {\"Epochs Done\": counter}\n",
|
||||
"tempdict.update(measure(model, working_dataset['validation']))\n",
|
||||
"outputarray = np.append(outputarray, tempdict)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 17,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# save values\n",
|
||||
"torch.save(model.state_dict(), savepath + \"/v\"+str(version)+\"/modelsave\" + str(counter) +\"epochs\")\n",
|
||||
"np.save(savepath + \"/v\"+str(version)+\"/outputarray\", outputarray)\n",
|
||||
"np.save(savepath + \"/v\"+str(version)+\"/counter\", counter)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 18,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[{'Epochs Done': 0, 'Within 30 Degrees': 0.162575, 'Within 15 Degrees': 0.080575, 'Within 10 Degrees': 0.053725, 'Within 5 Degrees': 0.027125, 'Within 1 Degree': 0.00545, 'Within 0.1 Degree': 0.00075}\n",
|
||||
" {'Epochs Done': 1, 'Within 30 Degrees': 0.7764, 'Within 15 Degrees': 0.65105, 'Within 10 Degrees': 0.538875, 'Within 5 Degrees': 0.322625, 'Within 1 Degree': 0.070375, 'Within 0.1 Degree': 0.00805}\n",
|
||||
" {'Epochs Done': 5, 'Within 30 Degrees': 0.891675, 'Within 15 Degrees': 0.8042, 'Within 10 Degrees': 0.673275, 'Within 5 Degrees': 0.415725, 'Within 1 Degree': 0.092375, 'Within 0.1 Degree': 0.009275}\n",
|
||||
" {'Epochs Done': 8, 'Within 30 Degrees': 0.928125, 'Within 15 Degrees': 0.881625, 'Within 10 Degrees': 0.7686, 'Within 5 Degrees': 0.4791, 'Within 1 Degree': 0.102925, 'Within 0.1 Degree': 0.009975}\n",
|
||||
" {'Epochs Done': 11, 'Within 30 Degrees': 0.9417, 'Within 15 Degrees': 0.91265, 'Within 10 Degrees': 0.86655, 'Within 5 Degrees': 0.633125, 'Within 1 Degree': 0.14265, 'Within 0.1 Degree': 0.01495}\n",
|
||||
" {'Epochs Done': 13, 'Within 30 Degrees': 0.941575, 'Within 15 Degrees': 0.917375, 'Within 10 Degrees': 0.889125, 'Within 5 Degrees': 0.735525, 'Within 1 Degree': 0.1992, 'Within 0.1 Degree': 0.019875}]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print(outputarray)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 26,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "d339e6a22ccf4812bdad90dd3d546c68",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
" 0%| | 0/39999 [00:00<?, ?it/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/usr/local/lib/python3.10/dist-packages/torchvision/transforms/functional.py:1603: UserWarning: The default value of the antialias parameter of all the resizing transforms (Resize(), RandomResizedCrop(), etc.) will change from None to True in v0.17, in order to be consistent across the PIL and Tensor backends. To suppress this warning, directly pass antialias=True (recommended, future default), antialias=None (current default, which means False for Tensors and True for PIL), or antialias=False (only works on Tensors - PIL will still use antialiasing). This also applies if you are using the inference transforms from the models weights: update the call to weights.transforms(antialias=True).\n",
|
||||
" warnings.warn(\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"{'Within 30 Degrees': 0.9433985849646241,\n",
|
||||
" 'Within 15 Degrees': 0.9174979374484362,\n",
|
||||
" 'Within 10 Degrees': 0.889422235555889,\n",
|
||||
" 'Within 5 Degrees': 0.737118427960699,\n",
|
||||
" 'Within 1 Degree': 0.1995799894997375,\n",
|
||||
" 'Within 0.1 Degree': 0.020050501262531564}"
|
||||
]
|
||||
},
|
||||
"execution_count": 26,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"measure(model, working_dataset['test'])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 20,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# first epoch 25 batchsize, 1e-3 stepsize # GOOD PROGRESS SO FAR\n",
|
||||
"# epoch 2-11 25 batchsize, 1e-3 stepsize"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 21,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# model1 = RotationDeterminer(new=False)\n",
|
||||
"# device = torch.device(\"cpu\")\n",
|
||||
"# if torch.cuda.is_available:\n",
|
||||
"# device = torch.device(\"cuda:0\")\n",
|
||||
"# model1 = model1.to(device)\n",
|
||||
"# measurementarray=np.array([])\n",
|
||||
"# for i in range(counter+1):\n",
|
||||
"# print(i)\n",
|
||||
"# if (i == 0 or i == 1 or i == 5 or i == 8 or i == 11):\n",
|
||||
"# tempdict = {\"Epochs Done\": i}\n",
|
||||
"# model1.load_state_dict(torch.load(savepath + \"/v\"+str(version)+\"/modelsave\" + str(i) +\"epochs\"))\n",
|
||||
"# tempdict.update(measure(model1, working_dataset['validation']))\n",
|
||||
"# measurementarray = np.append(measurementarray, tempdict)\n",
|
||||
" \n",
|
||||
"# print(\"hi\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 22,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# print(measurementarray)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 23,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# np.save(savepath + \"/v\"+str(version)+\"/outputarray\", measurementarray)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 24,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# measurementarraycopy = measurementarray\n",
|
||||
"# tempdict = {\"Epochs Done\": 1}\n",
|
||||
"# tempdict.update(measurementarraycopy[0])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 25,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# print(tempdict)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.12"
|
||||
},
|
||||
"orig_nbformat": 4
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
@ -1,144 +0,0 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"version=2.0\n",
|
||||
"cachepath=\"../.cache/\"\n",
|
||||
"savepath=\"./savespot/\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/usr/local/lib/python3.10/dist-packages/torchvision/datapoints/__init__.py:12: UserWarning: The torchvision.datapoints and torchvision.transforms.v2 namespaces are still Beta. While we do not expect major breaking changes, some APIs may still change according to user feedback. Please submit any feedback you may have in this issue: https://github.com/pytorch/vision/issues/6753, and you can also check out https://github.com/pytorch/vision/issues/7319 to learn more about the APIs that we suspect might involve future changes. You can silence this warning by calling torchvision.disable_beta_transforms_warning().\n",
|
||||
" warnings.warn(_BETA_TRANSFORMS_WARNING)\n",
|
||||
"/usr/local/lib/python3.10/dist-packages/torchvision/transforms/v2/__init__.py:54: UserWarning: The torchvision.datapoints and torchvision.transforms.v2 namespaces are still Beta. While we do not expect major breaking changes, some APIs may still change according to user feedback. Please submit any feedback you may have in this issue: https://github.com/pytorch/vision/issues/6753, and you can also check out https://github.com/pytorch/vision/issues/7319 to learn more about the APIs that we suspect might involve future changes. You can silence this warning by calling torchvision.disable_beta_transforms_warning().\n",
|
||||
" warnings.warn(_BETA_TRANSFORMS_WARNING)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import torch\n",
|
||||
"from torch.utils.data import DataLoader\n",
|
||||
"import torch.nn as nn\n",
|
||||
"import torch.nn.functional as fn\n",
|
||||
"import torch.optim as optim\n",
|
||||
"import torchvision.transforms.functional as tvf\n",
|
||||
"import torchvision.transforms.v2 as v2\n",
|
||||
"import torchvision.models as models\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"from PIL import Image\n",
|
||||
"\n",
|
||||
"import datasets as ds\n",
|
||||
"from tqdm.autonotebook import tqdm\n",
|
||||
"\n",
|
||||
"import random\n",
|
||||
"\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"import cv2\n",
|
||||
"import numpy as np\n",
|
||||
"import myfunctions as mf\n",
|
||||
"\n",
|
||||
"torch.cuda.empty_cache()\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# array = np.load(\"./testing_space/outputarray.npy\")\n",
|
||||
"# counter = np.load(\"./testing_space/counter.npy\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# print(array)\n",
|
||||
"# print(counter)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"img = cv2.imread('./test_images/IMG_7605.jpg')\n",
|
||||
"# img = mf.ResizeWithAspectRatio(img, 1000)\n",
|
||||
"# img = mf.ResizeWithAspectRatio(mf.SquarePad(fill=255)(img),1000)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"rotatedimg = mf.houghlinedeskewandcrop(img)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# out = mf.morphologyCrop(img)\n",
|
||||
"# out = cv2.cvtColor(out, cv2.COLOR_BGR2GRAY)\n",
|
||||
"# out = cv2.threshold(out, 200, 255, cv2.THRESH_BINARY)[1]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"cv2.imshow(\"result1\", rotatedimg)\n",
|
||||
"# cv2.imshow(\"result2\", result2)\n",
|
||||
"cv2.waitKey(0)\n",
|
||||
"cv2.destroyAllWindows()"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.12"
|
||||
},
|
||||
"orig_nbformat": 4
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
@ -1,345 +0,0 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# ORIGINAL DOCUMENT FOR MORPHOLOGY CROP can maybe be deleted"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 27,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import cv2\n",
|
||||
"import numpy as np\n",
|
||||
"\n",
|
||||
"import torch\n",
|
||||
"from torch.utils.data import DataLoader\n",
|
||||
"import torch.nn as nn\n",
|
||||
"import torch.nn.functional as fn\n",
|
||||
"import torch.optim as optim\n",
|
||||
"import torchvision.transforms.functional as tvf\n",
|
||||
"import torchvision.transforms.v2 as v2\n",
|
||||
"import torchvision.models as models\n",
|
||||
"import torchvision.transforms as t\n",
|
||||
"\n",
|
||||
"import myfunctions as mf\n",
|
||||
"\n",
|
||||
"from PIL import Image"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 28,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# read image as grayscale\n",
|
||||
"img = cv2.imread('./test_images/IMG_7640.jpg')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 29,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# def ResizeWithAspectRatio(image, width=None, height=None, inter=cv2.INTER_AREA):\n",
|
||||
"# dim = None\n",
|
||||
"# (h, w) = image.shape[:2]\n",
|
||||
"\n",
|
||||
"# if width is None and height is None:\n",
|
||||
"# return image\n",
|
||||
"# if width is None:\n",
|
||||
"# r = height / float(h)\n",
|
||||
"# dim = (int(w * r), height)\n",
|
||||
"# else:\n",
|
||||
"# r = width / float(w)\n",
|
||||
"# dim = (width, int(h * r))\n",
|
||||
"\n",
|
||||
"# return cv2.resize(image, dim, interpolation=inter)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 30,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# convert to grayscale\n",
|
||||
"gray = cv2.cvtColor(img,cv2.COLOR_BGR2GRAY)\n",
|
||||
"\n",
|
||||
"# threshold\n",
|
||||
"thresh = cv2.threshold(gray, 190, 255, cv2.THRESH_BINARY)[1]\n",
|
||||
"\n",
|
||||
"# apply morphology\n",
|
||||
"kernel = np.ones((7,7), np.uint8)\n",
|
||||
"morph = cv2.morphologyEx(thresh, cv2.MORPH_CLOSE, kernel)\n",
|
||||
"kernel = np.ones((9,9), np.uint8)\n",
|
||||
"morph = cv2.morphologyEx(morph, cv2.MORPH_ERODE, kernel)\n",
|
||||
"\n",
|
||||
"# get largest contour\n",
|
||||
"contours = cv2.findContours(morph, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)\n",
|
||||
"contours = contours[0] if len(contours) == 2 else contours[1]\n",
|
||||
"area_thresh = 0\n",
|
||||
"for c in contours:\n",
|
||||
" area = cv2.contourArea(c)\n",
|
||||
" if area > area_thresh:\n",
|
||||
" area_thresh = area\n",
|
||||
" big_contour = c\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# get bounding box\n",
|
||||
"x,y,w,h = cv2.boundingRect(big_contour)\n",
|
||||
"\n",
|
||||
"# draw filled contour on black background\n",
|
||||
"mask = np.zeros_like(gray)\n",
|
||||
"mask = cv2.merge([mask,mask,mask])\n",
|
||||
"cv2.drawContours(mask, [big_contour], -1, (255,255,255), cv2.FILLED)\n",
|
||||
"\n",
|
||||
"# apply mask to input\n",
|
||||
"result1 = img.copy()\n",
|
||||
"result1 = cv2.bitwise_and(result1, mask)\n",
|
||||
"\n",
|
||||
"# crop result\n",
|
||||
"result2 = result1[y:y+h, x:x+w]\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 32,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# view result\n",
|
||||
"# cv2.imshow(\"threshold\", thresh)\n",
|
||||
"# cv2.imshow(\"morph\", morph)\n",
|
||||
"# cv2.imshow(\"mask\", mask)\n",
|
||||
"# cv2.imshow(\"result1\", result1)\n",
|
||||
"resizedresult2 = mf.ResizeWithAspectRatio(result2, 1000)\n",
|
||||
"cv2.imwrite(\"./testing_space/cropped1.jpg\", resizedresult2)\n",
|
||||
"cv2.imshow(\"result2\", resizedresult2)\n",
|
||||
"cv2.waitKey(0)\n",
|
||||
"cv2.destroyAllWindows()\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 16,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class RotationDeterminer(nn.Module):\n",
|
||||
" def __init__(self, new=False):\n",
|
||||
" super(RotationDeterminer,self).__init__()\n",
|
||||
" \n",
|
||||
" torch.cuda.empty_cache()\n",
|
||||
" \n",
|
||||
" self.device = torch.device(\"cpu\")\n",
|
||||
" if torch.cuda.is_available:\n",
|
||||
" self.device = torch.device(\"cuda:0\")\n",
|
||||
" \n",
|
||||
" \n",
|
||||
" self.appliers = [v2.RandomApply(transforms=[v2.RandomPosterize(bits=1)], p=0.25),\n",
|
||||
" v2.RandomApply(transforms=[v2.ElasticTransform(alpha=25.0)], p=0.25), # maybe add fill=appliedFill\n",
|
||||
" v2.RandomApply(transforms=[v2.GaussianBlur(kernel_size=(5,9), sigma=(0.1,2.))],p=0.25),\n",
|
||||
" v2.RandomApply(transforms=[v2.RandomEqualize()],p=0.25)]\n",
|
||||
" \n",
|
||||
" \n",
|
||||
" # self.conv = nn.Sequential(nn.Conv2d(3, 9, kernel_size=11,stride=3), # 1100 x 1100 => 201 x 201\n",
|
||||
" # nn.ReLU(inplace=True),\n",
|
||||
" # nn.Conv2d(9, 18, kernel_size=5,stride=1),\n",
|
||||
" # nn.ReLU(inplace=True),\n",
|
||||
" # nn.MaxPool2d(kernel_size=4, stride=2),\n",
|
||||
" # nn.Conv2d(18, 36, kernel_size=3,stride=2),\n",
|
||||
" # nn.BatchNorm2d(36),\n",
|
||||
" # nn.ReLU(inplace=True),\n",
|
||||
" # nn.Conv2d(36, 72, kernel_size=3,stride=2),\n",
|
||||
" # nn.ReLU(inplace=True),\n",
|
||||
" # nn.AvgPool2d(kernel_size=5, stride=3),\n",
|
||||
" # nn.Conv2d(72, 144, kernel_size=3,stride=1),\n",
|
||||
" # nn.ReLU(inplace=True),\n",
|
||||
" # nn.Conv2d(144, 288, kernel_size=5,stride=1),\n",
|
||||
" # nn.ReLU(inplace=True),\n",
|
||||
" # nn.MaxPool2d(kernel_size=4, stride=1),\n",
|
||||
" # nn.Conv2d(288, 192, kernel_size=3,stride=1),\n",
|
||||
" # nn.ReLU(inplace=True),\n",
|
||||
" # nn.Conv2d(192, 192, kernel_size=3,stride=1), # => 1\n",
|
||||
" # nn.ReLU(inplace=True))\n",
|
||||
" # print(\"hi\")\n",
|
||||
" self.conv = models.resnet18(pretrained=new)\n",
|
||||
" \n",
|
||||
" self.classifier = nn.Sequential(nn.Linear(1000, 4096),\n",
|
||||
" nn.ReLU(inplace=True),\n",
|
||||
" nn.Linear(4096,1))\n",
|
||||
" \n",
|
||||
" self.lossfunc = nn.MSELoss()\n",
|
||||
" \n",
|
||||
" self.imageprep = v2.Compose([self.SquarePad(),v2.Resize(512),v2.Grayscale(num_output_channels=3),v2.CenterCrop(512),v2.ToImageTensor(), v2.ConvertImageDtype()])\n",
|
||||
" \n",
|
||||
" \n",
|
||||
" class SquarePad:\n",
|
||||
" def __call__(self, image):\n",
|
||||
" # print(\"hi type:\", type(image))\n",
|
||||
" temp = image.size()\n",
|
||||
" w = temp[-2]\n",
|
||||
" h = temp[-1]\n",
|
||||
" max_wh = max([w, h])\n",
|
||||
" hp = int((max_wh - w) / 2)\n",
|
||||
" vp = int((max_wh - h) / 2)\n",
|
||||
" padding = (hp, vp, hp, vp)\n",
|
||||
" return tvf.pad(image, padding, 0, 'edge')\n",
|
||||
"\n",
|
||||
"\n",
|
||||
" \n",
|
||||
"\n",
|
||||
" \n",
|
||||
" def forward(self, image):\n",
|
||||
"\n",
|
||||
" transformedimage = self.imageprep(image)\n",
|
||||
" transformedimage = transformedimage.to(self.device)\n",
|
||||
"\n",
|
||||
" if (len(transformedimage.shape) != 4 and len(transformedimage.shape) != 3):\n",
|
||||
" raise Exception(\"Sorry, Dimension of image is incorrect (\", len(transformedimage.shape),\"). Expected a 3D (single image) or 4D (batch of images) tensor\")\n",
|
||||
"\n",
|
||||
" if (len(transformedimage.shape) == 3):\n",
|
||||
" x = transformedimage.unsqueeze(0)\n",
|
||||
" else:\n",
|
||||
" x = transformedimage\n",
|
||||
" \n",
|
||||
" x = self.conv(x)\n",
|
||||
" # print(x.shape)\n",
|
||||
" # x = nn.Flatten(start_dim=-1)(x)\n",
|
||||
" # print(x.shape)\n",
|
||||
" x = self.classifier(x)\n",
|
||||
" # print(x.shape)\n",
|
||||
" guessRotation = nn.Flatten(start_dim=0)(x)\n",
|
||||
" \n",
|
||||
" return guessRotation\n",
|
||||
" \n",
|
||||
" def loss(self, guess, trueAnswer):\n",
|
||||
" return self.lossfunc(guess, trueAnswer)\n",
|
||||
" \n",
|
||||
" "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 17,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/usr/local/lib/python3.10/dist-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.\n",
|
||||
" warnings.warn(\n",
|
||||
"/usr/local/lib/python3.10/dist-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=ResNet18_Weights.IMAGENET1K_V1`. You can also use `weights=ResNet18_Weights.DEFAULT` to get the most up-to-date weights.\n",
|
||||
" warnings.warn(msg)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"model = RotationDeterminer(new=True)\n",
|
||||
"device = torch.device(\"cpu\")\n",
|
||||
"if torch.cuda.is_available:\n",
|
||||
" device = torch.device(\"cuda:0\")\n",
|
||||
" model = model.to(device)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 19,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"torch.Size([1, 1174, 1000])\n",
|
||||
"torch.Size([3, 1174, 1000])\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/usr/local/lib/python3.10/dist-packages/torchvision/transforms/functional.py:1603: UserWarning: The default value of the antialias parameter of all the resizing transforms (Resize(), RandomResizedCrop(), etc.) will change from None to True in v0.17, in order to be consistent across the PIL and Tensor backends. To suppress this warning, directly pass antialias=True (recommended, future default), antialias=None (current default, which means False for Tensors and True for PIL), or antialias=False (only works on Tensors - PIL will still use antialiasing). This also applies if you are using the inference transforms from the models weights: update the call to weights.transforms(antialias=True).\n",
|
||||
" warnings.warn(\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"-0.1470905989408493\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"tensorize = v2.Compose([v2.ToImageTensor(), v2.ConvertImageDtype()])\n",
|
||||
"grayscaler = v2.Grayscale(num_output_channels=3)\n",
|
||||
"\n",
|
||||
"imagetobeprocessed = cv2.cvtColor(resizedresult2,cv2.COLOR_BGR2GRAY)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"tensorizedimage = torch.unsqueeze(torch.from_numpy(imagetobeprocessed),0)\n",
|
||||
"print(tensorizedimage.shape)\n",
|
||||
"adjustedtensorizedimage = tensorize(grayscaler(t.ToPILImage()(tensorizedimage)))\n",
|
||||
"print(adjustedtensorizedimage.shape)\n",
|
||||
"rotation = model(adjustedtensorizedimage).item()\n",
|
||||
"print(rotation)\n",
|
||||
"rotatedimage = t.Resize(size=1000)(tvf.rotate(adjustedtensorizedimage, rotation))\n",
|
||||
"# imS = mf.ResizeWithAspectRatio(filereadimage, 1000)\n",
|
||||
"# imS = cv2.resize(filereadimage, (960, 540)) \n",
|
||||
"open_cv_image = np.array(t.ToPILImage()(rotatedimage))\n",
|
||||
"cv2.imshow(f'image', open_cv_image)\n",
|
||||
"key = cv2.waitKey(0)\n",
|
||||
"cv2.destroyAllWindows()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# # save result\n",
|
||||
"# cv2.imwrite(\"paper_thresh.jpg\", thresh)\n",
|
||||
"# cv2.imwrite(\"paper_morph.jpg\", morph)\n",
|
||||
"# cv2.imwrite(\"paper_mask.jpg\", mask)\n",
|
||||
"# cv2.imwrite(\"paper_result1.jpg\", result1)\n",
|
||||
"# cv2.imwrite(\"paper_result2.jpg\", result2)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.12"
|
||||
},
|
||||
"orig_nbformat": 4
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
@ -1,387 +0,0 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 772,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import cv2\n",
|
||||
"import numpy as np\n",
|
||||
"\n",
|
||||
"import myfunctions as mf\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"import scipy.stats as st\n",
|
||||
"import math"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 773,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# read image as grayscale\n",
|
||||
"img = cv2.imread('./test_images/IMG_7605.jpg')\n",
|
||||
"# img = mf.ResizeWithAspectRatio(img,1000)\n",
|
||||
"# img = mf.rotate(img, 54)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 774,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"prepped = mf.ResizeWithAspectRatio(mf.SquarePad(fill=255)(img),1000)\n",
|
||||
"prepped = mf.premorphCrop(prepped)\n",
|
||||
"prepped = mf.ResizeWithAspectRatio(mf.SquarePad(fill=255)(prepped),1000)\n",
|
||||
"# kernel = np.ones((5,5), np.uint8)\n",
|
||||
"# prepped = cv2.dilate(prepped, kernel, iterations=1)\n",
|
||||
"gray1 = cv2.cvtColor(prepped, cv2.COLOR_BGR2GRAY)\n",
|
||||
"dst1 = cv2.Canny(gray1, 0, 500, None, 3)\n",
|
||||
"\n",
|
||||
"kernel = np.ones((5,5), np.uint8)\n",
|
||||
"out = cv2.morphologyEx(dst1, cv2.MORPH_DILATE, kernel)\n",
|
||||
"out = cv2.blur(out, (5,5))\n",
|
||||
"kernel = np.ones((6,6), np.uint8)\n",
|
||||
"dst1 = cv2.morphologyEx(out, cv2.MORPH_ERODE, kernel)\n",
|
||||
"\n",
|
||||
"dst1 = cv2.Canny(dst1, 0, 500, None, 3)\n",
|
||||
"\n",
|
||||
"cdstP = prepped.copy()\n",
|
||||
"cdstPmargin = cdstP.copy()\n",
|
||||
"basecdstP = cdstP.copy()\n",
|
||||
"linesP = cv2.HoughLinesP(dst1, 1, np.pi / 180, 30, None, 90, 30)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 779,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# # testing = dst1.copy()\n",
|
||||
"# # kernel = np.ones((5,5), np.uint8)\n",
|
||||
"# # out = cv2.morphologyEx(testing, cv2.MORPH_DILATE, kernel)\n",
|
||||
"# # out = cv2.blur(out, (5,5))\n",
|
||||
"# # kernel = np.ones((3,3), np.uint8)\n",
|
||||
"# # out = cv2.morphologyEx(out, cv2.MORPH_ERODE, kernel)\n",
|
||||
"cv2.imshow(\"result1\", dst1)\n",
|
||||
"cv2.waitKey(0)\n",
|
||||
"cv2.destroyAllWindows()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 758,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"angles = np.zeros(len(linesP))\n",
|
||||
"if linesP is not None:\n",
|
||||
" for i in range(0, len(linesP)):\n",
|
||||
" l = linesP[i][0]\n",
|
||||
" angles[i] = mf.lineAngle(l)\n",
|
||||
" cv2.line(cdstP, (l[0], l[1]), (l[2], l[3]), (0,0,255), 3, cv2.LINE_AA)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 759,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# cv2.imshow(\"result1\", cdstP)\n",
|
||||
"# cv2.waitKey(0)\n",
|
||||
"# cv2.destroyAllWindows()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 760,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"-3.093972093706445\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"mode = st.mode(np.around(angles, decimals=3))[0]\n",
|
||||
"rotationangle = np.rad2deg(mode)\n",
|
||||
"print(rotationangle)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 761,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"rotatedcdstP = mf.rotate(basecdstP, rotationangle)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 762,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"vmarginlines = mf.WithinXDegrees(linesP, 7, baseangle=rotationangle)\n",
|
||||
"hmarginlines = mf.WithinXDegrees(linesP, 7, baseangle=90+rotationangle)\n",
|
||||
"vrect = mf.lineBoundingRect(vmarginlines,asRect=False, returnint=True)\n",
|
||||
"hmarginlines = mf.lineswithinrange(hmarginlines, (vrect[0], vrect[1]), (vrect[2],vrect[3]), x=True, y=False)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"if (hmarginlines != []):\n",
|
||||
" marginlines = np.append(vmarginlines, hmarginlines, axis=0)\n",
|
||||
"else:\n",
|
||||
" marginlines = vmarginlines\n",
|
||||
" \n",
|
||||
"rect = mf.lineBoundingRect(marginlines,asRect=False, returnint=True)\n",
|
||||
"cdstP = cv2.rectangle(cdstP, (rect[0],rect[1]), (rect[2],rect[3]), (0,255,0), 3)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 763,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"cv2.imshow(\"result1\", cdstP)\n",
|
||||
"cv2.waitKey(0)\n",
|
||||
"cv2.destroyAllWindows()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 764,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"#####NEED TO WORK ON SCORING THE LINES SO IT PICKS THE CORRECT ORIENTATION (horizontal vs vertical) AND SO THAT THE CROPPING RECTANGLE MOVES/GET TRANSFORMED WITH IT"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 780,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def rotatePoint(img, pt, angle, returnint=True):\n",
|
||||
" rotateaxisx = img.shape[0]/2\n",
|
||||
" rotateaxisy = img.shape[1]/2\n",
|
||||
" tempx = pt[0] - rotateaxisx\n",
|
||||
" tempy = pt[1] - rotateaxisy\n",
|
||||
" rotatedx = tempx*math.cos(np.deg2rad(-angle)) - tempy*math.sin(np.deg2rad(-angle))\n",
|
||||
" rotatedy = tempx*math.sin(np.deg2rad(-angle)) + tempy*math.cos(np.deg2rad(-angle))\n",
|
||||
" finalx = rotatedx + rotateaxisx\n",
|
||||
" finaly = rotatedy + rotateaxisy\n",
|
||||
" if (returnint):\n",
|
||||
" finalx = int(finalx)\n",
|
||||
" finaly = int(finaly)\n",
|
||||
" return (finalx, finaly)\n",
|
||||
" "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 766,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def rotateRect(img, rect, angle, returnint=True, asRect=False):\n",
|
||||
" if (asRect):\n",
|
||||
" pt1 = rotatePoint(img, (rect[0],rect[1]), angle, returnint)\n",
|
||||
" pt2 = rotatePoint(img, (rect[0]+rect[2],rect[1]+rect[3]), angle, returnint)\n",
|
||||
" return (pt1[0], pt1[1], pt2[0]-pt1[0], pt2[1]-pt1[1])\n",
|
||||
" else:\n",
|
||||
" pt1 = rotatePoint(img, (rect[0],rect[1]), angle, returnint)\n",
|
||||
" pt2 = rotatePoint(img, (rect[2],rect[3]), angle, returnint)\n",
|
||||
" return (pt1[0], pt1[1], pt2[0], pt2[1])\n",
|
||||
"\n",
|
||||
"def rotateLine(img, line, angle, returnint=True):\n",
|
||||
" pt1 = rotatePoint(img, (line[0],line[1]), angle, returnint)\n",
|
||||
" pt2 = rotatePoint(img, (line[2],line[3]), angle, returnint)\n",
|
||||
" return (pt1[0], pt1[1], pt2[0], pt2[1])\n",
|
||||
" \n",
|
||||
" "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 767,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# print(linesP.shape)\n",
|
||||
"rotatedlines = [rotateLine(rotatedcdstP, line[0], rotationangle) for line in linesP]\n",
|
||||
"rotatedlines = np.reshape(rotatedlines, (len(rotatedlines),1,4))\n",
|
||||
"# rotatedlines = linesP\n",
|
||||
"# print(rotatedlines.shape)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 768,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"vmarginlines = mf.WithinXDegrees(rotatedlines, 7)\n",
|
||||
"hmarginlines = mf.WithinXDegrees(rotatedlines, 7, baseangle=90)\n",
|
||||
"vrect = mf.lineBoundingRect(vmarginlines,asRect=False, returnint=True)\n",
|
||||
"hmarginlines = mf.lineswithinrange(hmarginlines, (vrect[0], vrect[1]), (vrect[2],vrect[3]), x=True, y=False)\n",
|
||||
"\n",
|
||||
"if (hmarginlines != []):\n",
|
||||
" marginlines = np.append(vmarginlines, hmarginlines, axis=0)\n",
|
||||
"else:\n",
|
||||
" marginlines = vmarginlines\n",
|
||||
" \n",
|
||||
"rect = mf.lineBoundingRect(marginlines,asRect=False, returnint=True)\n",
|
||||
"# rect = vrect\n",
|
||||
"rotatedcdstP = cv2.rectangle(rotatedcdstP, (rect[0],rect[1]), (rect[2],rect[3]), (0,255,0), 3)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 769,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"if rotatedlines is not None:\n",
|
||||
" for i in range(0, len(rotatedlines)):\n",
|
||||
" l = rotatedlines[i][0]\n",
|
||||
" cv2.line(rotatedcdstP, (l[0], l[1]), (l[2], l[3]), (0,0,255), 3, cv2.LINE_AA)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 771,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"cv2.imshow(\"result1\", rotatedcdstP)\n",
|
||||
"# cv2.imshow(\"result1\", cdstP)\n",
|
||||
"cv2.waitKey(0)\n",
|
||||
"cv2.destroyAllWindows()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 394,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"vmarginlines = mf.WithinXDegrees(linesP, 7)\n",
|
||||
"hmarginlines = mf.WithinXDegrees(linesP, 7, baseangle=90)\n",
|
||||
"vrect = mf.lineBoundingRect(vmarginlines,asRect=False, returnint=True)\n",
|
||||
"hmarginlines = mf.lineswithinrange(hmarginlines, (vrect[0], vrect[1]), (vrect[2],vrect[3]), x=True, y=False)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"if (hmarginlines != []):\n",
|
||||
" marginlines = np.append(vmarginlines, hmarginlines, axis=0)\n",
|
||||
"else:\n",
|
||||
" marginlines = vmarginlines\n",
|
||||
"\n",
|
||||
"rect = mf.lineBoundingRect(marginlines,asRect=False, returnint=True)\n",
|
||||
"cdstP = cv2.rectangle(cdstP, (rect[0],rect[1]), (rect[2],rect[3]), (0,255,0), 3)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# rotatedrect = rotateRect(cdstP, rect, -rotationangle)\n",
|
||||
"\n",
|
||||
"# rotatedcdstP = cv2.rectangle(rotatedcdstP, (rotatedrect[0],rotatedrect[1]), (rotatedrect[2],rotatedrect[3]), (0,255,0), 3)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 395,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"###figure out how to rotate rectangle"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 396,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"cv2.imshow(\"result1\", cdstP)\n",
|
||||
"cv2.waitKey(0)\n",
|
||||
"cv2.destroyAllWindows()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 397,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# vmarginlines = mf.WithinXDegrees(linesP, 7)\n",
|
||||
"# hmarginlines = mf.WithinXDegrees(linesP, 7, baseangle=90)\n",
|
||||
"# vrect = mf.lineBoundingRect(vmarginlines,asRect=False, returnint=True)\n",
|
||||
"# hmarginlines = mf.lineswithinrange(hmarginlines, (vrect[0], vrect[1]), (vrect[2],vrect[3]), x=True, y=False)\n",
|
||||
"# # print(hmarginlines)\n",
|
||||
"# if (hmarginlines != []):\n",
|
||||
"# marginlines = np.append(vmarginlines, hmarginlines, axis=0)\n",
|
||||
"# else:\n",
|
||||
"# marginlines = vmarginlines\n",
|
||||
"\n",
|
||||
"# # print(marginlines)\n",
|
||||
"# rect = mf.lineBoundingRect(marginlines,asRect=False, returnint=True)\n",
|
||||
"# # print(rect)\n",
|
||||
"# cdstP = cv2.rectangle(cdstP, (rect[0],rect[1]), (rect[2],rect[3]), (0,255,0), 3)\n",
|
||||
"# # print(cdstP.shape)\n",
|
||||
"# # cropped = cdstP[rect[1]:rect[3], rect[0]:rect[2],:]\n",
|
||||
"\n",
|
||||
"# if marginlines is not None:\n",
|
||||
"# for i in range(0, len(marginlines)):\n",
|
||||
"# l = marginlines[i]\n",
|
||||
"# cv2.line(cdstP, (int(l[0]), int(l[1])), (int(l[2]), int(l[3])), (255,0,0), 3, cv2.LINE_AA)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 398,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# # view result\n",
|
||||
"# # cv2.imshow(\"threshold\", thresh)\n",
|
||||
"# # cv2.imshow(\"morph\", morph)\n",
|
||||
"# # cv2.imshow(\"mask\", mask)\n",
|
||||
"# cv2.imshow(\"result1\", mf.ResizeWithAspectRatio(cdstP,height=1000))\n",
|
||||
"# # cv2.imshow(\"result2\", cropped)\n",
|
||||
"# cv2.waitKey(0)\n",
|
||||
"# cv2.destroyAllWindows()"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.12"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
@ -1,165 +0,0 @@
|
||||
#include "cropper.h"
|
||||
|
||||
#include <opencv2/ximgproc/segmentation.hpp>
|
||||
#include <opencv2/core.hpp>
|
||||
#include <opencv2/imgproc.hpp>
|
||||
|
||||
|
||||
using namespace cv::ximgproc::segmentation;
|
||||
|
||||
|
||||
|
||||
inline cv::Point topLeft(cv::Rect rect) {
|
||||
return cv::Point(rect.x, rect.y);
|
||||
}
|
||||
|
||||
inline cv::Point bottomLeft(cv::Rect rect) {
|
||||
return cv::Point(rect.x, rect.y + rect.height);
|
||||
}
|
||||
|
||||
inline cv::Point topRight(cv::Rect rect) {
|
||||
return cv::Point(rect.x + rect.width, rect.y);
|
||||
}
|
||||
|
||||
inline cv::Point bottomRight(cv::Rect rect) {
|
||||
return cv::Point(rect.x + rect.width, rect.y + rect.height);
|
||||
}
|
||||
|
||||
inline double distanceBetweenPoints(cv::Point p1, cv::Point p2) {
|
||||
return std::sqrt(std::pow(p1.x - p2.x, 2) + std::pow(p1.y - p2.y, 2));
|
||||
}
|
||||
|
||||
inline void scaleRect(cv::Rect& r, int originalheight, int currentheight) {
|
||||
int scalingFactor = originalheight / currentheight;
|
||||
r.x *= scalingFactor;
|
||||
r.y *= scalingFactor;
|
||||
r.width *= scalingFactor;
|
||||
r.height *= scalingFactor;
|
||||
}
|
||||
|
||||
|
||||
// uses the L2 loss of the corners of the rectangles
|
||||
double MSELossRect(cv::Rect r1, cv::Rect r2) {
|
||||
return (distanceBetweenPoints(topLeft(r1), topLeft(r2)) +
|
||||
distanceBetweenPoints(bottomLeft(r1), bottomLeft(r2)) +
|
||||
distanceBetweenPoints(topRight(r1), topRight(r2)) +
|
||||
distanceBetweenPoints(bottomRight(r1), bottomRight(r2))) / 4.0;
|
||||
}
|
||||
|
||||
std::vector<cv::Rect> selectiveSearchSegmentationActor(cv::InputArray src, bool fast = true, int imageHeight = 800) {
|
||||
cv::setUseOptimized(true);
|
||||
cv::setNumThreads(4);
|
||||
|
||||
cv::Mat temp = src.getMat();
|
||||
|
||||
cv::Ptr<cv::ximgproc::segmentation::SelectiveSearchSegmentation> ss =
|
||||
createSelectiveSearchSegmentation();
|
||||
|
||||
|
||||
ss->setBaseImage(temp);
|
||||
|
||||
if (fast) {
|
||||
ss->switchToSelectiveSearchFast();
|
||||
} else {
|
||||
ss->switchToSelectiveSearchQuality();
|
||||
}
|
||||
|
||||
|
||||
std::vector<cv::Rect> rects;
|
||||
|
||||
ss->process(rects);
|
||||
return rects;
|
||||
}
|
||||
|
||||
inline double clip(double n, double lower, double upper) {
|
||||
return std::max(lower, std::min(n, upper));
|
||||
};
|
||||
|
||||
inline double colourscaler(double n, double min, double max) {
|
||||
double temp = n - min;
|
||||
double diff = std::abs(max - min);
|
||||
return clip((temp / diff) * 255, 0, 255);
|
||||
};
|
||||
|
||||
|
||||
cv::Rect cannyEdgeRectangle(cv::InputArray src, int lower = 100, int upper = 255, double threshold1 = 50, double threshold2 = 350) {
|
||||
cv::Mat gray, scaled_gray, blurred, edged;
|
||||
|
||||
lower = std::max(lower, 0);
|
||||
upper = std::min(upper, 255);
|
||||
|
||||
cv::cvtColor(src, gray, cv::COLOR_BGR2GRAY);
|
||||
scaled_gray = cv::Mat::zeros(gray.size(), gray.type());
|
||||
|
||||
for (int y = 0; y < gray.rows; y++) {
|
||||
for (int x = 0; x < gray.cols; x++) {
|
||||
scaled_gray.at<uchar>(y, x) =
|
||||
cv::saturate_cast<uchar>(colourscaler(gray.at<uchar>(y, x), lower, upper));
|
||||
}
|
||||
}
|
||||
|
||||
cv::GaussianBlur(scaled_gray, blurred, cv::Size(15, 15), 0);
|
||||
cv::Canny(blurred, edged, threshold1, threshold2);
|
||||
|
||||
std::vector<std::vector<cv::Point>> contours;
|
||||
std::vector<cv::Vec4i> heirarchy;
|
||||
cv::Mat approx;
|
||||
|
||||
cv::findContours(edged, contours, heirarchy, cv::RETR_TREE, cv::CHAIN_APPROX_SIMPLE);
|
||||
|
||||
cv::cvtColor(gray, gray, cv::COLOR_GRAY2BGR);
|
||||
|
||||
std::sort(contours.begin(), contours.end(), [](std::vector<cv::Point> a, std::vector<cv::Point> b) {
|
||||
return cv::arcLength(a, false) > cv::arcLength(b, false); });
|
||||
|
||||
int numContours = contours.size();
|
||||
|
||||
|
||||
return cv::boundingRect(contours[0]);
|
||||
}
|
||||
|
||||
bool crop(cv::InputArray src, cv::OutputArray dst, bool fastsearch, int imageHeight) { //add other params or maybe overload or something
|
||||
|
||||
cv::Mat temp;
|
||||
src.copyTo(temp);
|
||||
int newWidth = temp.cols * imageHeight / temp.rows;
|
||||
cv::resize(temp, temp, cv::Size(newWidth, imageHeight));
|
||||
|
||||
|
||||
|
||||
cv::Rect cannyRect = cannyEdgeRectangle(temp, 100, 255, 255 / 4, 255);
|
||||
|
||||
|
||||
std::vector<cv::Rect> rects = selectiveSearchSegmentationActor(temp, fastsearch);
|
||||
|
||||
|
||||
|
||||
int indexOfMin = -1;
|
||||
double currentMin = std::numeric_limits<double>::max();
|
||||
|
||||
int lengthOfRects = rects.size();
|
||||
for (int i = 0; i < lengthOfRects; i++) {
|
||||
double tempMin = MSELossRect(rects[i], cannyRect);
|
||||
if (tempMin < currentMin) {
|
||||
indexOfMin = i;
|
||||
currentMin = tempMin;
|
||||
}
|
||||
}
|
||||
|
||||
cv::Rect goodRect = rects[indexOfMin];
|
||||
|
||||
|
||||
cv::Rect finalRect;
|
||||
if (goodRect.area() > cannyRect.area()) {
|
||||
finalRect = goodRect;
|
||||
} else {
|
||||
finalRect = cannyRect;
|
||||
}
|
||||
cv::Mat extra = src.getMat();
|
||||
scaleRect(finalRect, extra.rows, temp.rows);
|
||||
extra = extra(finalRect);
|
||||
extra.copyTo(dst);
|
||||
|
||||
|
||||
return true;
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
Binary file not shown.
|
Before Width: | Height: | Size: 3.0 MiB |
Binary file not shown.
|
Before Width: | Height: | Size: 2.1 MiB |
Binary file not shown.
|
Before Width: | Height: | Size: 5.3 MiB |
Binary file not shown.
|
Before Width: | Height: | Size: 2.3 MiB |
Binary file not shown.
|
Before Width: | Height: | Size: 139 KiB |
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -1,62 +0,0 @@
|
||||
cmake_minimum_required(VERSION 3.22)
|
||||
project(imagemanipulation_libraries
|
||||
VERSION 0.1
|
||||
DESCRIPTION "Libraries for image preprocessing"
|
||||
LANGUAGES CXX)
|
||||
|
||||
include(GNUInstallDirs)
|
||||
|
||||
find_package(OpenCV REQUIRED)
|
||||
|
||||
|
||||
|
||||
# RECTANGLE
|
||||
|
||||
add_library(rect SHARED src/rectangle.cpp)
|
||||
|
||||
target_compile_features(rect PRIVATE cxx_std_20)
|
||||
|
||||
# set_target_properties(rect PROPERTIES VERSION ${PROJECT_VERSION}) # git can't deal with the symlinks for some reason
|
||||
|
||||
# set_target_properties(rect PROPERTIES PUBLIC_HEADER ${CMAKE_CURRENT_SOURCE_DIR}/include/rect_lib.h)
|
||||
|
||||
set_target_properties(rect PROPERTIES LIBRARY_OUTPUT_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/lib)
|
||||
|
||||
|
||||
target_link_libraries(rect ${OpenCV_LIBS})
|
||||
|
||||
target_include_directories(rect
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include
|
||||
PRIVATE ${OpenCV_INCLUDE_DIRS})
|
||||
|
||||
|
||||
# LINE
|
||||
|
||||
add_library(line SHARED src/line.cpp)
|
||||
|
||||
target_compile_features(line PRIVATE cxx_std_20)
|
||||
|
||||
# set_target_properties(line PROPERTIES VERSION ${PROJECT_VERSION}) # git can't deal with the symlinks for some reason
|
||||
|
||||
# set_target_properties(line PROPERTIES PUBLIC_HEADER ${CMAKE_CURRENT_SOURCE_DIR}/include/line_lib.h)
|
||||
|
||||
set_target_properties(line PROPERTIES LIBRARY_OUTPUT_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/lib)
|
||||
|
||||
|
||||
target_link_libraries(line ${OpenCV_LIBS})
|
||||
|
||||
target_include_directories(line
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include
|
||||
PRIVATE ${OpenCV_INCLUDE_DIRS})
|
||||
|
||||
|
||||
# install(TARGETS rect
|
||||
# LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}
|
||||
# PUBLIC_HEADER DESTINATION ${CMAKE_INSTALL_INCLUDEDIR})
|
||||
|
||||
# find_package(OpenCV REQUIRED)
|
||||
|
||||
# target_include_directories(CropperEx
|
||||
# PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include
|
||||
# PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../externallibraries/stbimagehelpers
|
||||
# PRIVATE ${OpenCV_INCLUDE_DIRS})
|
||||
@ -1,20 +0,0 @@
|
||||
#ifndef LINE_H
|
||||
#define LINE_H
|
||||
|
||||
|
||||
class Line {
|
||||
private:
|
||||
|
||||
public:
|
||||
|
||||
|
||||
private:
|
||||
|
||||
|
||||
public:
|
||||
|
||||
|
||||
};
|
||||
|
||||
|
||||
#endif //LINE_H
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user