Merge pull request 'WIP backend' (#27) from backend into main

Reviewed-on: #27
This commit is contained in:
Ethan Wellenreiter 2025-05-06 18:40:45 -04:00
commit 971914442c
152 changed files with 2432 additions and 23375 deletions

52
backend/.air.toml Normal file
View File

@ -0,0 +1,52 @@
root = "."
testdata_dir = "testdata"
tmp_dir = "bin"
[build]
args_bin = []
bin = "./bin/main.exe"
cmd = "go build -o ./bin/main.exe ./cmd/api"
delay = 1000
exclude_dir = ["assets", "bin", "vendor", "testdata", "docs", "scripts"]
exclude_file = []
exclude_regex = ["_test.go"]
exclude_unchanged = false
follow_symlink = false
full_bin = ""
include_dir = []
include_ext = ["go", "tpl", "tmpl", "html"]
include_file = []
kill_delay = "0s"
log = "build-errors.log"
poll = false
poll_interval = 0
post_cmd = []
pre_cmd = []
rerun = false
rerun_delay = 500
send_interrupt = false
stop_on_error = false
[color]
app = ""
build = "yellow"
main = "magenta"
runner = "green"
watcher = "cyan"
[log]
main_only = false
silent = false
time = false
[misc]
clean_on_exit = false
[proxy]
app_port = 0
enabled = false
proxy_port = 0
[screen]
clear_on_rebuild = false
keep_scroll = true

View File

@ -0,0 +1 @@
exit status 1exit status 1exit status 1exit status 1exit status 1exit status 1exit status 1exit status 1

BIN
backend/bin/main.exe Normal file

Binary file not shown.

View File

@ -0,0 +1,14 @@
package main
func init() {
}
func transformImage() {
// imgtransform.ResizeImage(,10, 10)
}
func main() {
}

View File

@ -0,0 +1 @@
Will use minio or s3 lambda object transform functions. This can be used to get thumbnail or other sized images so that the full s3 doesn't have to be used.

View File

@ -0,0 +1,2 @@
package main

266
backend/cmd/api/api.go Normal file
View File

@ -0,0 +1,266 @@
package main
import (
"context"
"errors"
"net/http"
"os"
"os/signal"
"syscall"
"time"
"github.com/go-chi/chi/v5"
"github.com/go-chi/chi/v5/middleware"
"github.com/go-chi/httprate"
// "git.ewellenr.ca/receipt_indexer/backend/internal/auth"
"git.ewellenr.ca/receipt_indexer/backend/internal/env"
"git.ewellenr.ca/receipt_indexer/backend/internal/logger"
"git.ewellenr.ca/receipt_indexer/backend/internal/ratelimiter"
"git.ewellenr.ca/receipt_indexer/backend/internal/storage"
auth_storage "git.ewellenr.ca/receipt_indexer/backend/internal/storage/auth"
"git.ewellenr.ca/receipt_indexer/backend/internal/storage/cache"
)
type application struct {
// set up the configs stuff here
config config
auth auth_storage.AuthStorage
store storage.Storage
logger logger.Logger
cacheStorage cache.Storage
rateLimiter ratelimiter.Limiter
environment env.Environment
}
type config struct {
addr string
rateLimiter ratelimiter.Config
redisCfg redisConfig
// holds the different stuff like rate limiter, store, authenticator
}
type redisConfig struct {
addr string
pw string
db int
enabled bool
}
func (app *application) mount() http.Handler {
r := chi.NewRouter()
r.Use(middleware.Logger)
r.Use(middleware.RequestID)
r.Use(middleware.RealIP)
r.Use(middleware.CleanPath)
r.Use(middleware.Recoverer)
r.Use(middleware.Throttle(100)) // temporary or removable. throttles the whole thing to 1000 concurrent requests
// r.Use(cors.Handler(cors.Options{
// AllowedOrigins: []string{env.GetString("CORS_ALLOWED_ORIGIN", "http://localhost:5174")},
// AllowedMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"},
// AllowedHeaders: []string{"Accept", "Authorization", "Content-Type", "X-CSRF-Token"},
// ExposedHeaders: []string{"Link"},
// AllowCredentials: false,
// MaxAge: 300, // Maximum value not ignored by any of major browsers
// }))
// use either the chi built rate limiter or the custom built one
if app.config.rateLimiter.Enabled {
// r.Use(app.RateLimiterMiddleware)
r.Use(httprate.LimitByRealIP(app.config.rateLimiter.RequestsPerTimeFrame, app.config.rateLimiter.TimeFrame))
}
// Set a timeout value on the request context (ctx), that will signal
// through ctx.Done() that the request has timed out and further
// processing should be stopped.
r.Use(middleware.Timeout(60 * time.Second))
r.Use(middleware.Heartbeat("/ping"))
// v1 of api
r.Route("/v1", func(r chi.Router) {
// SEPERATE STUFF FOR THE LOGIN RELATED STUFF. CONSIDER A GROUP
// FIX API. NEED TO ALSO CONSIDER GROUPS AND STUFF
// Operations
// r.Get("/health", app.healthCheckHandler)
// r.With(app.BasicAuthMiddleware()).Get("/debug/vars", expvar.Handler().ServeHTTP)
// docsURL := fmt.Sprintf("%s/swagger/doc.json", app.config.addr)
// r.Get("/swagger/*", httpSwagger.Handler(httpSwagger.URL(docsURL)))
// Need to sign in as a user. Then, you can see the groups you're in, your role in the groups,
r.Route("/user", func(r chi.Router) {
r.Route("/{userID}", func(r chi.Router) {
r.Use(app.AuthSessionMiddleware, app.CSRFCheckMiddleware, app.CheckUserMatchingMiddleware)
r.Get("/", app.getUserHandler)
r.Route("/groups", func(r chi.Router) {
r.Get("/", app.getUsersGroupsHandler)
r.Route("/{groupID}", func(r chi.Router) {
r.Get("/", app.getUsersGroupHandler)
r.Delete("/", app.removeUserGroupHandler) // maybe this should expect authentication headers to reverify the password when deleting a group you own.
r.Put("/moderator", app.addGroupModeratorHandler)
r.Delete("/moderator/{secondaryuserID}", app.removeModeratorPriviligesHandler)
r.Get("/users", app.getGroupUsersHandler)
r.Delete("/users/{secondaryuserID}", app.removeUserFromGroupHandler)
r.Put("/owner", app.setGroupOwnerHandler)
})
})
r.Route("/receipts", func(r chi.Router) {
r.Get("/", app.getReceiptsHandler)
r.Route("/{receiptID}", func(r chi.Router) {
r.Get("/", app.getReceiptHandler)
r.Delete("/", app.deleteReceiptHandler)
r.Route("/images", func(r chi.Router) {
r.Get("/", app.getReceiptImagesHandler)
r.Put("/", app.addReceiptImageHandler)
r.Route("/{imageID}", func(r chi.Router) {
r.Get("/", app.getReceiptImageHandler)
r.Put("/", app.changeReceiptImageHandler)
r.Delete("/", app.deleteReceiptImageHandler)
})
})
})
})
})
})
r.Use(app.CSRFCheckMiddleware)
r.Group(func(r chi.Router) {
r.Use(app.AuthSessionMiddleware)
r.Use(app.CSRFCheckMiddleware)
r.Route("/groups", func(r chi.Router) {
r.Get("/", app.getGroupsHandler)
r.Route("/{groupID}", func(r chi.Router) {
r.Get("/", app.getGroupHandler)
})
})
r.Route("/users", func(r chi.Router) {
r.With(app.CheckRoleMiddleware("admin")).Get("/", app.getUsersHandler)
r.Route("/{userID}", func(r chi.Router) {
r.With(app.CheckRoleMiddleware("admin")).Delete("/", app.getUserHandler)
})
})
})
// r.Route("/users", func(r chi.Router) {
// // r.Put("/activate/{token}", app.activateUserHandler)
// // r.Get("/", app.check)
// r.Route("/{userID}", func(r chi.Router) {
// r.Use(app.AuthSessionMiddleware)
// r.Use(app.CSRFCheckMiddleware)
// // r.Use(app.CheckUserMatchingMiddleware)
// r.Get("/", app.getUserHandler)
// r.Route("/receipts", func(r chi.Router) {
// r.With(app.Paginate).Get("/", app.getReceiptsHandler)
// r.Post("/", app.createReceiptHandler)
// r.Route("/{receiptID}", func(r chi.Router) {
// r.Use(app.receiptsContextMiddleware)
// r.Get("/", app.getReceiptHandler)
// r.Patch("/", app.updateReceiptHandler)
// r.Delete("/", app.checkReceiptOwnership("admin", app.deleteReceiptHandler))
// r.Route("/images", func(r chi.Router) {
// r.Post("/", app.addImageHandler)
// r.Delete("/{imageID}", app.deleteImageHandler)
// })
// })
// })
// })
// })
// // Admin page routes
// r.Route("/admin", func(r chi.Router) {
// r.Use(app.AuthSessionMiddleware)
// r.Use(app.CheckRoleMiddleware("admin"))
// r.Route("/users", func(r chi.Router) {
// r.Get("/", app.getUsersHandler)
// r.Delete("/{userID}", app.deleteUserHandler)
// })
// })
// Public routes
r.Route("/auth", func(r chi.Router) {
r.Post("/login", app.loginHandler)
r.Post("/newuser", app.registerUserHandler)
r.Post("/refreshtoken", app.refreshTokenHandler)
r.Post("/logout", app.logoutHandler)
})
})
return r
}
func (app *application) run(mux http.Handler) error {
srv := &http.Server{
Addr: app.config.addr,
Handler: mux,
WriteTimeout: time.Second * 30,
ReadTimeout: time.Second * 10,
IdleTimeout: time.Minute,
}
shutdown := make(chan error)
go func() {
quit := make(chan os.Signal, 1)
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
s := <-quit
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
app.logger.Info("signal caught", "signal", s.String())
shutdown <- srv.Shutdown(ctx)
}()
app.logger.Info("server has started", "addr", app.config.addr) //, "env", app.config.env)
err := srv.ListenAndServe()
if !errors.Is(err, http.ErrServerClosed) {
return err
}
err = <-shutdown
if err != nil {
return err
}
app.logger.Info("server has stopped", "addr", app.config.addr) //, "env", app.config.env)
return nil
}

91
backend/cmd/api/auth.go Normal file
View File

@ -0,0 +1,91 @@
package main
import (
"encoding/base64"
"fmt"
"net/http"
"strings"
)
func (app *application) loginHandler(w http.ResponseWriter, r *http.Request) {
// should give them a cookie in the response
authHeader := r.Header.Get("Authorization")
if authHeader == "" {
app.unauthorizedBasicErrorResponse(w, r, fmt.Errorf("authorization header is missing"))
return
}
ctx := r.Context()
// parse it -> get the base64
parts := strings.Split(authHeader, " ")
if len(parts) != 2 || parts[0] != "Basic" {
app.unauthorizedBasicErrorResponse(w, r, fmt.Errorf("authorization header is malformed"))
return
}
// decode it
decoded, err := base64.StdEncoding.DecodeString(parts[1])
if err != nil {
app.unauthorizedBasicErrorResponse(w, r, err)
return
}
// check the credentials
creds := strings.SplitN(string(decoded), ":", 2)
if len(creds) != 2 {
app.unauthorizedBasicErrorResponse(w, r, fmt.Errorf("invalid credentials"))
return
}
username, pass := creds[0], creds[1]
valid, user, err := app.auth.Users.SigninUser(ctx, username, pass)
if !valid || err != nil {
app.unauthorizedBasicErrorResponse(w, r, fmt.Errorf("invalid credentials"))
return
}
token, err := app.auth.Sessions.AddSession(ctx, user.ID)
if err != nil {
app.unauthorizedBasicErrorResponse(w, r, fmt.Errorf("failed to add session"))
return
}
w.Header().Add("Vary", "Cookie")
w.Header().Add("Cache-Control", `no-cache="Set-Cookie"`)
http.SetCookie(w, &http.Cookie{
Name: "Session_token",
Value: token,
Path: "/",
HttpOnly: true,
Secure: true,
SameSite: http.SameSiteStrictMode,
// Set an expiry and path
})
// also need to set the csrf or cors or refresh stuff here. Probably just the refresh token stuff (actually, with sessions, there will be no refresh token. Just updating the expiry)
// cache and ignore if the cache fails?
if !app.config.redisCfg.enabled {
if err := app.cacheStorage.Users.Set(ctx, user); err != nil {
app.internalServerError(w, r, fmt.Errorf("Failed to add user to cache"))
return
}
}
// since it could be a different storage system, idk if it makes sense to do any caching or to just deal with the overhead of logging in. It doesn't happen that often anyways so caching shouldn't make much of a difference. Maybe load it into cache after
// consider using the cached user stuff instead. that user will have the hashed password which is fine to move around. And then it can just do the compare stuff using whatever auth. Maybe instead, the checkpass can take in a user
}
func (app *application) logoutHandler(w http.ResponseWriter, r *http.Request) {
// should give them a cookie in the response
}
func (app *application) registerUserHandler(w http.ResponseWriter, r *http.Request) {
}
func (app *application) refreshTokenHandler(w http.ResponseWriter, r *http.Request) {
}

57
backend/cmd/api/errors.go Normal file
View File

@ -0,0 +1,57 @@
package main
import (
"net/http"
)
func (app *application) internalServerError(w http.ResponseWriter, r *http.Request, err error) {
app.logger.Error("internal error", "method", r.Method, "path", r.URL.Path, "error", err.Error())
writeJSONError(w, http.StatusInternalServerError, "the server encountered a problem")
}
func (app *application) forbiddenResponse(w http.ResponseWriter, r *http.Request) {
app.logger.Warn("forbidden", "method", r.Method, "path", r.URL.Path, "error")
writeJSONError(w, http.StatusForbidden, "forbidden")
}
func (app *application) badRequestResponse(w http.ResponseWriter, r *http.Request, err error) {
app.logger.Warn("bad request", "method", r.Method, "path", r.URL.Path, "error", err.Error())
writeJSONError(w, http.StatusBadRequest, err.Error())
}
func (app *application) conflictResponse(w http.ResponseWriter, r *http.Request, err error) {
app.logger.Error("conflict response", "method", r.Method, "path", r.URL.Path, "error", err.Error())
writeJSONError(w, http.StatusConflict, err.Error())
}
func (app *application) notFoundResponse(w http.ResponseWriter, r *http.Request, err error) {
app.logger.Warn("not found error", "method", r.Method, "path", r.URL.Path, "error", err.Error())
writeJSONError(w, http.StatusNotFound, "not found")
}
func (app *application) unauthorizedErrorResponse(w http.ResponseWriter, r *http.Request, err error) {
app.logger.Warn("unauthorized error", "method", r.Method, "path", r.URL.Path, "error", err.Error())
writeJSONError(w, http.StatusUnauthorized, "unauthorized")
}
func (app *application) unauthorizedBasicErrorResponse(w http.ResponseWriter, r *http.Request, err error) {
app.logger.Warn("unauthorized basic error", "method", r.Method, "path", r.URL.Path, "error", err.Error())
w.Header().Set("WWW-Authenticate", `Basic realm="restricted", charset="UTF-8"`)
writeJSONError(w, http.StatusUnauthorized, "unauthorized")
}
func (app *application) rateLimitExceededResponse(w http.ResponseWriter, r *http.Request, retryAfter string) {
app.logger.Warn("rate limit exceeded", "method", r.Method, "path", r.URL.Path)
w.Header().Set("Retry-After", retryAfter)
writeJSONError(w, http.StatusTooManyRequests, "rate limit exceeded, retry after: "+retryAfter)
}

View File

@ -0,0 +1,8 @@
package main
import "net/http"
func (app *application) healthCheckHandler(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("ok"))
}

44
backend/cmd/api/images.go Normal file
View File

@ -0,0 +1,44 @@
package main
import (
"context"
"net/http"
"git.ewellenr.ca/receipt_indexer/backend/internal/storage"
)
func (app *application) getImage(ctx context.Context, imageID int64) (*storage.Image, error) {
if !app.config.redisCfg.enabled {
return app.store.Images.GetByID(ctx, imageID)
}
image, err := app.cacheStorage.ReceiptImage.Get(ctx, imageID)
if err != nil {
return nil, err
}
if image == nil {
image, err = app.store.Images.GetByID(ctx, imageID)
if err != nil {
return nil, err
}
if err := app.cacheStorage.ReceiptImage.Set(ctx, image); err != nil {
return nil, err
}
}
return image, nil
}
func (app *application) addImageHandler(w http.ResponseWriter, r *http.Request) {
// create a new image, add it to the receipt. this should be a function because it should be one whole transaction
// it should be do the database transaction, attempt the upload, and then abort/commit the transaction depending on the results of the upload. While uploading directly to the s3/minio would be good, it doesn't provide the transactionality that is required
}
func (app *application) deleteImageHandler(w http.ResponseWriter, r *http.Request) {
// delete image and remove from cache
// this also needs to be transactional first the database transaction stuff, then attempt the delete, and then abort/commit the transaction depending on the results of the upload.
}

20
backend/cmd/api/json.go Normal file
View File

@ -0,0 +1,20 @@
package main
import (
"encoding/json"
"net/http"
)
func writeJSON(w http.ResponseWriter, status int, data any) error {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
return json.NewEncoder(w).Encode(data)
}
func writeJSONError(w http.ResponseWriter, status int, message string) error {
type envelope struct {
Error string `json:"error"`
}
return writeJSON(w, status, &envelope{Error: message})
}

21
backend/cmd/api/main.go Normal file
View File

@ -0,0 +1,21 @@
package main
import (
"log"
)
func main() {
// set up the application config here
cfg := config{
addr: ":8080",
}
app := application{
config: cfg,
}
// fmt.Println(app)
mux := app.mount()
log.Fatal(app.run(mux))
}

View File

@ -0,0 +1,197 @@
package main
import (
"context"
"fmt"
"net"
"net/http"
"strconv"
auth_storage "git.ewellenr.ca/receipt_indexer/backend/internal/storage/auth"
"github.com/go-chi/chi/v5"
)
const sessionCtx string = "session_id"
func (app *application) AuthSessionMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
cookie, err := r.Cookie("Session")
if err != nil {
app.unauthorizedErrorResponse(w, r, fmt.Errorf("Session cookie is missing"))
return
}
// parse cookie
token := cookie.Value
if token == "" {
app.unauthorizedErrorResponse(w, r, fmt.Errorf("Empty session token"))
return
}
valid, userID, err := app.auth.Sessions.CheckSession(r.Context(), token) // should have a different function for this
if !valid {
app.unauthorizedErrorResponse(w, r, fmt.Errorf("Invalid session token"))
return
}
// userID comes from the session check
// need to select the user from the token validation stuff
ctx := r.Context()
user, err := app.getUser(ctx, userID)
if err != nil {
app.unauthorizedErrorResponse(w, r, err)
return
}
ctx = context.WithValue(ctx, userCtx, user)
ctx = context.WithValue(ctx, sessionCtx, token)
// make sure to add user and role into the context here
next.ServeHTTP(w, r.WithContext(ctx))
})
}
func (app *application) CSRFCheckMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
})
}
func (app *application) Paginate(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// add the page and size to the context
})
}
// Also include the stuff like getting the receipt and stuff
func (app *application) checkRole(next http.Handler, roleName string) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
user := getUserFromContext(r)
allowed, err := app.checkRolePrecedence(r.Context(), user, roleName)
if err != nil {
app.internalServerError(w, r, err)
return
}
if !allowed {
app.forbiddenResponse(w, r)
return
}
next.ServeHTTP(w, r)
})
}
// essentially a role checking middleware factory
func (app *application) CheckRoleMiddleware(roleName string) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return app.checkRole(next, roleName)
}
}
func (app *application) CheckUserMatchingMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
urlUserID, err := strconv.ParseInt(chi.URLParam(r, "userID"), 10, 64)
if err != nil {
app.badRequestResponse(w, r, fmt.Errorf("Invalid url user ID - Not an integer"))
return
}
user := getUserFromContext(r)
if int64(urlUserID) != user.ID {
app.forbiddenResponse(w, r)
return
}
next.ServeHTTP(w, r)
})
}
func (app *application) RateLimiterMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if app.config.rateLimiter.Enabled {
ip, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
app.internalServerError(w, r, err)
return
}
if allow, retryAfter := app.rateLimiter.Allow(ip); !allow {
app.rateLimitExceededResponse(w, r, retryAfter.String())
return
}
}
next.ServeHTTP(w, r)
})
}
func (app *application) receiptsContextMiddleware(next http.Handler) http.Handler {
// add the receipt id to the context? or the receipt class to the context
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
receiptID, err := strconv.ParseInt(chi.URLParam(r, "receiptID"), 10, 64)
if err != nil {
app.badRequestResponse(w, r, fmt.Errorf("Invalid url receipt ID - Not an integer"))
return
}
// need to select the receipt
ctx := r.Context()
receipt, err := app.getReceipt(ctx, receiptID)
if err != nil {
app.internalServerError(w, r, err)
return
}
if receipt.OwnerID != getUserFromContext(r).ID {
app.forbiddenResponse(w, r)
return
}
ctx = context.WithValue(ctx, receiptCtx, receipt)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
func (app *application) checkRolePrecedence(ctx context.Context, user *auth_storage.User, roleName string) (bool, error) {
role, err := app.auth.Roles.GetByName(ctx, roleName)
if err != nil {
return false, err
}
return user.Role.Level >= role.Level, nil
}
func (app *application) checkReceiptOwnership(requiredRole string, next http.HandlerFunc) http.HandlerFunc {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
user := getUserFromContext(r)
receipt := getReceiptFromContext(r)
if receipt.OwnerID == user.ID {
next.ServeHTTP(w, r)
return
}
allowed, err := app.checkRolePrecedence(r.Context(), user, requiredRole)
if err != nil {
app.internalServerError(w, r, err)
return
}
if !allowed {
app.forbiddenResponse(w, r)
return
}
next.ServeHTTP(w, r)
})
}

View File

@ -0,0 +1,70 @@
package main
import (
"context"
"net/http"
"git.ewellenr.ca/receipt_indexer/backend/internal/storage"
)
type receiptKey string
const receiptCtx receiptKey = "receipt"
func (app *application) getReceiptsHandler(w http.ResponseWriter, r *http.Request) {
// get the page and size from context
// default them to something
}
func (app *application) getReceiptHandler(w http.ResponseWriter, r *http.Request) {
}
func (app *application) createReceiptHandler(w http.ResponseWriter, r *http.Request) {
// handle receipt creation logic here
}
func (app *application) updateReceiptHandler(w http.ResponseWriter, r *http.Request) {
// handle receipt update logic here
// Not too sure what to do here. need to break it down into what can actually be updated via the api
}
func (app *application) deleteReceiptHandler(w http.ResponseWriter, r *http.Request) {
// delete the receipt
// should be as simple as getting the receipt and then calling the delete function.
// Should also delete from cache if cache is enabled
// the delete function should be like a transaction. It should do the transaction for the receipt and images in the db.
// Then it should try and delete all the images in the S3/minio bucket. Abort/commit otherwise
}
func (app *application) getReceipt(ctx context.Context, receiptID int64) (*storage.Receipt, error) {
if !app.config.redisCfg.enabled {
return app.store.Receipts.GetByID(ctx, receiptID)
}
receipt, err := app.cacheStorage.Receipts.Get(ctx, receiptID)
if err != nil {
return nil, err
}
if receipt == nil {
receipt, err = app.store.Receipts.GetByID(ctx, receiptID)
if err != nil {
return nil, err
}
if err := app.cacheStorage.Receipts.Set(ctx, receipt); err != nil {
return nil, err
}
}
return receipt, nil
}
func getReceiptFromContext(r *http.Request) *storage.Receipt {
receipt, _ := r.Context().Value(receiptCtx).(*storage.Receipt)
return receipt
}

50
backend/cmd/api/users.go Normal file
View File

@ -0,0 +1,50 @@
package main
import (
"context"
"net/http"
auth_storage "git.ewellenr.ca/receipt_indexer/backend/internal/storage/auth"
)
type userKey string
const userCtx userKey = "user"
func (app *application) getUserHandler(w http.ResponseWriter, r *http.Request) {
}
func (app *application) getUsersHandler(w http.ResponseWriter, r *http.Request) {
}
func (app *application) deleteUserHandler(w http.ResponseWriter, r *http.Request) {
}
func (app *application) getUser(ctx context.Context, userID int64) (*auth_storage.User, error) {
if !app.config.redisCfg.enabled {
return app.auth.Users.GetByID(ctx, userID)
}
user, err := app.cacheStorage.Users.Get(ctx, userID)
if err != nil {
return nil, err
}
if user == nil {
user, err = app.auth.Users.GetByID(ctx, userID)
if err != nil {
return nil, err
}
if err := app.cacheStorage.Users.Set(ctx, user); err != nil {
return nil, err
}
}
return user, nil
}
func getUserFromContext(r *http.Request) *auth_storage.User {
user, _ := r.Context().Value(userCtx).(*auth_storage.User)
return user
}

22
backend/go.mod Normal file
View File

@ -0,0 +1,22 @@
module git.ewellenr.ca/receipt_indexer/backend
go 1.24.2
require (
github.com/cespare/xxhash/v2 v2.2.0 // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
github.com/go-chi/chi/v5 v5.2.1 // indirect
github.com/go-chi/httprate v0.15.0 // indirect
github.com/klauspost/cpuid/v2 v2.2.10 // indirect
github.com/mattn/go-colorable v0.1.13 // indirect
github.com/mattn/go-isatty v0.0.19 // indirect
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 // indirect
github.com/redis/go-redis/v9 v9.7.3 // indirect
github.com/rs/zerolog v1.34.0 // indirect
github.com/zeebo/xxh3 v1.0.2 // indirect
go.uber.org/multierr v1.10.0 // indirect
go.uber.org/zap v1.27.0 // indirect
golang.org/x/crypto v0.37.0 // indirect
golang.org/x/sys v0.32.0 // indirect
golang.org/x/time v0.11.0 // indirect
)

40
backend/go.sum Normal file
View File

@ -0,0 +1,40 @@
github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44=
github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
github.com/go-chi/chi/v5 v5.2.1 h1:KOIHODQj58PmL80G2Eak4WdvUzjSJSm0vG72crDCqb8=
github.com/go-chi/chi/v5 v5.2.1/go.mod h1:L2yAIGWB3H+phAw1NxKwWM+7eUH/lU8pOMm5hHcoops=
github.com/go-chi/httprate v0.15.0 h1:j54xcWV9KGmPf/X4H32/aTH+wBlrvxL7P+SdnRqxh5g=
github.com/go-chi/httprate v0.15.0/go.mod h1:rzGHhVrsBn3IMLYDOZQsSU4fJNWcjui4fWKJcCId1R4=
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
github.com/klauspost/cpuid/v2 v2.2.10 h1:tBs3QSyvjDyFTq3uoc/9xFpCuOsJQFNPiAhYdw2skhE=
github.com/klauspost/cpuid/v2 v2.2.10/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0=
github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA=
github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg=
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA=
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 h1:zYyBkD/k9seD2A7fsi6Oo2LfFZAehjjQMERAvZLEDnQ=
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646/go.mod h1:jpp1/29i3P1S/RLdc7JQKbRpFeM1dOBd8T9ki5s+AY8=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/redis/go-redis/v9 v9.7.3 h1:YpPyAayJV+XErNsatSElgRZZVCwXX9QzkKYNvO7x0wM=
github.com/redis/go-redis/v9 v9.7.3/go.mod h1:bGUrSggJ9X9GUmZpZNEOQKaANxSGgOEBRltRTZHSvrA=
github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0=
github.com/rs/zerolog v1.34.0 h1:k43nTLIwcTVQAncfCw4KZ2VY6ukYoZaBPNOE8txlOeY=
github.com/rs/zerolog v1.34.0/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ=
github.com/zeebo/xxh3 v1.0.2 h1:xZmwmqxHZA8AI603jOQ0tMqmBr9lPeFwGg6d+xy9DC0=
github.com/zeebo/xxh3 v1.0.2/go.mod h1:5NWz9Sef7zIDm2JHfFlcQvNekmcEl9ekUZQQKCYaDcA=
go.uber.org/multierr v1.10.0 h1:S0h4aNzvfcFsC3dRF1jLoaov7oRaKqRGC/pUEJ2yvPQ=
go.uber.org/multierr v1.10.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
go.uber.org/zap v1.27.0 h1:aJMhYGrd5QSmlpLMr2MftRKl7t8J8PTZPA732ud/XR8=
go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E=
golang.org/x/crypto v0.37.0 h1:kJNSjF/Xp7kU0iB2Z+9viTPMW4EqqsrywMXLJOOsXSE=
golang.org/x/crypto v0.37.0/go.mod h1:vg+k43peMZ0pUMhYmVAWysMK35e6ioLh3wB8ZCAfbVc=
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.32.0 h1:s77OFDvIQeibCmezSnk/q6iAfkdiQaJi4VzroCFrN20=
golang.org/x/sys v0.32.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
golang.org/x/time v0.11.0 h1:/bpjEDfN9tkoN/ryeYHnv5hcMlc8ncjMcM4XBk5NWV0=
golang.org/x/time v0.11.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg=

View File

@ -0,0 +1,76 @@
// Use DBML to define your database structure
// Docs: https://dbml.dbdiagram.io/docs
Table users {
id integer [primary key]
username varchar unique
email varchar unique
password varchar
created_at timestamp
is_active bool
role integer
personalgroup integer
}
Table groups {
id integer [primary key]
name varchar
owner integer
}
Table groupMembership {
membershipid integer [primary key]
groupid integer
userid integer
moderator bool
}
Table roles {
id integer [primary key]
name varchar
description varchar
level integer
}
Table reciepts {
id integer [primary key]
groupid integer
data nvarchar
created_at timestamp
updated_at timestamp
}
// Table imageOwnership {
// ownershipid integer [primary key]
// receiptid integer
// imageid integer
// }
Table images {
id integer [primary key]
receiptid integer
created_at timestamp
path varchar
added bool
}
Ref: "users"."personalgroup" > "groups"."id"
Ref: "groups"."owner" > "users"."id"
Ref: "groups"."id" < "groupMembership"."groupid"
Ref: "groupMembership"."userid" < "users"."id"
Ref: "roles"."id" < "users"."role"
// Ref: "reciepts"."id" < "imageOwnership"."receiptid"
// Ref: "images"."id" < "imageOwnership"."imageid"
Ref: "reciepts"."id" < "images"."receiptid"
Ref: "groups"."id" < "reciepts"."groupownerid"

18
backend/internal/env/env.go vendored Normal file
View File

@ -0,0 +1,18 @@
package env
type Environment interface {
Initialize(path string)
GetString(key string, defaultValue string) string
GetBool(key string, defaultValue bool) bool
GetInt(key string, defaultValue int) int
}
// func GetString(key, fallback string) string {
// value, ok := os.LookupEnv(key)
// if !ok {
// return fallback
// }
// return value
// }
// repeate for bool and int and such

View File

@ -0,0 +1,7 @@
package env
type EnvironmentVariables struct {
// nothing. just empty
}
// also add one for a file

1
backend/internal/env/jsonvars.go vendored Normal file
View File

@ -0,0 +1 @@
package env

1
backend/internal/env/yamlvars.go vendored Normal file
View File

@ -0,0 +1 @@
package env

View File

@ -0,0 +1,125 @@
package lcrypto
import (
"encoding/base64"
"errors"
"fmt"
"strings"
"golang.org/x/crypto/argon2"
)
var (
ErrInvalidHash = errors.New("the encoded hash is not in the correct format")
ErrIncompatibleVersion = errors.New("incompatible argon version")
ErrWrongHashAlgo = errors.New("wrong hash algorithm")
algorithmName = "argon2id"
)
type Argon2id struct {
Memory uint32
Iterations uint32
Parallelism uint8
}
func (a *Argon2id) getAlgoString() string {
return fmt.Sprintf("$%s$v=%d$m=%d,t=%d,p=%d", algorithmName, argon2.Version, a.Memory, a.Iterations, a.Parallelism)
}
func (a *Argon2id) decodeAlgoParams(encodedString string) (algo Argon2id, leftovers []string, err error) {
algo = Argon2id{}
vals := strings.Split(encodedString, "$")
if len(vals) > 4 {
return algo, nil, ErrInvalidHash
}
var name string
if _, err = fmt.Sscanf(vals[1], "%s", &name); err != nil {
return algo, nil, err
}
if name != algorithmName {
return algo, nil, ErrWrongHashAlgo
}
var version int
_, err = fmt.Sscanf(vals[2], "v=%d", &version)
if err != nil {
return algo, nil, err
}
if version != argon2.Version {
return algo, nil, ErrIncompatibleVersion
}
_, err = fmt.Sscanf(vals[3], "m=%d,t=%d,p=%d", &algo.Memory, &algo.Iterations, &algo.Parallelism)
if err != nil {
return algo, nil, err
}
return algo, vals[4:], nil
}
func (a *Argon2id) EncodeHashAndSalt(hash []byte, salt []byte) (string, error) {
// Base64 encode the salt and hashed password.
b64Hash := base64.RawStdEncoding.EncodeToString(hash)
b64Salt := base64.RawStdEncoding.EncodeToString(salt)
// Return a string using the standard encoded hash representation.
encoding := fmt.Sprintf("%s$%s$%s", a.getAlgoString(), b64Hash, b64Salt)
return encoding, nil
}
func (a *Argon2id) EncodeHash(hash []byte) (string, error) {
// Base64 encode the salt and hashed password.
b64Hash := base64.RawStdEncoding.EncodeToString(hash)
// Return a string using the standard encoded hash representation.
encoding := fmt.Sprintf("%s$%s", a.getAlgoString(), b64Hash)
return encoding, nil
}
func (a *Argon2id) DecodeHashAndSalt(encodedString string) (algo Argon2id, hash []byte, salt []byte, err error) {
algo, values, err := a.decodeAlgoParams(encodedString)
if err != nil {
return algo, nil, nil, err
}
if len(values) != 2 {
return algo, nil, nil, ErrInvalidHash
}
hash, err = base64.RawStdEncoding.Strict().DecodeString(values[0])
if err != nil {
return algo, nil, nil, err
}
salt, err = base64.RawStdEncoding.Strict().DecodeString(values[1])
if err != nil {
return algo, nil, nil, err
}
return algo, hash, salt, nil
}
func (a *Argon2id) DecodeHash(encodedString string) (algo Argon2id, hash []byte, err error) {
algo, values, err := a.decodeAlgoParams(encodedString)
if err != nil {
return algo, nil, err
}
if len(values) != 1 {
return algo, nil, ErrInvalidHash
}
hash, err = base64.RawStdEncoding.Strict().DecodeString(values[0])
if err != nil {
return algo, nil, err
}
return algo, hash, nil
}
func (a *Argon2id) HashString(in string, salt []byte, keyLen uint) ([]byte, error) {
hash := argon2.IDKey([]byte(in), salt, a.Iterations, a.Memory, a.Parallelism, uint32(keyLen))
return hash, nil
}

View File

@ -0,0 +1,103 @@
package lcrypto
import (
"crypto/rand"
"crypto/subtle"
)
type HashAlgo interface {
EncodeHashAndSalt(hash []byte, salt []byte) (string, error)
EncodeHash(hash []byte) (string, error)
DecodeHashAndSalt(encodedString string) (algo Argon2id, hash []byte, salt []byte, err error)
DecodeHash(encodedString string) (algo Argon2id, hash []byte, err error)
HashString(in string, salt []byte, keyLen uint) ([]byte, error)
}
type Hasher struct {
Algo HashAlgo
KeyLength uint
SaltLength uint
}
func (h *Hasher) CheckString(pass string, encoded string) (bool, error) {
algo, truehash, salt, err := h.Algo.DecodeHashAndSalt(encoded)
if err != nil {
return false, err
}
newhash, err := algo.HashString(pass, salt, h.KeyLength)
if err != nil {
return false, err
}
equal, err := CompareHash(truehash, newhash)
return equal, err
}
func (h *Hasher) CheckStringWithSalt(in string, salt []byte, encoded string) (bool, error) {
algo, truehash, err := h.Algo.DecodeHash(encoded)
if err != nil {
return false, err
}
newhash, err := algo.HashString(in, salt, h.KeyLength)
if err != nil {
return false, err
}
equal, err := CompareHash(truehash, newhash)
return equal, err
}
func (h *Hasher) HashString(in string) (string, error) {
salt, err := GenerateRandomBytes(h.SaltLength)
if err != nil {
return "", err
}
hash, err := h.Algo.HashString(in, salt, h.KeyLength)
if err != nil {
return "", err
}
return h.Algo.EncodeHashAndSalt(hash, salt)
}
func (h *Hasher) HashStringWithSalt(in string, salt []byte) (string, error) {
hash, err := h.Algo.HashString(in, salt, h.KeyLength)
if err != nil {
return "", err
}
return h.Algo.EncodeHash(hash)
}
func (h *Hasher) Hash(in string) ([]byte, error) {
salt, err := GenerateRandomBytes(h.SaltLength)
if err != nil {
return nil, err
}
return h.Algo.HashString(in, salt, h.KeyLength)
}
func CompareHash(h1 []byte, h2 []byte) (bool, error) {
if subtle.ConstantTimeCompare(h1, h2) == 1 {
return true, nil
}
return false, nil
}
func GenerateRandomBytes(saltLen uint) ([]byte, error) {
var bytes []byte = make([]byte, saltLen)
_, err := rand.Read(bytes)
if err != nil {
return nil, err
}
return bytes, nil
}

View File

@ -0,0 +1,9 @@
package logger
type Logger interface {
Debug(msg string, keysAndValues ...interface{})
Info(msg string, keysAndValues ...interface{})
Warn(msg string, keysAndValues ...interface{})
Error(msg string, keysAndValues ...interface{})
Fatal(msg string, keysAndValues ...interface{})
}

View File

@ -0,0 +1 @@
logger taken from https://dwarvesf.hashnode.dev/go-1-21-release-slog-with-benchmarks-zerolog-and-zap#heading-implementing-the-logger-with-zerolog-and-zap

View File

@ -0,0 +1,38 @@
package logger
import (
"log"
"log/slog"
"os"
)
type SlogLogger struct {
log *slog.Logger
}
func NewSlogLogger() Logger {
sl := slog.New(slog.NewJSONHandler(os.Stdout, nil))
return &SlogLogger{log: sl}
}
func (zl *SlogLogger) Debug(msg string, keysAndValues ...interface{}) {
zl.log.Debug(msg, keysAndValues...)
}
func (zl *SlogLogger) Info(msg string, keysAndValues ...interface{}) {
zl.log.Info(msg, keysAndValues...)
}
func (zl *SlogLogger) Warn(msg string, keysAndValues ...interface{}) {
zl.log.Warn(msg, keysAndValues...)
}
func (zl *SlogLogger) Error(msg string, keysAndValues ...interface{}) {
zl.log.Error(msg, keysAndValues...)
}
func (zl *SlogLogger) Fatal(msg string, keysAndValues ...interface{}) {
zl.log.Error(msg, keysAndValues...)
log.Fatal(msg)
}

View File

@ -0,0 +1,35 @@
package logger
import (
"go.uber.org/zap"
)
type ZapLogger struct {
log *zap.SugaredLogger
}
func NewZapLogger() Logger {
logger, _ := zap.NewProduction()
sugar := logger.Sugar()
return &ZapLogger{log: sugar}
}
func (zl *ZapLogger) Debug(msg string, keysAndValues ...interface{}) {
zl.log.Debugw(msg, keysAndValues...)
}
func (zl *ZapLogger) Info(msg string, keysAndValues ...interface{}) {
zl.log.Infow(msg, keysAndValues...)
}
func (zl *ZapLogger) Warn(msg string, keysAndValues ...interface{}) {
zl.log.Warnw(msg, keysAndValues...)
}
func (zl *ZapLogger) Error(msg string, keysAndValues ...interface{}) {
zl.log.Errorw(msg, keysAndValues...)
}
func (zl *ZapLogger) Fatal(msg string, keysAndValues ...interface{}) {
zl.log.Fatalw(msg, keysAndValues...)
}

View File

@ -0,0 +1,36 @@
package logger
import (
"os"
"github.com/rs/zerolog"
)
type ZeroLogger struct {
log zerolog.Logger
}
func NewZeroLogger() Logger {
zl := zerolog.New(os.Stdout).With().Timestamp().Logger()
return &ZeroLogger{log: zl}
}
func (zl *ZeroLogger) Debug(msg string, keysAndValues ...interface{}) {
zl.log.Debug().Fields(keysAndValues).Msg(msg)
}
func (zl *ZeroLogger) Info(msg string, keysAndValues ...interface{}) {
zl.log.Info().Fields(keysAndValues).Msg(msg)
}
func (zl *ZeroLogger) Warn(msg string, keysAndValues ...interface{}) {
zl.log.Warn().Fields(keysAndValues).Msg(msg)
}
func (zl *ZeroLogger) Error(msg string, keysAndValues ...interface{}) {
zl.log.Error().Fields(keysAndValues).Msg(msg)
}
func (zl *ZeroLogger) Fatal(msg string, keysAndValues ...interface{}) {
zl.log.Fatal().Fields(keysAndValues).Msg(msg)
}

View File

@ -0,0 +1,49 @@
package ratelimiter
// Copied from https://www.youtube.com/watch?v=m5oyY9fgZPs
import (
"sync"
"time"
)
type FixedWindowRateLimiter struct {
sync.RWMutex
clients map[string]int
limit int
window time.Duration
}
func NewFixedWindowLimiter(limit int, window time.Duration) *FixedWindowRateLimiter {
return &FixedWindowRateLimiter{
clients: make(map[string]int),
limit: limit,
window: window,
}
}
func (rl *FixedWindowRateLimiter) Allow(ip string) (bool, time.Duration) {
rl.RLock()
count, exists := rl.clients[ip]
rl.RUnlock()
if !exists || count < rl.limit {
// begin the reset count for a new window
rl.Lock()
if !exists {
go rl.resetCount(ip)
}
rl.clients[ip]++
rl.Unlock()
return true, time.Duration(0)
}
return false, rl.window
}
func (rl *FixedWindowRateLimiter) resetCount(ip string) {
time.Sleep(rl.window)
rl.Lock()
delete(rl.clients, ip)
rl.Unlock()
}

View File

@ -0,0 +1,13 @@
package ratelimiter
import "time"
type Limiter interface {
Allow(ip string) (bool, time.Duration)
}
type Config struct {
RequestsPerTimeFrame int
TimeFrame time.Duration
Enabled bool
}

View File

@ -0,0 +1,42 @@
package ratelimiter
import (
"sync"
"time"
)
type SlidingWindowRateLimiter struct {
sync.RWMutex
clients map[string][]time.Time
limit int
window time.Duration
}
func NewSlidingWindowLimiter(limit int, window time.Duration) *SlidingWindowRateLimiter {
return &SlidingWindowRateLimiter{
clients: make(map[string][]time.Time),
limit: limit,
window: window,
}
}
func (rl *SlidingWindowRateLimiter) Allow(ip string) (bool, time.Duration) {
rl.RLock()
defer rl.Unlock()
// add new request attempt
rl.clients[ip] = append(rl.clients[ip], time.Now())
// remove ones outside the window
for len(rl.clients[ip]) > 0 && time.Since(rl.clients[ip][0]) > rl.window {
rl.clients[ip] = rl.clients[ip][1:]
}
// do actual check now
if len(rl.clients[ip]) > rl.limit {
// calc retry wait time
retryAfter := rl.window - time.Since(rl.clients[ip][0])
return false, time.Duration(retryAfter)
}
return true, time.Duration(0)
}

View File

@ -0,0 +1,75 @@
package ratelimiter
import (
"sync"
"time"
"golang.org/x/time/rate"
)
type TokenBucketRateLimiter struct {
sync.RWMutex
clients map[string]*Client
config Config
}
func NewTokenBucketRateLimiter(config Config) *TokenBucketRateLimiter {
rl := &TokenBucketRateLimiter{
clients: make(map[string]*Client),
config: config,
}
go rl.cleanupClients()
return rl
}
type Client struct {
Limiter *rate.Limiter
LastSeen time.Time
}
func (rl *TokenBucketRateLimiter) getClient(ip string) *Client {
rl.RLock()
_, exists := rl.clients[ip]
rl.RUnlock()
if !exists {
limiter := rate.NewLimiter(rate.Every(rl.config.TimeFrame), rl.config.RequestsPerTimeFrame)
rl.Lock()
rl.clients[ip] = &Client{
Limiter: limiter,
LastSeen: time.Now()}
rl.Unlock()
}
client := rl.clients[ip]
client.LastSeen = time.Now()
return client
}
func (rl *TokenBucketRateLimiter) Allow(ip string) (bool, float64) {
client := rl.getClient(ip)
allowed := client.Limiter.Allow()
tokens := client.Limiter.Tokens()
return allowed, tokens
}
func (rl *TokenBucketRateLimiter) cleanupClients() {
for {
time.Sleep(time.Minute)
//log cleaning up clients
rl.Lock()
for ip, client := range rl.clients {
if time.Since(client.LastSeen) > 3*time.Minute { // timeout period
delete(rl.clients, ip)
}
}
rl.Unlock()
}
}

26
backend/internal/storage/cache/cache.go vendored Normal file
View File

@ -0,0 +1,26 @@
package cache
import (
"context"
"git.ewellenr.ca/receipt_indexer/backend/internal/storage"
// auth_storage "git.ewellenr.ca/receipt_indexer/backend/internal/storage/auth"
)
type Storage struct {
Users interface {
Get(ctx context.Context, id int64) (*storage.User, error)
Set(ctx context.Context, user *storage.User) error
Delete(ctx context.Context, userID int64)
}
Receipts interface {
Get(ctx context.Context, id int64) (*storage.Receipt, error)
Set(ctx context.Context, receipt *storage.Receipt) error
Delete(ctx context.Context, id int64)
}
ReceiptImage interface {
Get(ctx context.Context, id int64) (*storage.Image, error)
Set(ctx context.Context, image *storage.Image) error
Delete(ctx context.Context, id int64)
}
}

View File

@ -0,0 +1,55 @@
package cache
import (
"context"
"encoding/json"
"fmt"
"time"
// auth_storage "git.ewellenr.ca/receipt_indexer/backend/internal/storage/auth"
"git.ewellenr.ca/receipt_indexer/backend/internal/storage"
"github.com/redis/go-redis/v9"
)
type UserStore struct {
rdb *redis.Client
}
const UserExpTime = time.Minute
func (s *UserStore) Get(ctx context.Context, userID int64) (*storage.User, error) {
cacheKey := fmt.Sprintf("user-%d", userID)
data, err := s.rdb.Get(ctx, cacheKey).Result()
if err == redis.Nil {
return nil, nil
} else if err != nil {
return nil, err
}
var user storage.User
if data != "" {
err := json.Unmarshal([]byte(data), &user)
if err != nil {
return nil, err
}
}
return &user, nil
}
func (s *UserStore) Set(ctx context.Context, user *storage.User) error {
cacheKey := fmt.Sprintf("user-%d", user.ID)
json, err := json.Marshal(user)
if err != nil {
return err
}
return s.rdb.Set(ctx, cacheKey, json, UserExpTime).Err()
}
func (s *UserStore) Delete(ctx context.Context, userID int64) {
cacheKey := fmt.Sprintf("user-%d", userID)
s.rdb.Del(ctx, cacheKey)
}

View File

@ -0,0 +1,9 @@
package storage
type Group struct {
ID int64 `json:"id"`
Name string `json:"name"`
Users []int64 `json:"users"`
Owner int64 `json:"owner"`
Moderators []int64 `json:"moderators"`
}

View File

@ -0,0 +1,8 @@
package storage
type Image struct {
ID int64 `json:"id"`
ReceiptID string `json:"receipt_id"`
CreatedAt string `json:"created_at"`
Path string `json:"path"`
}

View File

@ -0,0 +1,22 @@
package storage
import "time"
type Receipt struct {
ID int64 `json:"id"`
// Owner string `json:"username"`
OwnerID int64 `json:"user_id"`
ImageIDs []int64 `json:"image_ids"`
Data ReceiptData `json:"receipt_data"`
CreatedAt time.Time `json:"created_at`
UpdatedAt time.Time `json:"updated_at`
}
type ReceiptData struct {
Date time.Time `json:"date"`
Subtotal float64 `json:"subtotal"`
Total float64 `json:"total"`
Tax float64 `json:"tax"`
Items map[string]float64 `json:"items"`
// Currency string `json:"currency"`
}

View File

@ -0,0 +1,57 @@
package storage
import (
"context"
"time"
"git.ewellenr.ca/receipt_indexer/backend/internal/lcrypto"
"github.com/redis/go-redis/v9"
)
// game plan. store the session token/id and a salt. then, hash them to create the csrf token. give this csrf token out to the user. when the user ends a session, it ends the session but also deletes it from the csrf store
// const csrfSaltLength = 32
// const csrfTokenLength = 128
// should be set by the hasher
type RedisCSRFStore struct {
rdb *redis.Client
hasher lcrypto.Hasher
expirationTime uint
}
func (r *RedisCSRFStore) AddCSRF(ctx context.Context, sessionToken string) (csrftoken string, err error) {
csrf, err := r.hasher.Hash(sessionToken)
if err != nil {
return "", err
}
csrftoken = string(csrf)
if err = r.rdb.Set(ctx, sessionToken, csrftoken, time.Duration(r.expirationTime)).Err(); err != nil {
return "", err
}
return csrftoken, nil
}
func (r *RedisCSRFStore) RemoveCSRF(ctx context.Context, sessionToken string) error {
r.rdb.Del(ctx, sessionToken)
return nil
}
func (r *RedisCSRFStore) ValidCSRF(ctx context.Context, sessionToken string, csrfToken string) (bool, error) {
storedcsrfToken, err := r.rdb.Get(ctx, sessionToken).Result()
if err != nil {
return false, err
}
if err = r.rdb.ExpireXX(ctx, sessionToken, time.Duration(r.expirationTime)).Err(); err != nil {
return false, err
}
valid, err := lcrypto.CompareHash([]byte(csrfToken), []byte(storedcsrfToken))
if err != nil {
return false, err
}
return valid, nil
}

View File

@ -0,0 +1,59 @@
package storage
import (
"context"
"time"
"git.ewellenr.ca/receipt_indexer/backend/internal/lcrypto"
"github.com/redis/go-redis/v9"
)
// game plan. store the session token/id and a salt. then, hash them to create the csrf token. give this csrf token out to the user. when the user ends a session, it ends the session but also deletes it from the csrf store
// const csrfSaltLength = 32
// const csrfTokenLength = 128
// should be set by the hasher
type RedisSessionStore struct {
rdb *redis.Client
sessionTokenLength uint
expirationTime uint
}
func (r *RedisSessionStore) AddSession(ctx context.Context, userid int64) (token string, err error) {
temp, err := lcrypto.GenerateRandomBytes(r.sessionTokenLength)
if err != nil {
return "", err
}
token = string(temp)
err = r.rdb.Set(ctx, token, userid, time.Duration(r.expirationTime)).Err()
if err != nil {
return "", err
}
return token, err
}
func (r *RedisSessionStore) GetSession(ctx context.Context, token string) (valid bool, userid int64, err error) { // should also extend it by the lifespan if near the end of the time. maybe a 5 min window at the end?
userid, err = r.rdb.Get(ctx, token).Int64()
if err == redis.Nil {
valid = false
userid = -1
} else if err != nil {
return false, -1, err
} else {
valid = true
}
err = r.rdb.ExpireXX(ctx, token, time.Duration(r.expirationTime)).Err()
if err != nil {
return false, -1, err
}
return valid, userid, err
}
func (r *RedisSessionStore) RemoveSession(ctx context.Context, token string) error {
return r.rdb.Del(ctx, token).Err()
}

View File

@ -0,0 +1,8 @@
package storage
type Role struct {
ID int64 `json:"id"`
Name string `json:"name"`
Description string `json:"description"`
Level int `json:"level"`
}

View File

@ -0,0 +1,37 @@
package storage
import (
"context"
"database/sql"
)
type SQLGroupsStore struct {
db *sql.DB
}
func (s *SQLGroupsStore) GetByID(ctx context.Context, id int64) (*Group, error) {
query := `SELECT id, name, owner FROM groups WHERE id = $1`
group := &Group{}
err := s.db.QueryRowContext(ctx, query, id).Scan(&group.ID, &group.Name, &group.Owner)
if err != nil {
return nil, err
}
return group, nil
}
func (s *SQLGroupsStore) GetUserGroups(ctx context.Context, userID int64) ([]*Group, error) {
// Implement logic to retrieve user's groups from the database
}
func (s *SQLGroupsStore) GetUsersInGroup(ctx context.Context, groupId int64) ([]*User, error) {
// Implement logic to retrieve users in a group from the database
}
func (s *SQLGroupsStore) Create(ctx context.Context, group *Group) error {
// Implement logic to create a new group in the database
}
func (s *SQLGroupsStore) Delete(ctx context.Context, id int64) error {
// Implement logic to delete a group from the database
}

View File

@ -0,0 +1,79 @@
package storage
import (
"context"
"database/sql"
)
type SQLImagesStore struct {
db *sql.DB
}
func (s SQLImagesStore) GetByID(ctx context.Context, id int64) (*Image, error) {
query := `SELECT id, receiptid, created_at, path FROM images WHERE id = $1`
image := &Image{}
err := s.db.QueryRowContext(ctx, query, id).Scan(&image.ID, &image.ReceiptID, &image.CreatedAt, &image.Path)
if err != nil {
return nil, err
}
return image, nil
}
func (s SQLImagesStore) Create(ctx context.Context, img *Image) error {
query := `
INSERT INTO images (receiptid, path)
VALUES ($1, $2) RETURNING id, created_at`
ctx, cancel := context.WithTimeout(ctx, QueryTimeoutDuration)
defer cancel()
err := s.db.QueryRowContext(
ctx,
query,
img.ReceiptID,
img.Path, // might need to marshal or serialize this
).Scan(
&img.ID,
&img.CreatedAt,
)
return err
}
func (s SQLImagesStore) Delete(ctx context.Context, id int64) error {
query := `DELETE FROM images WHERE id = $1`
ctx, cancel := context.WithTimeout(ctx, QueryTimeoutDuration)
defer cancel()
res, err := s.db.ExecContext(ctx, query, id)
if err != nil {
return err
}
rows, err := res.RowsAffected()
if err != nil {
return err
}
if rows == 0 {
return ErrNotFound
}
return nil
}
func (s SQLImagesStore) ActivateImage(ctx context.Context, id int64) error {
query := `UPDATE images SET added = $1`
ctx, cancel := context.WithTimeout(ctx, QueryTimeoutDuration)
defer cancel()
_, err := s.db.ExecContext(ctx, query, true)
if err != nil {
return err
}
return nil
}

View File

@ -0,0 +1,66 @@
package storage
import (
"context"
"database/sql"
)
type SQLReceiptsStore struct {
db *sql.DB
}
func (s *SQLReceiptsStore) GetByID(ctx context.Context, id int64) (*Receipt, error) {
query := `SELECT id, groupid, data FROM receipts WHERE id = $1`
receipt := &Receipt{}
err := s.db.QueryRowContext(ctx, query, id).Scan(&receipt.ID, &receipt.OwnerID, &receipt.Data)
if err != nil {
return nil, err
}
return receipt, nil
}
func (s *SQLReceiptsStore) Create(ctx context.Context, receipt *Receipt) error {
query := `
INSERT INTO receipts (groupid, data)
VALUES ($1, $2) RETURNING id, created_at, updated_at`
ctx, cancel := context.WithTimeout(ctx, QueryTimeoutDuration)
defer cancel()
err := s.db.QueryRowContext(
ctx,
query,
receipt.OwnerID,
receipt.Data, // might need to marshal or serialize this
).Scan(
&receipt.ID,
&receipt.CreatedAt,
&receipt.UpdatedAt,
)
return err
}
func (s *SQLReceiptsStore) Delete(ctx context.Context, id int64) error {
query := `DELETE FROM receipts WHERE id = $1`
ctx, cancel := context.WithTimeout(ctx, QueryTimeoutDuration)
defer cancel()
res, err := s.db.ExecContext(ctx, query, id)
if err != nil {
return err
}
rows, err := res.RowsAffected()
if err != nil {
return err
}
if rows == 0 {
return ErrNotFound
}
return nil
}

View File

@ -0,0 +1,39 @@
package storage
import (
"context"
"database/sql"
)
type SQLRolesStore struct {
db *sql.DB
}
func (s *SQLRolesStore) GetByName(ctx context.Context, name string) (*Role, error) {
ctx, cancel := context.WithTimeout(ctx, QueryTimeoutDuration)
defer cancel()
query := `SELECT id, name, description, level FROM roles WHERE name = $1`
role := &Role{}
err := s.db.QueryRowContext(ctx, query, name).Scan(&role.ID, &role.Name, &role.Description, &role.Level)
if err != nil {
return nil, err
}
return role, nil
}
func (s *SQLRolesStore) GetById(ctx context.Context, id int64) (*Role, error) {
ctx, cancel := context.WithTimeout(ctx, QueryTimeoutDuration)
defer cancel()
query := `SELECT id, name, description, level FROM roles WHERE id = $1`
role := &Role{}
err := s.db.QueryRowContext(ctx, query, id).Scan(&role.ID, &role.Name, &role.Description, &role.Level)
if err != nil {
return nil, err
}
return role, nil
}

View File

@ -0,0 +1,231 @@
package storage
import (
"context"
"database/sql"
"fmt"
"time"
)
type SQLUsersStore struct {
db *sql.DB
}
func (s *SQLUsersStore) GetByID(ctx context.Context, id int64) (*User, error) {
query := `SELECT users.id, users.username, users.email, users.password, users.created_at, roles.*
FROM users
JOIN roles ON (users.role_id = roles.id)
WHERE users.id = $1 AND is_active = true`
ctx, cancel := context.WithTimeout(ctx, QueryTimeoutDuration)
defer cancel()
user := &User{}
err := s.db.QueryRowContext(
ctx,
query,
id,
).Scan(
&user.ID,
&user.Username,
&user.Email,
&user.Password.hash,
&user.CreatedAt,
&user.Role.ID,
&user.Role.Name,
&user.Role.Level,
&user.Role.Description,
)
if err != nil {
switch err {
case sql.ErrNoRows:
return nil, ErrNotFound
default:
return nil, err
}
}
return user, nil
}
func (s *SQLUsersStore) GetByEmail(ctx context.Context, email string) (*User, error) {
query := `SELECT users.id, users.username, users.email, users.password, users.created_at, roles.*
FROM users
JOIN roles ON (users.role_id = roles.id)
WHERE users.email = $1 AND is_active = true`
ctx, cancel := context.WithTimeout(ctx, QueryTimeoutDuration)
defer cancel()
user := &User{}
err := s.db.QueryRowContext(
ctx,
query,
email,
).Scan(
&user.ID,
&user.Username,
&user.Email,
&user.Password.hash,
&user.CreatedAt,
&user.Role.ID,
&user.Role.Name,
&user.Role.Level,
&user.Role.Description,
)
if err != nil {
switch err {
case sql.ErrNoRows:
return nil, ErrNotFound
default:
return nil, err
}
}
return user, nil
}
func (s *SQLUsersStore) GetByUsername(ctx context.Context, username string) (*User, error) {
query := `SELECT users.id, users.username, users.email, users.password, users.created_at, roles.*
FROM users
JOIN roles ON (users.role_id = roles.id)
WHERE users.username = $1 AND is_active = true`
ctx, cancel := context.WithTimeout(ctx, QueryTimeoutDuration)
defer cancel()
user := &User{}
err := s.db.QueryRowContext(
ctx,
query,
username,
).Scan(
&user.ID,
&user.Username,
&user.Email,
&user.Password.hash,
&user.CreatedAt,
&user.Role.ID,
&user.Role.Name,
&user.Role.Level,
&user.Role.Description,
)
if err != nil {
switch err {
case sql.ErrNoRows:
return nil, ErrNotFound
default:
return nil, err
}
}
return user, nil
}
func (s *SQLUsersStore) create(ctx context.Context, user *User, tx *sql.Tx) error { // creates the personal group and binds them
ctx, cancel := context.WithTimeout(ctx, QueryTimeoutDuration)
defer cancel()
query := `INSERT INTO groups (name, owner) VALUES ($1, $2) RETURNING id`
err := tx.QueryRowContext(ctx, query, fmt.Sprintf("User-%d-Personal-Group", user.ID), user.ID).Scan(&user.PersonalGroup)
if err != nil {
return err
}
query = `INSERT INTO users (username, password, email, role_id, personalgroup) VALUES
($1, $2, $3, (SELECT id FROM roles WHERE name = $4), $5)
RETURNING id, created_at`
role := user.Role.Name
if role == "" {
role = "user"
}
err = tx.QueryRowContext(
ctx,
query,
user.Username,
user.Password,
user.Email,
role,
user.PersonalGroup,
).Scan(
&user.ID,
&user.CreatedAt,
)
if err != nil {
switch {
case err.Error() == `pq: duplicate key value violates unique constraint "users_email_key"`:
return ErrDuplicateEmail
case err.Error() == `pq: duplicate key value violates unique constraint "users_username_key"`:
return ErrDuplicateUsername
default:
return err
}
}
query = `INSERT INTO groupMembership (groupid, userid, moderator) VALUES ($1, $2, $3)`
_, err = tx.ExecContext(ctx, query, user.PersonalGroup, user.ID, true)
if err != nil {
return err
}
return nil
}
func (s *SQLUsersStore) Create(ctx context.Context, user *User) error { // create a non-exported create function which does take in the tx
return withTx(s.db, ctx, func(tx *sql.Tx) error {
return s.create(ctx, user, tx)
})
}
func (s *SQLUsersStore) CreateAndInvite(ctx context.Context, user *User, token string, exp time.Duration) error { // figure this out
return nil
}
func (s *SQLUsersStore) Activate(context.Context, string) error { // what does this do?
return nil
}
func (s *SQLUsersStore) delete(ctx context.Context, id int64, tx *sql.Tx) error {
query := `DELETE FROM users WHERE id = $1`
ctx, cancel := context.WithTimeout(ctx, QueryTimeoutDuration)
defer cancel()
res, err := s.db.ExecContext(ctx, query, id)
if err != nil {
return err
}
rows, err := res.RowsAffected()
if err != nil {
return err
}
if rows == 0 {
return ErrNotFound
}
return nil
}
func (s *SQLUsersStore) Delete(ctx context.Context, id int64) error {
return withTx(s.db, ctx, func(tx *sql.Tx) error {
return s.delete(ctx, id, tx)
})
}
func (s *SQLUsersStore) UpdateUserPass(ctx context.Context, user string, oldPassword string, newPass string) error {
return nil
}
func (s *SQLUsersStore) CheckPass(ctx context.Context, name string, pass string) (bool, error) {
return false, nil
}
func (s *SQLUsersStore) SigninUser(ctx context.Context, name string, pass string) (bool, *User, error) {
return false, nil, nil
}

View File

@ -0,0 +1,25 @@
package storage
import (
"context"
"database/sql"
"time"
)
var (
QueryTimeoutDuration = time.Second * 5
)
func withTx(db *sql.DB, ctx context.Context, fn func(*sql.Tx) error) error {
tx, err := db.BeginTx(ctx, nil)
if err != nil {
return err
}
if err := fn(tx); err != nil {
_ = tx.Rollback()
return err
}
return tx.Commit()
}

View File

@ -0,0 +1,80 @@
package storage
import (
"context"
"database/sql"
"errors"
"time"
)
var (
ErrNotFound = errors.New("resource not found")
ErrConflict = errors.New("resource already exists")
ErrDuplicateEmail = errors.New("a user with that email already exists")
ErrDuplicateUsername = errors.New("a user with that username already exists")
)
type Storage struct {
Users interface { // store user id, username, password(hashed+salted), role?
GetByID(ctx context.Context, id int64) (*User, error)
GetByEmail(context.Context, string) (*User, error)
GetByUsername(context.Context, string) (*User, error)
Create(context.Context, *User) error // create a non-exported create function which does take in the tx
CreateAndInvite(ctx context.Context, user *User, token string, exp time.Duration) error // figure this out
Activate(context.Context, string) error // what does this do?
Delete(ctx context.Context, id int64) error
UpdateUserPass(ctx context.Context, user string, oldPassword string, newPass string) error
// CheckPass(ctx context.Context, name string, pass string) (bool, error)
// SigninUser(ctx context.Context, name string, pass string) (bool, *User, error)
// ValidCredentials(ctx context.Context, user *User, pass string) (bool, error)
}
Sessions interface { // store just session tokens, and their corresponding user id
AddSession(ctx context.Context, userid int64) (token string, err error)
GetSession(ctx context.Context, token string) (valid bool, userid int64, err error) // extends it's expiry
RemoveSession(ctx context.Context, token string) error
// SetLifespan(ctx context.Context, token string, lf time.Time) error
}
CSRF interface {
AddCSRF(ctx context.Context, sessionToken string) (csrftoken string, err error)
CheckCSRF(ctx context.Context, sessionToken string, csrfToken string) (bool, error)
RemoveCSRF(ctx context.Context, sessionToken string) error
// CleanupCSRF()
}
Roles interface {
GetByName(context.Context, string) (*Role, error)
GetById(context.Context, int64) (*Role, error)
}
Receipts interface {
GetByID(ctx context.Context, id int64) (*Receipt, error)
Create(ctx context.Context, receipt *Receipt) error
Delete(ctx context.Context, id int64) error
}
Images interface {
GetByID(context.Context, int64) (*Image, error)
Create(context.Context, *Image) error
Delete(context.Context, int64) error
ActivateImage(context.Context, int64) error // may need to change this. Consider finishing https://www.youtube.com/watch?v=pmEmQcd9_KA
// Can have it so that the client get's a presigned url
// For creation, do a delayed check in 10 minutes to clean up the image from AWS or check
}
Groups interface {
GetByID(context.Context, int64) (*Group, error)
GetUserGroups(context.Context, int64) ([]*Group, error)
GetUsersInGroup(context.Context, int64) ([]*User, error)
Create(context.Context, *Group) error
Delete(context.Context, int64) error
}
}
func NewSQLRedisMinIOStorage(db *sql.DB) Storage {
return Storage{
Users: &SQLUsersStore{db},
}
}

View File

@ -0,0 +1,29 @@
package storage
import (
"errors"
)
var (
ErrUserNotFound = errors.New("user not found")
ErrExistingUser = errors.New("user already exists")
ErrPasswordIncorrect = errors.New("password incorrect")
)
type User struct {
ID int64 `json:"id"`
Username string `json:"username"`
Email string `json:"email"`
Password string `json:"-"`
CreatedAt string `json:"created_at"`
IsActive bool `json:"is_active"`
Role Role `json:"role"`
PersonalGroup int64 `json:"user_group"`
Groups []int64 `json:"groups"`
}
// type Password struct {
// text *string
// hash []byte
// encoded *string
// }

View File

@ -0,0 +1,17 @@
package imgtransform
import (
"bytes"
"image"
"github.com/nfnt/resize"
)
func BytesToImage(data []byte) (image.Image, error) {
image, _, err := image.Decode(bytes.NewReader(data))
return image, err
}
func ResizeImage(img image.Image, width uint, height uint) image.Image {
return resize.Resize(width, height, img, resize.Lanczos3)
}

View File

@ -1,17 +0,0 @@
{
"configurations": [
{
"name": "Linux",
"includePath": [
"${workspaceFolder}/**"
],
"defines": [],
"compilerPath": "/usr/bin/gcc",
// "cStandard": "c17",
// "cppStandard": "gnu++17",
"intelliSenseMode": "linux-gcc-x64",
"configurationProvider": "ms-vscode.cmake-tools"
}
],
"version": 4
}

View File

@ -1,65 +0,0 @@
{
"C_Cpp.errorSquiggles": "enabled",
"files.associations": {
"array": "cpp",
"atomic": "cpp",
"bit": "cpp",
"*.tcc": "cpp",
"cctype": "cpp",
"chrono": "cpp",
"clocale": "cpp",
"cmath": "cpp",
"compare": "cpp",
"complex": "cpp",
"concepts": "cpp",
"condition_variable": "cpp",
"cstdarg": "cpp",
"cstddef": "cpp",
"cstdint": "cpp",
"cstdio": "cpp",
"cstdlib": "cpp",
"cstring": "cpp",
"ctime": "cpp",
"cwchar": "cpp",
"cwctype": "cpp",
"deque": "cpp",
"list": "cpp",
"map": "cpp",
"set": "cpp",
"string": "cpp",
"unordered_map": "cpp",
"vector": "cpp",
"exception": "cpp",
"algorithm": "cpp",
"functional": "cpp",
"iterator": "cpp",
"memory": "cpp",
"memory_resource": "cpp",
"numeric": "cpp",
"random": "cpp",
"ratio": "cpp",
"string_view": "cpp",
"system_error": "cpp",
"tuple": "cpp",
"type_traits": "cpp",
"utility": "cpp",
"fstream": "cpp",
"initializer_list": "cpp",
"iomanip": "cpp",
"iosfwd": "cpp",
"iostream": "cpp",
"istream": "cpp",
"limits": "cpp",
"mutex": "cpp",
"new": "cpp",
"numbers": "cpp",
"ostream": "cpp",
"semaphore": "cpp",
"sstream": "cpp",
"stdexcept": "cpp",
"stop_token": "cpp",
"streambuf": "cpp",
"thread": "cpp",
"typeinfo": "cpp"
},
}

View File

@ -1,24 +0,0 @@
cmake_minimum_required(VERSION 3.22)
project(autocropper
VERSION 0.1
DESCRIPTION "Autocrops Receipt Pictures"
LANGUAGES CXX)
#GLOBING
file(GLOB_RECURSE SOURCE_FILES src/*.cpp)
add_executable(CropperEx main.cpp ${SOURCE_FILES})
# add_executable(CropperEx main.cpp
# src/dog.cpp
# src/operations.cpp)
target_compile_features(CropperEx PRIVATE cxx_std_20)
find_package(OpenCV REQUIRED)
target_link_libraries(CropperEx ${OpenCV_LIBS})
target_include_directories(CropperEx
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../externallibraries/stbimagehelpers
PRIVATE ${OpenCV_INCLUDE_DIRS})

View File

@ -1,721 +0,0 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/usr/local/lib/python3.10/dist-packages/torchvision/datapoints/__init__.py:12: UserWarning: The torchvision.datapoints and torchvision.transforms.v2 namespaces are still Beta. While we do not expect major breaking changes, some APIs may still change according to user feedback. Please submit any feedback you may have in this issue: https://github.com/pytorch/vision/issues/6753, and you can also check out https://github.com/pytorch/vision/issues/7319 to learn more about the APIs that we suspect might involve future changes. You can silence this warning by calling torchvision.disable_beta_transforms_warning().\n",
" warnings.warn(_BETA_TRANSFORMS_WARNING)\n",
"/usr/local/lib/python3.10/dist-packages/torchvision/transforms/v2/__init__.py:54: UserWarning: The torchvision.datapoints and torchvision.transforms.v2 namespaces are still Beta. While we do not expect major breaking changes, some APIs may still change according to user feedback. Please submit any feedback you may have in this issue: https://github.com/pytorch/vision/issues/6753, and you can also check out https://github.com/pytorch/vision/issues/7319 to learn more about the APIs that we suspect might involve future changes. You can silence this warning by calling torchvision.disable_beta_transforms_warning().\n",
" warnings.warn(_BETA_TRANSFORMS_WARNING)\n"
]
}
],
"source": [
"import cv2\n",
"import myfunctions as mf\n",
"import numpy as np\n",
"import math\n",
"import scipy.stats as st"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import pathlib\n",
"import time\n",
"\n",
"def removeextensionandnumeric(filename):\n",
" suffix = pathlib.Path(filename).suffix\n",
" num = filename[:-len(suffix)]\n",
" numint = int(num)\n",
" return numint\n",
" \n",
"\n",
"def testondataset(pathtodataset, function):\n",
" imagefileextensions = [\".jpg\", \".png\"]\n",
" filenames = next(os.walk(pathtodataset), (None, None, []))[2]\n",
" \n",
" filenames.sort(key=removeextensionandnumeric)\n",
" # print(filenames)\n",
" outs = []\n",
" tdiffs = []\n",
" for filename in filenames:\n",
" suffix = pathlib.Path(filename).suffix\n",
" if (suffix not in imagefileextensions):\n",
" print(\"Not a valid image \"+filename)\n",
" continue\n",
" img = cv2.imread(pathtodataset+filename)\n",
" t1 = time.time()\n",
" outs.append(function(img))\n",
" tdiffs.append(time.time() - t1)\n",
" tdiffs = np.array(tdiffs)\n",
" print(\"average time: \" + str(np.mean(tdiffs))+\"(s)\")\n",
" return outs\n",
" "
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"def showimgs(imgs):\n",
" if (isinstance(imgs, np.ndarray)):\n",
" if (imgs.shape[0] > imgs.shape[1]):\n",
" cv2.imshow(\"test\", mf.ResizeWithAspectRatio(imgs, height=1350))\n",
" else:\n",
" cv2.imshow(\"test\", mf.ResizeWithAspectRatio(imgs, width=1000))\n",
" else:\n",
" for i, out in enumerate(imgs):\n",
" if (out.shape[0] > out.shape[1]):\n",
" cv2.imshow(\"test\"+str(i), mf.ResizeWithAspectRatio(out, height=1350))\n",
" else:\n",
" cv2.imshow(\"test\"+str(i), mf.ResizeWithAspectRatio(out, width=1000))\n",
" cv2.waitKey(0)\n",
" cv2.destroyAllWindows()"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"def writeimgs(directorypath, imgs):\n",
" if (isinstance(imgs, np.ndarray)):\n",
" cv2.imwrite(directorypath+\"test.png\", imgs)\n",
" else:\n",
" for i, out in enumerate(imgs):\n",
" cv2.imwrite(directorypath+\"test\"+str(i)+\".png\", out)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"img = cv2.imread('/mnt/dataset/baseimages/12.jpg')\n",
"# img = cv2.imread('/mnt/code/autocropper/test_images/IMG_7605.jpg')\n",
"testall = False"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"## NEED TO FIX THE EARLIER PARTS SO THAT IT DOESN'T HAVE THOSE BLACK SECTIONS AFTER THE ROTATION\n",
"\n",
"\n",
"def whiteoutbackground(image):\n",
" ogshape = image.shape\n",
" shrunkdim=1000\n",
" if (image.shape[1] > image.shape[0]):\n",
" shrunkimg, scaler = mf.ResizeWithAspectRatio(image, width=shrunkdim, retscale=True)\n",
" else:\n",
" shrunkimg, scaler = mf.ResizeWithAspectRatio(image, height=shrunkdim, retscale=True)\n",
" \n",
" mainimage = shrunkimg\n",
" \n",
" sdim = int(min(mainimage.shape[0], mainimage.shape[1])/5)\n",
" srkernel = cv2.getStructuringElement(cv2.MORPH_RECT, (sdim, sdim))\n",
" skernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (sdim, sdim))\n",
" \n",
" \n",
" lab = cv2.cvtColor(mainimage, cv2.COLOR_BGR2LAB)\n",
" \n",
" imglist = []\n",
" # imglist.append(mainimage)\n",
" \n",
" labl = lab[:,:,0]\n",
" # imglist.append(labl)\n",
" # imglist.append(cv2.cvtColor(image, cv2.COLOR_BGR2GRAY))\n",
" laba = lab[:,:,1]\n",
" # imglist.append(laba)\n",
" labb = lab[:,:,2]\n",
" # imglist.append(labb)\n",
" \n",
" \n",
" # canny = cv2.Canny(labl, 0, 500)\n",
" threshl = cv2.threshold(labl, 0, 255, cv2.THRESH_OTSU)[1]\n",
" # return threshl\n",
" \n",
" \n",
" dim = int(min(mainimage.shape[0], mainimage.shape[1])/100)\n",
" # dim = 2\n",
" # dim = dotsize\n",
" kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (dim, dim))\n",
" kernelell = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (dim, dim))\n",
" \n",
" paddedl = mf.padWithColour(threshl, sdim*2, sdim*2, fill=0)\n",
" # return paddedl\n",
" \n",
" \n",
" # morphedl = 255-cv2.morphologyEx(255-threshl, cv2.MORPH_OPEN, kernel, iterations=3)\n",
" morphedl = paddedl\n",
" # morphedl = cv2.morphologyEx(morphedl, cv2.MORPH_ERODE, kernel, iterations=1)\n",
" morphed1l = cv2.morphologyEx(morphedl, cv2.MORPH_ERODE, kernelell, iterations=1)\n",
"\n",
" # return morphedl\n",
" \n",
" contours, heirarchy = cv2.findContours(morphed1l, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)\n",
" biggestcontour = max(contours, key=cv2.contourArea)\n",
" \n",
" \n",
" blank = np.full(labl.shape, 255, dtype=np.uint8)\n",
" mask1 = blank.copy()\n",
" mask1 = mf.padWithColour(mask1, sdim*2, sdim*2, fill=255)\n",
" mask1 = cv2.drawContours(mask1, [biggestcontour], -1, 0, thickness=cv2.FILLED)\n",
" \n",
" \n",
" mask1 = cv2.morphologyEx(mask1, cv2.MORPH_DILATE, kernelell, iterations=2)\n",
" \n",
" \n",
" # mask1 = mask1[(sdim*2):-(sdim*2), (sdim*2):-(sdim*2)]\n",
" # return mask1\n",
" \n",
" # morphed2l = mf.padWithColour(morphedl, sdim*2, sdim*2, fill=255)\n",
" morphed2l = cv2.morphologyEx(morphedl, cv2.MORPH_OPEN, kernel, iterations=1)\n",
" # morphed2l = morphed2l[(sdim*2):-(sdim*2), (sdim*2):-(sdim*2)]\n",
" \n",
" # return morphed2l\n",
" # print(mask1.shape)\n",
" # print(morphed2l.shape)\n",
" morphed2l = cv2.bitwise_or(morphed2l, 255-mask1)\n",
" # return morphed2l\n",
" \n",
" morphed2l = morphed2l[(sdim*2):-(sdim*2), (sdim*2):-(sdim*2)]\n",
" temp_final = cv2.bitwise_or(threshl, 255-morphed2l)\n",
" return temp_final\n",
" \n",
" canny = cv2.Canny(morphed2l, 0, 500)\n",
" # return canny\n",
"\n",
" vminlength = mainimage.shape[0]//10\n",
" vmaxgap = mainimage.shape[0]//50\n",
" vlinesP = cv2.HoughLinesP(canny, 1, np.pi / 180, 10, None, vminlength, vmaxgap)\n",
" \n",
" hminlength = mainimage.shape[1]//15\n",
" hmaxgap = mainimage.shape[1]//40\n",
" hlinesP = cv2.HoughLinesP(canny, 1, np.pi / 180, 10, None, hminlength, hmaxgap)\n",
" # print(linesP)\n",
" \n",
" vmarginlines = mf.WithinXDegrees(vlinesP, 15)\n",
" hmarginlines = mf.WithinXDegrees(hlinesP, 15, baseangle=90)\n",
" \n",
" marginlines = np.append(vmarginlines, hmarginlines, axis=0)\n",
" # marginlines = marginlines.astype(int)\n",
" # # print(marginlines)\n",
" # reshaped = np.reshape(marginlines, (-1,1, 2))\n",
" # # reshaped = cv2.convexHull(reshaped)\n",
" # # print(reshaped)\n",
" \n",
" \n",
" \n",
" colourdst = cv2.cvtColor(morphedl, cv2.COLOR_GRAY2BGR)\n",
" # out = cv2.drawContours(colourdst, [reshaped], -1, (0,255,0), thickness=3)\n",
" # return out\n",
" \n",
" \n",
" #### NEW IDEA: MERGE THE WHITEOUT BACKGROUND AND TEXT CLARIFICATION STEP BECAUSE DOING THE OTSU THRESHOLD SEEMS TO WORK PRETTY WELL AND IF I JUST WHITE OUT THE OUTER AREA (ACTUALLY WHITE)\n",
" # THEN I HAVE JUST THE TEXT\n",
" \n",
"\n",
" if marginlines is not None:\n",
" for l in marginlines:\n",
" cv2.line(colourdst, (int(l[0]), int(l[1])), (int(l[2]), int(l[3])), (0,0,255), 3, cv2.LINE_AA)\n",
" return colourdst\n",
"\n",
"\n",
"\n",
"\n",
"\n",
" ## IDEA:\n",
" # MASK OUT THE WORDS USING OUR MASKS MADE FROM THE STUFF BELOW. THEN WHEN CANNY IS DONE TO IT, IT SHOULDN'T HAVE A WHOLE BUNCH OF SHIT IN THE CENTER. STILL NEED TO FIGURE OUT HOW TO LINK THE HOUGH LINES AROUND THE RECEIPT\n",
"\n",
"\n",
"\n",
"\n",
"\n",
" # morphedl = 255-cv2.morphologyEx(255-threshl, cv2.MORPH_OPEN, kernel, iterations=3)\n",
" morphedl = paddedl\n",
" morphedl = cv2.morphologyEx(morphedl, cv2.MORPH_ERODE, kernel, iterations=1)\n",
" morphedl = cv2.morphologyEx(morphedl, cv2.MORPH_ERODE, kernelell, iterations=1)\n",
"\n",
" # return morphedl\n",
" \n",
" contours, heirarchy = cv2.findContours(morphedl, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)\n",
" # print(contours[0].shape)\n",
" print(contours[0])\n",
" biggestcontour = max(contours, key=cv2.contourArea)\n",
" return canny\n",
" \n",
" \n",
" blank = np.full(labl.shape, 255, dtype=np.uint8)\n",
" mask1 = blank.copy()\n",
" mask1 = mf.padWithColour(mask1, sdim*2, sdim*2, fill=255)\n",
" mask1 = cv2.drawContours(mask1, [biggestcontour], -1, 0, thickness=cv2.FILLED)\n",
" \n",
" \n",
" mask1 = mask1[(sdim*2):-(sdim*2), (sdim*2):-(sdim*2)]\n",
" \n",
" \n",
" # resizemask = cv2.resize(mask1, (ogshape[1], ogshape[0]))\n",
" # return resizemask\n",
" maskc = cv2.cvtColor(mask1, cv2.COLOR_GRAY2BGR)\n",
" # print(maskc.shape)\n",
" # print(image.shape)\n",
" whitedbackground = cv2.bitwise_or(mainimage, maskc)\n",
" # return whitedbackground\n",
" \n",
" \n",
" lab2 = cv2.cvtColor(whitedbackground, cv2.COLOR_BGR2LAB)\n",
" \n",
" lab2l = lab2[:,:,0]\n",
" \n",
" \n",
" otsu2 = cv2.threshold(lab2l, 0, 255, cv2.THRESH_OTSU)[1]\n",
" \n",
" expandedmask1 = cv2.morphologyEx(mask1, cv2.MORPH_DILATE, kernel, iterations=1)\n",
" expandedmask1 = cv2.morphologyEx(expandedmask1, cv2.MORPH_DILATE, kernelell, iterations=1)\n",
" # return expandedmask1\n",
" \n",
" maskmerge = cv2.bitwise_and(otsu2, 255-expandedmask1)\n",
" return mask1\n",
" return maskmerge\n",
" \n",
" # return otsu2\n",
" \n",
" mpad = mf.padWithColour(maskmerge, sdim*2, sdim*2, fill=0)\n",
" return mpad\n",
" \n",
" #MORPHOLOGIES \n",
" morphed2 = cv2.morphologyEx(mpad, cv2.MORPH_ERODE, kernel, iterations=1)\n",
" morphed2 = cv2.morphologyEx(morphed2, cv2.MORPH_ERODE, kernelell, iterations=1)\n",
" return morphed2\n",
" \n",
" contours, heirarchy = cv2.findContours(morphed2, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)\n",
" biggestcontour = max(contours, key=cv2.contourArea)\n",
" \n",
" \n",
" mask2 = blank.copy()\n",
" mask2 = mf.padWithColour(mask2, sdim*2, sdim*2, fill=255)\n",
" mask2 = cv2.drawContours(mask2, [biggestcontour], -1, 0, thickness=cv2.FILLED)\n",
" \n",
" \n",
" mask2 = mask2[(sdim*2):-(sdim*2), (sdim*2):-(sdim*2)]\n",
" \n",
" return mask2\n",
" \n",
" test = cv2.inpaint(whitedbackground, resizemask, 3, cv2.INPAINT_TELEA)\n",
" \n",
" return test\n",
" \n",
" contours, heirarchy = cv2.findContours(255-labl, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)\n",
" \n",
" imgout = cv2.drawContours(mainimage, contours, -1, (0,255,0), thickness=3)\n",
" return imgout\n",
" \n"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"def textleaver(image):\n",
" ogshape = image.shape\n",
" shrunkdim=1000\n",
" if (image.shape[1] > image.shape[0]):\n",
" shrunkimg, scaler = mf.ResizeWithAspectRatio(image, width=shrunkdim, retscale=True)\n",
" else:\n",
" shrunkimg, scaler = mf.ResizeWithAspectRatio(image, height=shrunkdim, retscale=True)\n",
" \n",
" mainimage = shrunkimg\n",
" \n",
" sdim = int(min(mainimage.shape[0], mainimage.shape[1])/5)\n",
" srkernel = cv2.getStructuringElement(cv2.MORPH_RECT, (sdim, sdim))\n",
" skernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (sdim, sdim))\n",
" \n",
" oglab = cv2.cvtColor(image, cv2.COLOR_BGR2LAB)\n",
" lab = cv2.cvtColor(mainimage, cv2.COLOR_BGR2LAB)\n",
" \n",
" imglist = []\n",
" # imglist.append(mainimage)\n",
" \n",
" labl = lab[:,:,0]\n",
" oglabl = oglab[:,:,0]\n",
" # # imglist.append(labl)\n",
" # # imglist.append(cv2.cvtColor(image, cv2.COLOR_BGR2GRAY))\n",
" # laba = lab[:,:,1]\n",
" # # imglist.append(laba)\n",
" # labb = lab[:,:,2]\n",
" # # imglist.append(labb)\n",
" \n",
" divisor = 1.5\n",
" window = int(min(labl.shape)/divisor)\n",
" window = window if window%2 == 1 else window + 1\n",
" # canny = cv2.Canny(labl, 0, 500)\n",
" ethreshl = cv2.threshold(labl, 0, 255, cv2.THRESH_OTSU)[1]\n",
" threshl = cv2.threshold(labl, 0, 255, cv2.THRESH_OTSU)[1]\n",
" # threshl = cv2.adaptiveThreshold(labl, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY, window, 35)\n",
" \n",
" \n",
" ogwindow = int(min(oglabl.shape)/divisor)\n",
" ogwindow = window if window%2 == 1 else window + 1\n",
" print(ogwindow)\n",
" ogthreshl = cv2.threshold(oglabl, 0, 255, cv2.THRESH_TRIANGLE)[1]\n",
" return ogthreshl\n",
" # ogthreshl = cv2.adaptiveThreshold(oglabl, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY, ogwindow, 35)\n",
" # return threshl\n",
" \n",
" colourthresh = cv2.cvtColor(threshl, cv2.COLOR_GRAY2BGR)\n",
" \n",
" dim = int(min(mainimage.shape[0], mainimage.shape[1])/100)\n",
" # dim = 2\n",
" # dim = dotsize\n",
" dim = max(3,dim)\n",
" kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (dim, dim))\n",
" kernelell = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (dim, dim))\n",
" \n",
" # paddedl = mf.padWithColour(threshl, sdim*2, sdim*2, fill=0)\n",
" paddedl = threshl\n",
" # return paddedl\n",
" \n",
" \n",
" # morphedl = 255-cv2.morphologyEx(255-threshl, cv2.MORPH_OPEN, kernel, iterations=3)\n",
" morphedl = paddedl\n",
" morphed1l = cv2.morphologyEx(morphedl, cv2.MORPH_ERODE, kernel, iterations=1)\n",
" # morphed1l = cv2.morphologyEx(morphed1l, cv2.MORPH_OPEN, kernel, iterations=1)\n",
" # morphed1l = cv2.morphologyEx(morphed1l, cv2.MORPH_OPEN, kernel, iterations=1)\n",
" # morphed1l = cv2.morphologyEx(morphedl, cv2.MORPH_ERODE, kernelell, iterations=2)\n",
" \n",
" emorphed1l = cv2.morphologyEx(ethreshl, cv2.MORPH_ERODE, kernel, iterations=1)\n",
"\n",
" # return morphedl\n",
" \n",
" contours, heirarchy = cv2.findContours(morphed1l, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)\n",
" biggestcontour = max(contours, key=cv2.contourArea)\n",
" \n",
" # temp = cv2.drawContours(colourthresh, [biggestcontour], -1, (0,255,0), thickness=1)\n",
" # return temp\n",
" \n",
" \n",
" blank = np.full(labl.shape, 255, dtype=np.uint8)\n",
" mask1 = blank.copy()\n",
" # mask1 = mf.padWithColour(mask1, sdim*2, sdim*2, fill=255)\n",
" mask1 = cv2.drawContours(mask1, [biggestcontour], -1, 0, thickness=cv2.FILLED)\n",
" ## need to change the erosion so that if the paper goes to the edge, it doesn't get eroded in (because that means the paper is right to the edge and writing may be close)\n",
" \n",
" contours, heirarchy = cv2.findContours(morphed1l, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)\n",
" biggestcontour = max(contours, key=cv2.contourArea)\n",
" \n",
" emask1 = blank.copy()\n",
" emask1 = cv2.drawContours(emask1, [biggestcontour], -1, 0, thickness=cv2.FILLED)\n",
" \n",
" mask1 = 255-cv2.morphologyEx(255-mask1, cv2.MORPH_ERODE, kernel, iterations=2)\n",
" \n",
" emask1 = 255-cv2.morphologyEx(255-emask1, cv2.MORPH_ERODE, kernel, iterations=2)\n",
" \n",
" \n",
" # mask1 = mask1[(sdim*2):-(sdim*2), (sdim*2):-(sdim*2)]\n",
" # return mask1\n",
" \n",
" # morphed2l = mf.padWithColour(morphedl, sdim*2, sdim*2, fill=255)\n",
" morphed2l = cv2.morphologyEx(morphedl, cv2.MORPH_OPEN, kernel, iterations=1)\n",
" morphed2l = cv2.morphologyEx(morphedl, cv2.MORPH_ERODE, kernel, iterations=1)\n",
" # morphed2l = morphed2l[(sdim*2):-(sdim*2), (sdim*2):-(sdim*2)]\n",
" \n",
" # return morphed2l\n",
" # print(mask1.shape)\n",
" # print(morphed2l.shape)\n",
" morphed2l = cv2.bitwise_or(morphed2l, 255-mask1)\n",
" # return morphed2l\n",
"\n",
" # paddedthreshl = mf.padWithColour(morphed2l, sdim*2, sdim*2, fill=255)\n",
" # temp = cv2.drawContours(colourthresh, [biggestcontour], -1, (0,255,0), thickness=1)\n",
" # return temp\n",
"\n",
"\n",
" morphed2l = cv2.morphologyEx(morphed2l, cv2.MORPH_ERODE, kernel, iterations=1)\n",
" morphed2l = cv2.morphologyEx(morphed2l, cv2.MORPH_ERODE, kernelell, iterations=1)\n",
" # return morphed2l\n",
" # morphed2l = cv2.bitwise_or(morphed2l, 255-emask1)\n",
" \n",
" # morphed2l = morphed2l[(sdim*2):-(sdim*2), (sdim*2):-(sdim*2)]\n",
" \n",
" resizedmask = cv2.resize(255-morphed2l, (ogshape[1], ogshape[0]))\n",
" temp_final = cv2.bitwise_or(ogthreshl, resizedmask)\n",
" \n",
" dim=3\n",
" kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (dim, dim))\n",
" temp_final = cv2.morphologyEx(temp_final, cv2.MORPH_OPEN, kernel)\n",
" temp_final = cv2.morphologyEx(temp_final, cv2.MORPH_OPEN, kernel)\n",
" # temp_final = cv2.morphologyEx(temp_final, cv2.MORPH_CLOSE, kernel)\n",
" # temp_final = cv2.morphologyEx(temp_final, cv2.MORPH_OPEN, kernel)\n",
" return temp_final"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"def cropclarifying(image):\n",
" # whitedbackground = whiteoutbackground(image)\n",
" # return whitedbackground\n",
"\n",
" # textrefined = mf.textClarifying(whitedbackground)\n",
" textrefined = textleaver(image)\n",
" return textrefined\n",
" #maybe now is when I put in the line removing function\n",
"\n",
" lineout = mf.removeLinesFromText(textrefined)\n",
"\n",
" return lineout\n",
" # implement a function that's called refine text"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"def houghlineprocessing(image):\n",
" croppedanddeskewed, angle = mf.houghlinedeskewandcrop(image)\n",
" # return croppedanddeskewed\n",
" \n",
" \n",
" # postprocessed = cropclarifying(croppedanddeskewed)\n",
" postprocessed = croppedanddeskewed\n",
" # return postprocessed\n",
" # postprocessed = mf.croptoblack(postprocessed)\n",
" \n",
" # postprocessed = cv2.cvtColor(postprocessed, cv2.COLOR_GRAY2BGR)\n",
" # return postprocessed\n",
" \n",
" # final = mf.externaldeskew(postprocessed, fill=(255,255,255))\n",
" # rotangle = mf.receipttextdeskew(postprocessed, fill=(255,255,255), returnangle=True)\n",
" final = postprocessed\n",
" \n",
" \n",
" # final = mf.croptoblack(final)\n",
" \n",
" # cv2.imshow(\"postprocessed\", mf.ResizeWithAspectRatio(postprocessed, 1000))\n",
" # cv2.imshow(\"final\", mf.ResizeWithAspectRatio(final, 1000))\n",
" # cv2.waitKey(0)\n",
" # cv2.destroyAllWindows()\n",
" \n",
" return final"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"# print(img.shape)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.0\n"
]
}
],
"source": [
"# prepped, scaler, hp, vp = mf.squareandthenresize(img, fill=255, width=1000, returnscalerinfo=True)\n",
"outs = houghlineprocessing(img)\n",
"# outs = prepimageforhoughline(img, returnrect=True)\n",
"# print(img.shape)\n",
"# outs = houghlinedeskewandcrop(img)\n",
"# outs = outs[0]\n",
"# print(croprect)\n",
"#need to fix premorphCrop. it removes too much"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"# shrunk, scaler, hp, vp = mf.squareandthenresize(img, fill=255, width=1000, returnscalerinfo=True)\n",
"# shrunk1, croprect = mf.premorphCrop(shrunk)\n",
"# print(croprect)\n",
"# print(int(30*4.032 - 0))\n",
"# # temp = img[100:, :, :]\n",
"# temp = shrunk[croprect[1]:croprect[1]+croprect[3], croprect[0]:croprect[0]+croprect[2], :]\n"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"# cv2.imshow(\"temp\", mf.ResizeWithAspectRatio(out, height=1000))\n",
"# # cv2.imshow(\"shrunk1\", mf.ResizeWithAspectRatio(shrunk1, height=1000))\n",
"# cv2.waitKey(0)\n",
"# cv2.destroyAllWindows()"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"testall = True"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"if not testall:\n",
" showimgs(outs)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"# # for out in outs:\n",
"# # if (out.shape[0] > out.shape[1]):\n",
"# # cv2.imshow(\"test1\", mf.ResizeWithAspectRatio(out, height=1000))\n",
"# # else:\\\n",
"# # cv2.imshow(\"test1\", mf.ResizeWithAspectRatio(out, width=1000))\n",
"# # key = cv2.waitKey(0)\n",
"# # cv2.destroyAllWindows()\n",
"# # if (key == 107):\n",
"# # break\n",
"# if (isinstance(outs, np.ndarray)):\n",
"# if (outs.shape[0] > outs.shape[1]):\n",
"# cv2.imshow(\"test\", mf.ResizeWithAspectRatio(outs, height=1350))\n",
"# else:\n",
"# cv2.imshow(\"test\", mf.ResizeWithAspectRatio(outs, width=1000))\n",
"# else:\n",
"# for i, out in enumerate(outs):\n",
"# if (out.shape[0] > out.shape[1]):\n",
"# cv2.imshow(\"test\"+str(i), mf.ResizeWithAspectRatio(out, height=1350))\n",
"# else:\n",
"# cv2.imshow(\"test\"+str(i), mf.ResizeWithAspectRatio(out, width=1000))\n",
"# cv2.waitKey(0)\n",
"# cv2.destroyAllWindows()"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.9740282517223996\n",
"-2.0053522829578814\n",
"-0.9740282517223996\n",
"0.0\n",
"0.9740282517223996\n",
"-0.9740282517223996\n",
"-0.011669615052326776\n",
"2.0053522829578814\n",
"0.0\n",
"0.0\n",
"0.0\n",
"-2.979380534680281\n",
"0.0\n",
"0.0\n",
"-2.0053522829578814\n",
"-11.000789666511807\n",
"average time: 0.19967518746852875(s)\n"
]
}
],
"source": [
"if testall:\n",
" results = testondataset(\"/mnt/dataset/baseimages/\", houghlineprocessing)"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
"# if testall:\n",
"# showimgs(results)"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [],
"source": [
"# print(results[0])"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
"if testall:\n",
" writeimgs(\"/mnt/code/autocropper/result_images/\", results)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@ -1,11 +0,0 @@
#ifndef CROPPER_H
#define CROPPER_H
#include <opencv2/opencv.hpp>
bool crop(cv::InputArray src, cv::OutputArray dst, bool fastsearch = true, int imageHeight = 700);
#endif //CROPPER_H

View File

@ -1 +0,0 @@
#define DEBUG 1

View File

@ -1,43 +0,0 @@
#include <cropper.h>
#include <opencv2/opencv.hpp>
// PLAN:
// Implement selective search
// Implement Canny edge detection and then find a good rectangle
// Do L2 loss with the corners of the rectangle and choose the selective search rectangle with the lowest loss
//for testing delete later
#include <iostream>
int main(int argc, char** argv) {
if (argc < 2) {
std::cerr << "BAD" << std::endl;
return -1;
}
cv::Mat imOut, result;
imOut = cv::imread(argv[1]);
if (imOut.empty()) {
std::cout << "Could not open or find the image!\n" << std::endl;
std::cout << "Usage: " << argv[0] << " <Input image>" << std::endl;
return -1;
}
crop(imOut, result, true, 1000);
int imageHeight = 800;
int newWidth = result.cols * imageHeight / result.rows;
cv::resize(result, result, cv::Size(newWidth, imageHeight));
cv::imshow("banana", result);
imwrite("../testing_space/cropped.jpg", result);
cv::waitKey();
return 0;
}

File diff suppressed because it is too large Load Diff

File diff suppressed because one or more lines are too long

View File

@ -1,430 +0,0 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"# https://docs.opencv.org/3.4/d9/db0/tutorial_hough_lines.html\n",
"# https://medium.com/@9sphere/machine-vision-recipes-deskewing-document-images-e17827894c34\n",
"# https://towardsdatascience.com/pre-processing-in-ocr-fc231c6035a7"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#initially for deskewing and cropping. moving to a doc for just cropping now that deskewing"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"import cv2\n",
"import numpy as np\n",
"import math\n",
"import myfunctions as mf\n",
"\n",
"import scipy.stats as st"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"# def ResizeWithAspectRatio(image, width=None, height=None, inter=cv2.INTER_AREA, retscale=False):\n",
"# dim = None\n",
"# (h, w) = image.shape[:2]\n",
"\n",
"# if width is None and height is None:\n",
"# if (retscale == True):\n",
"# return (image, 1)\n",
"# return image\n",
"# if width is None:\n",
"# r = height / float(h)\n",
"# dim = (int(w * r), height)\n",
"# else:\n",
"# r = width / float(w)\n",
"# dim = (width, int(h * r))\n",
"\n",
"# if (retscale == True):\n",
"# # print(\"hi\")\n",
"# return (cv2.resize(image, dim, interpolation=inter), 1/r)\n",
"# return cv2.resize(image, dim, interpolation=inter)\n",
"\n",
"\n",
"# class SquarePad:\n",
"# def __init__(self, fill):\n",
"# self.fill = fill\n",
" \n",
"# def __call__(self, image):\n",
"# w, h = image.shape[1], image.shape[0]\n",
"# max_wh = np.max([w, h])\n",
"# hp = int((max_wh - w) / 2)\n",
"# vp = int((max_wh - h) / 2)\n",
"# padding = (hp, vp, hp, vp)\n",
"# return cv2.copyMakeBorder(image, vp, vp, hp, hp, cv2.BORDER_CONSTANT, self.fill)\n",
" \n",
" \n",
" \n",
"# def rotate(img, angle):\n",
"# rows,cols = img.shape[0], img.shape[1]\n",
"# M = cv2.getRotationMatrix2D((cols/2,rows/2),angle,1)\n",
"# dst = cv2.warpAffine(img,M,(cols,rows))\n",
"# return dst"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
"# def morphologyCrop(image):\n",
"# # convert to grayscale\n",
"# gray = cv2.cvtColor(image,cv2.COLOR_BGR2GRAY)\n",
"\n",
"# # threshold\n",
"# thresh = cv2.threshold(gray, 170, 255, cv2.THRESH_BINARY)[1]\n",
"\n",
"# # apply morphology\n",
"# kernel = np.ones((7,7), np.uint8)\n",
"# morph = cv2.morphologyEx(thresh, cv2.MORPH_CLOSE, kernel)\n",
"# kernel = np.ones((9,9), np.uint8)\n",
"# morph = cv2.morphologyEx(morph, cv2.MORPH_ERODE, kernel)\n",
"\n",
"# # get largest contour\n",
"# contours = cv2.findContours(morph, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)\n",
"# contours = contours[0] if len(contours) == 2 else contours[1]\n",
"# area_thresh = 0\n",
"# for c in contours:\n",
"# area = cv2.contourArea(c)\n",
"# if area > area_thresh:\n",
"# area_thresh = area\n",
"# big_contour = c\n",
"\n",
"\n",
"# # get bounding box\n",
"# x,y,w,h = cv2.boundingRect(big_contour)\n",
"\n",
"# # draw filled contour on black background\n",
"# mask = np.zeros_like(gray)\n",
"# mask = cv2.merge([mask,mask,mask])\n",
"# cv2.drawContours(mask, [big_contour], -1, (255,255,255), cv2.FILLED)\n",
"\n",
"# # apply mask to input\n",
"# result1 = image.copy()\n",
"# result1 = cv2.bitwise_and(result1, mask)\n",
"\n",
"# # crop result\n",
"# result2 = result1[y:y+h, x:x+w]\n",
"# return result2"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
"\n",
"# x = -2*np.pi/3\n",
"# print(x)\n",
"# print(np.pi/3)\n",
"# print(x % np.pi)"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [],
"source": [
"# def lineAngle(line):\n",
"# # print(line)\n",
"# angle = (math.atan2(line[3] - line[1], line[2] - line[0]) % np.pi) - (np.pi/2)\n",
"# return angle\n",
" \n",
"# def WithinXDegrees(lines, margin):\n",
"# # outlines = np.array([[]])\n",
"# outlines = np.empty((0, 4))\n",
"# # print(outlines.shape)\n",
"# for line in lines:\n",
"# # print(type(line))\n",
"# # print(abs(lineAngle(line[0])))\n",
"# if (np.rad2deg(abs(lineAngle(line[0]))) <= margin):\n",
"# outlines = np.append(outlines, [line[0]], axis=0)\n",
"# return outlines\n",
"\n",
"# def lineBoundingRect(lines):\n",
"# maxvals = lines.max(0)\n",
"# minvals = lines.min(0)\n",
"# boundingrect = (min(minvals[0],minvals[2]), min(minvals[1],minvals[3]), max(maxvals[0],maxvals[2]),max(maxvals[1],maxvals[3]))\n",
"# return boundingrect\n",
"# # print(lines.max(0))\n",
"# # print(type(lines))"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
"img = cv2.imread('./test_images/IMG_7605.jpg')\n",
"img = mf.SquarePad(fill=255)(img)\n",
"img = mf.rotate(img, 54)\n",
"img = mf.morphologyCrop(mf.ResizeWithAspectRatio(img,1000))"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [],
"source": [
"# img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)\n",
"# img = cv2.threshold(img, 200, 255, cv2.THRESH_BINARY)[1]"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [],
"source": [
"cv2.imshow(\"Detected Lines (in red) - Standard Hough Line Transform\", mf.ResizeWithAspectRatio(mf.SquarePad(fill=255)(img), 1000))\n",
"cv2.waitKey(0)\n",
"cv2.destroyAllWindows()"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [],
"source": [
"resizedimg = mf.ResizeWithAspectRatio(mf.SquarePad(fill=255)(img), 500)\n",
"\n",
"# cv2.imshow(\"Detected Lines (in red) - Standard Hough Line Transform\", img)\n",
"# cv2.waitKey(0)\n",
"# cv2.destroyAllWindows()\n",
"\n",
"gray = cv2.cvtColor(resizedimg ,cv2.COLOR_BGR2GRAY)\n",
"gray = cv2.threshold(gray, 200, 255, cv2.THRESH_BINARY)[1]\n",
"cdst = resizedimg.copy()\n",
"\n",
"\n",
"dst = cv2.Canny(gray, 50, 200, None, 3)\n",
"lines = cv2.HoughLines(dst, 1, np.pi/180, 150, None, 0, 0)\n"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [],
"source": [
"angles = np.zeros(len(lines))\n",
"if lines is not None:\n",
" for i in range(0, len(lines)):\n",
" rho = lines[i][0][0]\n",
" theta = lines[i][0][1]\n",
" a = math.cos(theta)\n",
" b = math.sin(theta)\n",
" x0 = a * rho\n",
" y0 = b * rho\n",
" unroundedpt1 = (x0 + 1000*(-b), y0 + 1000*(a))\n",
" unroundedpt2 = (x0 - 1000*(-b), y0 - 1000*(a))\n",
" pt1 = (int(unroundedpt1[0]), int(unroundedpt1[1]))\n",
" pt2 = (int(unroundedpt2[0]), int(unroundedpt2[1]))\n",
" v1_theta = math.atan2(pt1[1], pt1[0])\n",
" v2_theta = math.atan2(pt2[1], pt2[0])\n",
" # print(math.atan2(unroundedpt2[1] - unroundedpt1[1], unroundedpt2[0] - unroundedpt1[0]) % np.pi)\n",
" # print(lineAngle((unroundedpt1[0], unroundedpt1[1], unroundedpt2[0], unroundedpt2[1])))\n",
" # angles[i] = math.atan2(unroundedpt2[1] - unroundedpt1[1], unroundedpt2[0] - unroundedpt1[0]) % np.pi\n",
" angles[i] = mf.lineAngle((unroundedpt1[0], unroundedpt1[1], unroundedpt2[0], unroundedpt2[1]))\n",
" cv2.line(cdst, pt1, pt2, (0,0,255), 3, cv2.LINE_AA)"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"-56.7228217179515\n",
"-56.7228217179515\n"
]
}
],
"source": [
"# print(st.mode(np.around(angles, decimals=1)))\n",
"mode = st.mode(np.around(angles, decimals=2))[0]\n",
"print(np.rad2deg(mode))\n",
"# slope = math.tan(np.deg2rad(mode))\n",
"# print(slope)\n",
"# myy0 = 0\n",
"# p1 = [0,myy0]\n",
"# p2 = [0,myy0]\n",
"# while (math.dist(p1, p2) < 5000):\n",
"# p2[0] += 0.5\n",
"# p2[1] += 0.5*slope*1000\n",
"# p2[1] = int(p2[1])\n",
"# print(p2)\n",
"# cv2.line(cdst, p1, p2, (0,255,0), 3, cv2.LINE_AA)\n",
"# rotationangle = np.rad2deg(mode)-90\n",
"rotationangle = np.rad2deg(mode)\n",
"print(rotationangle)"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [],
"source": [
"cv2.imshow(\"Detected Lines (in red) - Standard Hough Line Transform\", cdst)\n",
"cv2.waitKey(0)\n",
"cv2.destroyAllWindows()"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [],
"source": [
"# cv2.imshow(\"Detected Lines (in red) - Standard Hough Line Transform\", rotate(cdst,rotationangle))\n",
"# cv2.waitKey(0)\n",
"# cv2.destroyAllWindows()"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [],
"source": [
"rotatedimg = mf.SquarePad(fill=255)(mf.rotate(img, rotationangle))\n"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [],
"source": [
"# cv2.imshow(\"Rotated Image\", ResizeWithAspectRatio(rotatedimg, 1000))\n",
"# cv2.waitKey(0)\n",
"# cv2.destroyAllWindows()"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
"outputs": [],
"source": [
"resizedrotatedimg = mf.ResizeWithAspectRatio(rotatedimg, 500)\n",
"gray1 = cv2.cvtColor(resizedrotatedimg, cv2.COLOR_BGR2GRAY)\n",
"dst1 = cv2.Canny(gray1, 0, 500, None, 3)\n",
"cdstP = resizedrotatedimg.copy()\n",
"cdstPmargin = cdstP.copy()\n",
"linesP = cv2.HoughLinesP(dst1, 1, np.pi / 180, 30, None, 100, 30)"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {},
"outputs": [],
"source": [
"if linesP is not None:\n",
" for i in range(0, len(linesP)):\n",
" l = linesP[i][0]\n",
" cv2.line(cdstP, (l[0], l[1]), (l[2], l[3]), (0,0,255), 3, cv2.LINE_AA)"
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {},
"outputs": [],
"source": [
"cv2.imshow(\"Detected Lines (in red) - Standard Hough Line Transform\", cdstP)\n",
"cv2.waitKey(0)\n",
"cv2.destroyAllWindows()"
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {},
"outputs": [],
"source": [
"# print(linesP)\n",
"marginlines = mf.WithinXDegrees(linesP, 2)\n",
"# print(marginlines)\n",
"if marginlines is not None:\n",
" for i in range(0, len(marginlines)):\n",
" l = marginlines[i]\n",
" cv2.line(cdstPmargin, (int(l[0]), int(l[1])), (int(l[2]), int(l[3])), (0,0,255), 3, cv2.LINE_AA)\n",
" \n",
"# boundingrectout = mf.lineBoundingRect(marginlines)\n",
"# # print(boundingrectout)\n",
"# cdstPmargin = cv2.rectangle(cdstPmargin,(int(boundingrectout[0]),int(boundingrectout[1])),(int(boundingrectout[2]),int(boundingrectout[3])),(0,255,0),2)"
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {},
"outputs": [],
"source": [
"cv2.imshow(\"Detected Lines (in red) - Standard Hough Line Transform\", cdstPmargin)\n",
"cv2.waitKey(0)\n",
"cv2.destroyAllWindows()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@ -1,385 +0,0 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 137,
"metadata": {},
"outputs": [],
"source": [
"version=2.0\n",
"cachepath=\"../.cache/\"\n",
"savepath=\"./savespot/\""
]
},
{
"cell_type": "code",
"execution_count": 138,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from torch.utils.data import DataLoader\n",
"import torch.nn as nn\n",
"import torch.nn.functional as fn\n",
"import torch.optim as optim\n",
"import torchvision.transforms.functional as tvf\n",
"import torchvision.transforms.v2 as v2\n",
"import torchvision.models as models\n",
"import torchvision.transforms as t\n",
"\n",
"\n",
"from PIL import Image\n",
"\n",
"import datasets as ds\n",
"from tqdm.autonotebook import tqdm\n",
"\n",
"import random\n",
"\n",
"import matplotlib.pyplot as plt\n",
"\n",
"import numpy as np\n",
"\n",
"\n",
"torch.cuda.empty_cache()\n",
"\n",
"\n",
"import os\n",
"import cv2"
]
},
{
"cell_type": "code",
"execution_count": 139,
"metadata": {},
"outputs": [],
"source": [
"# array = np.load(\"./testing_space/outputarray.npy\")\n",
"# counter = np.load(\"./testing_space/counter.npy\")"
]
},
{
"cell_type": "code",
"execution_count": 140,
"metadata": {},
"outputs": [],
"source": [
"# print(array)\n",
"# print(counter)"
]
},
{
"cell_type": "code",
"execution_count": 141,
"metadata": {},
"outputs": [],
"source": [
"class RotationDeterminer(nn.Module):\n",
" def __init__(self, new=False):\n",
" super(RotationDeterminer,self).__init__()\n",
" \n",
" torch.cuda.empty_cache()\n",
" \n",
" self.device = torch.device(\"cpu\")\n",
" if torch.cuda.is_available:\n",
" self.device = torch.device(\"cuda:0\")\n",
" \n",
" \n",
" self.appliers = [v2.RandomApply(transforms=[v2.RandomPosterize(bits=1)], p=0.25),\n",
" v2.RandomApply(transforms=[v2.ElasticTransform(alpha=25.0)], p=0.25), # maybe add fill=appliedFill\n",
" v2.RandomApply(transforms=[v2.GaussianBlur(kernel_size=(5,9), sigma=(0.1,2.))],p=0.25),\n",
" v2.RandomApply(transforms=[v2.RandomEqualize()],p=0.25)]\n",
" \n",
" \n",
" # self.conv = nn.Sequential(nn.Conv2d(3, 9, kernel_size=11,stride=3), # 1100 x 1100 => 201 x 201\n",
" # nn.ReLU(inplace=True),\n",
" # nn.Conv2d(9, 18, kernel_size=5,stride=1),\n",
" # nn.ReLU(inplace=True),\n",
" # nn.MaxPool2d(kernel_size=4, stride=2),\n",
" # nn.Conv2d(18, 36, kernel_size=3,stride=2),\n",
" # nn.BatchNorm2d(36),\n",
" # nn.ReLU(inplace=True),\n",
" # nn.Conv2d(36, 72, kernel_size=3,stride=2),\n",
" # nn.ReLU(inplace=True),\n",
" # nn.AvgPool2d(kernel_size=5, stride=3),\n",
" # nn.Conv2d(72, 144, kernel_size=3,stride=1),\n",
" # nn.ReLU(inplace=True),\n",
" # nn.Conv2d(144, 288, kernel_size=5,stride=1),\n",
" # nn.ReLU(inplace=True),\n",
" # nn.MaxPool2d(kernel_size=4, stride=1),\n",
" # nn.Conv2d(288, 192, kernel_size=3,stride=1),\n",
" # nn.ReLU(inplace=True),\n",
" # nn.Conv2d(192, 192, kernel_size=3,stride=1), # => 1\n",
" # nn.ReLU(inplace=True))\n",
" # print(\"hi\")\n",
" self.conv = models.resnet18(pretrained=new)\n",
" \n",
" self.classifier = nn.Sequential(nn.Linear(1000, 4096),\n",
" nn.ReLU(inplace=True),\n",
" nn.Linear(4096,1))\n",
" \n",
" self.lossfunc = nn.MSELoss()\n",
" \n",
" self.imageprep = v2.Compose([self.SquarePad(),v2.Resize(512),v2.Grayscale(num_output_channels=3),v2.CenterCrop(512),v2.ToImageTensor(), v2.ConvertImageDtype()])\n",
" \n",
" \n",
" class SquarePad:\n",
" def __call__(self, image):\n",
" # print(\"hi type:\", type(image))\n",
" temp = image.size()\n",
" w = temp[-2]\n",
" h = temp[-1]\n",
" max_wh = max([w, h])\n",
" hp = int((max_wh - w) / 2)\n",
" vp = int((max_wh - h) / 2)\n",
" padding = (hp, vp, hp, vp)\n",
" return tvf.pad(image, padding, 0, 'edge')\n",
"\n",
"\n",
" \n",
"\n",
" \n",
" def forward(self, image):\n",
"\n",
" transformedimage = self.imageprep(image)\n",
" transformedimage = transformedimage.to(self.device)\n",
"\n",
" if (len(transformedimage.shape) != 4 and len(transformedimage.shape) != 3):\n",
" raise Exception(\"Sorry, Dimension of image is incorrect (\", len(transformedimage.shape),\"). Expected a 3D (single image) or 4D (batch of images) tensor\")\n",
"\n",
" if (len(transformedimage.shape) == 3):\n",
" x = transformedimage.unsqueeze(0)\n",
" else:\n",
" x = transformedimage\n",
" \n",
" x = self.conv(x)\n",
" # print(x.shape)\n",
" # x = nn.Flatten(start_dim=-1)(x)\n",
" # print(x.shape)\n",
" x = self.classifier(x)\n",
" # print(x.shape)\n",
" guessRotation = nn.Flatten(start_dim=0)(x)\n",
" \n",
" return guessRotation\n",
" \n",
" def loss(self, guess, trueAnswer):\n",
" return self.lossfunc(guess, trueAnswer)\n",
" \n",
" "
]
},
{
"cell_type": "code",
"execution_count": 142,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/usr/local/lib/python3.10/dist-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.\n",
" warnings.warn(\n",
"/usr/local/lib/python3.10/dist-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=ResNet18_Weights.IMAGENET1K_V1`. You can also use `weights=ResNet18_Weights.DEFAULT` to get the most up-to-date weights.\n",
" warnings.warn(msg)\n"
]
}
],
"source": [
"model = RotationDeterminer(new=True)\n",
"device = torch.device(\"cpu\")\n",
"if torch.cuda.is_available:\n",
" device = torch.device(\"cuda:0\")\n",
" model = model.to(device)"
]
},
{
"cell_type": "code",
"execution_count": 143,
"metadata": {},
"outputs": [],
"source": [
"# def ResizeWithAspectRatio(image, width=None, height=None, inter=cv2.INTER_AREA):\n",
"# dim = None\n",
"# (h, w) = image.shape[:2]\n",
"\n",
"# if width is None and height is None:\n",
"# return image\n",
"# if width is None:\n",
"# r = height / float(h)\n",
"# dim = (int(w * r), height)\n",
"# else:\n",
"# r = width / float(w)\n",
"# dim = (width, int(h * r))\n",
"\n",
"# return cv2.resize(image, dim, interpolation=inter)"
]
},
{
"cell_type": "code",
"execution_count": 163,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([1, 4032, 3024])\n",
"torch.Size([3, 4032, 3024])\n",
"0.7532281875610352\n"
]
}
],
"source": [
"working_dataset = ds.load_from_disk(cachepath + \"datasets/customrotation/\")\n",
"prepimage = v2.Compose([v2.Grayscale(num_output_channels=3),v2.Resize(512), v2.CenterCrop(512),v2.ToImageTensor(), v2.ConvertImageDtype()])\n",
"tensorize = v2.Compose([v2.ToImageTensor(), v2.ConvertImageDtype()])\n",
"grayscaler = v2.Grayscale(num_output_channels=3)\n",
"working_dataset.set_transform(prepimage)\n",
"counter = np.load(savepath + \"/v\"+str(version)+\"/counter.npy\")\n",
"model.load_state_dict(torch.load(savepath + \"/v\"+str(version)+\"/modelsave\" + str(counter) +\"epochs\"))"
]
},
{
"cell_type": "code",
"execution_count": 165,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([1, 800, 723])\n",
"torch.Size([3, 800, 723])\n",
"-1.3860492706298828\n"
]
}
],
"source": [
"filereadimage = cv2.imread(\"./testing_space/cropped.jpg\", 0)\n",
"# print(type(filereadimage))\n",
"tensorizedimage = torch.unsqueeze(torch.from_numpy(filereadimage),0)\n",
"print(tensorizedimage.shape)\n",
"adjustedtensorizedimage = tensorize(grayscaler(t.ToPILImage()(tensorizedimage)))\n",
"print(adjustedtensorizedimage.shape)\n",
"rotation = model(adjustedtensorizedimage).item()\n",
"print(rotation)\n",
"rotatedimage = t.Resize(size=1000)(tvf.rotate(adjustedtensorizedimage, rotation))\n",
"# imS = mf.ResizeWithAspectRatio(filereadimage, 1000)\n",
"# imS = cv2.resize(filereadimage, (960, 540)) \n",
"open_cv_image = np.array(t.ToPILImage()(rotatedimage))\n",
"cv2.imshow(f'image', open_cv_image)\n",
"key = cv2.waitKey(0)\n",
"cv2.destroyAllWindows()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"index = 0\n",
"active_dataset = working_dataset['test']"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# plt.imshow(t.ToPILImage()(working_dataset['test'][3]['image']), cmap='gray', vmin=0, vmax=255)\n",
"# plt.show()\n",
"# rotationapplier = model(working_dataset['test'][3]['image']).item()\n",
"# print(rotationapplier)\n",
"# img = tvf.rotate(working_dataset['test'][3]['image'], rotationapplier)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# plt.imshow(t.ToPILImage()(img), cmap='gray', vmin=0, vmax=255)\n",
"# plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# # To call the model on a bunch of the images and rotate them back\n",
"\n",
"# while(True):\n",
"# activeimage = active_dataset[index]['image']\n",
"# # img = cv2.imread(active_dataset[index]['image'], 0)\n",
"# activeimage = tvf.rotate(activeimage, model(activeimage).item())\n",
"# open_cv_image = np.array(t.ToPILImage()(activeimage))\n",
"# print(index)\n",
"# cv2.imshow(f'current image', open_cv_image)\n",
"# key = cv2.waitKey(0)\n",
"\n",
"# if key == ord('c'):\n",
"# print(\"\\tCopying this one\")\n",
"# elif key == ord('x'):\n",
"# index -= 1\n",
"# elif key == ord('v'):\n",
"# index +=1\n",
"# elif key == ord('q'):\n",
"# break\n",
"\n",
"# cv2.destroyAllWindows()\n",
"# cv2.destroyAllWindows()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# # for trying to call the model on the picture repeatedly to see if it will just get more and more straight if it's called multiple times\n",
"\n",
"# currentimage = working_dataset['test'][3]['image']\n",
"# while(True):\n",
"# rotationapplier = model(currentimage).item()\n",
"# print(rotationapplier)\n",
"# img = tvf.rotate(currentimage, rotationapplier)\n",
"# open_cv_image = np.array(t.ToPILImage()(img))\n",
"# cv2.imshow(f'current image', open_cv_image)\n",
"# key = cv2.waitKey(0)\n",
" \n",
"# if key == ord('q'):\n",
"# break\n",
"# elif key == ord('v'):\n",
"# currentimage = img\n",
"# # cv2.destroyAllWindows()\n",
"# cv2.destroyAllWindows()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@ -1,417 +0,0 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
"# from datasets import load_dataset, Image\n",
"import datasets as ds\n",
"import PIL\n",
"import torchvision.transforms.functional as tvf\n",
"from torchvision.transforms import v2\n",
"import random\n",
"import numpy as np"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [],
"source": [
"original_dataset = ds.load_dataset(\"aharley/rvl_cdip\")"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
"# Create own dataset from the images of the original dataset but make the labels the float value for the rotation. do the random rotation on all of the training ones but the labels for the validation and test should/can be 0\n",
"trainblacklist = [5, 102664, 102667, 277943]\n",
"testblacklist = [6, 11, 14, 18, 27, 35, 37, 54, 33669] # 33669 is a corrupt image\n",
"validationblacklist = []\n",
"og_training_dataset = original_dataset['train'].select([i for i in range(len(original_dataset['train'])) if i not in trainblacklist])\n",
"og_testing_dataset = original_dataset['test'].select([i for i in range(len(original_dataset['test'])) if i not in testblacklist])\n",
"og_validation_dataset = original_dataset['validation'].select([i for i in range(len(original_dataset['validation'])) if i not in validationblacklist])\n",
"\n",
"# type(og_testing_dataset)\n",
"\n",
"# print(type(transform_picture(og_testing_dataset[0], params)))\n",
"# out = transform_picture(og_testing_dataset[0], params)\n",
"# print(out['image'])\n"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"319998\n"
]
}
],
"source": [
"print(len(og_training_dataset))"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"def has_valid_image(ex):\n",
" print(type(ex))\n",
" try:\n",
" PIL.Image.open(ex[\"image\"][\"path\"])\n",
" except Exception:\n",
" print(\"hi\")\n",
" return False\n",
" return True\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"# dataset = original_dataset.cast_column(\"image\", ds.Image(decode=False))\n",
"# dataset = dataset.filter(has_valid_image)\n",
"# filtered_dataset = dataset.cast_column(\"image\", ds.Image(decode=True))"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"# Parameter Declaration\n",
"minRotation=-180\n",
"maxRotation=180\n",
"minTranslation=0\n",
"maxTranslation=150\n",
"minScale = 0.4\n",
"maxScale = 1\n",
"minShear = 0\n",
"maxShear = 0\n",
"\n",
"minFill=255\n",
"maxFill=255\n",
"\n",
"params = {\"minRotation\":minRotation,\"maxRotation\":maxRotation,\"minTranslation\":minTranslation,\"maxTranslation\":maxTranslation,\"minScale\":minScale,\"maxScale\":maxScale,\"minShear\":minShear,\"maxShear\":maxShear,\"minFill\":minFill,\"maxFill\":maxFill}"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"class SquarePad:\n",
" def __init__(self, fill):\n",
" self.fill = fill\n",
" \n",
" def __call__(self, image):\n",
" w, h = image.size\n",
" max_wh = np.max([w, h])\n",
" hp = int((max_wh - w) / 2)\n",
" vp = int((max_wh - h) / 2)\n",
" padding = (hp, vp, hp, vp)\n",
" return tvf.pad(image, padding,fill=self.fill, padding_mode='constant')\n",
"\n",
"\n",
"\n",
"\n",
"def transform_picture(image_label, parameters):\n",
" image = image_label['image']\n",
"\n",
" appliedRotation = random.uniform(parameters['minRotation'], parameters['maxRotation'])\n",
" appliedXTranslation = random.uniform(parameters['minTranslation'], parameters['maxTranslation'])\n",
" appliedYTranslation = random.uniform(parameters['minTranslation'], parameters['maxTranslation'])\n",
" appliedScale = random.uniform(parameters['minScale'], parameters['maxScale'])\n",
" appliedFill = random.uniform(parameters['minFill'], parameters['maxFill'])\n",
" appliedXShear = random.uniform(parameters['minShear'], parameters['maxShear'])\n",
" appliedYShear = random.uniform(parameters['minShear'], parameters['maxShear'])\n",
" \n",
" appliers = v2.Compose([v2.RandomApply(transforms=[v2.RandomPosterize(bits=1)], p=0.25),\n",
" v2.RandomApply(transforms=[v2.ElasticTransform(alpha=25.0, fill=appliedFill)], p=0.25), # maybe add fill=appliedFill\n",
" v2.RandomApply(transforms=[v2.GaussianBlur(kernel_size=(5,9), sigma=(0.1,2.))],p=0.25),\n",
" v2.RandomApply(transforms=[v2.RandomEqualize()],p=0.25),\n",
" SquarePad(fill=appliedFill),v2.Resize(1100)])\n",
" \n",
" adjustedimage = tvf.affine(image, appliedRotation, [appliedXTranslation,appliedYTranslation], appliedScale, [appliedXShear, appliedYShear], fill=appliedFill)\n",
"\n",
" adjustedimage = appliers(adjustedimage)\n",
"\n",
" \n",
" return {'image':adjustedimage,'rotation':appliedRotation}"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "bd95dc0201c2419e982f8167e16db6b5",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Map (num_proc=4): 0%| | 0/39999 [00:00<?, ? examples/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"new_testing_dataset = og_testing_dataset.map(transform_picture, fn_kwargs={'parameters':params}, num_proc=4)\n",
"#33669 has bad EXIF data so it is ignored at load time"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "17c8b6a170ae4072b385b6d3e965d9e8",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Map (num_proc=4): 0%| | 0/40000 [00:00<?, ? examples/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"new_validation_dataset = og_validation_dataset.map(transform_picture, fn_kwargs={'parameters':params}, num_proc=4)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "32d7adecb479420cb8a4b3eee898ec1b",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Map (num_proc=4): 0%| | 0/320000 [00:00<?, ? examples/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"new_training_dataset = og_training_dataset.map(transform_picture, fn_kwargs={'parameters':params}, num_proc=4)\n"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"def setlabelname(entry):\n",
" return {'image':entry['image'], 'rotation':entry['label']}"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# new_testing_dataset = new_testing_dataset.map(setlabelname, num_proc=4, batch_size=700, batched=True, writer_batch_size=700)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# new_training_dataset = new_training_dataset.remove_columns(\"label\")\n",
"# new_testing_dataset = new_testing_dataset.remove_columns(\"label\")\n",
"# new_validation_dataset = new_validation_dataset.remove_columns(\"label\")"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"new_dataset = ds.DatasetDict({'train': new_training_dataset,'test': new_testing_dataset, 'validation': new_validation_dataset})\n"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"DatasetDict({\n",
" train: Dataset({\n",
" features: ['image', 'label', 'rotation'],\n",
" num_rows: 320000\n",
" })\n",
" test: Dataset({\n",
" features: ['image', 'label', 'rotation'],\n",
" num_rows: 39999\n",
" })\n",
" validation: Dataset({\n",
" features: ['image', 'label', 'rotation'],\n",
" num_rows: 40000\n",
" })\n",
"})\n"
]
}
],
"source": [
"print(new_dataset)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"new_dataset = new_dataset.remove_columns(\"label\")"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"DatasetDict({\n",
" train: Dataset({\n",
" features: ['image', 'rotation'],\n",
" num_rows: 320000\n",
" })\n",
" test: Dataset({\n",
" features: ['image', 'rotation'],\n",
" num_rows: 39999\n",
" })\n",
" validation: Dataset({\n",
" features: ['image', 'rotation'],\n",
" num_rows: 40000\n",
" })\n",
"})\n"
]
}
],
"source": [
"print(new_dataset)"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "a614c8b5206649f0b774dc25909bca75",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Saving the dataset (0/65 shards): 0%| | 0/320000 [00:00<?, ? examples/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "bc7984d6eafe443aa4980e43350fee03",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Saving the dataset (0/9 shards): 0%| | 0/39999 [00:00<?, ? examples/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "ab4cc0e859bd46599345129fd70bcb37",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Saving the dataset (0/9 shards): 0%| | 0/40000 [00:00<?, ? examples/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"new_dataset.save_to_disk(\"../.cache/huggingfaces/datasets/customrotation/\", max_shard_size=\"500MB\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@ -1,159 +0,0 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"# from datasets import load_dataset, Image\n",
"import datasets as ds\n",
"import PIL\n",
"import torchvision.transforms.functional as tvf\n",
"from torchvision.transforms import v2\n",
"import random\n",
"import numpy as np\n",
"\n",
"import torchvision.utils as utils\n",
"\n",
"from tqdm.autonotebook import tqdm"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"original_dataset = ds.load_dataset(\"aharley/rvl_cdip\")"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"# Create own dataset from the images of the original dataset but make the labels the float value for the rotation. do the random rotation on all of the training ones but the labels for the validation and test should/can be 0\n",
"trainblacklist = []\n",
"testblacklist = [33669] # index 33669 is just corrupted\n",
"validationblacklist = []\n",
"og_training_dataset = original_dataset['train'].select([i for i in range(len(original_dataset['train'])) if i not in trainblacklist])\n",
"og_testing_dataset = original_dataset['test'].select([i for i in range(len(original_dataset['test'])) if i not in testblacklist])\n",
"og_validation_dataset = original_dataset['validation'].select([i for i in range(len(original_dataset['validation'])) if i not in validationblacklist])\n",
"\n",
"tensorize = v2.Compose([v2.ToImageTensor(), v2.ConvertImageDtype()])\n",
"\n",
"og_training_dataset.set_transform(tensorize)\n",
"og_testing_dataset.set_transform(tensorize)\n",
"og_validation_dataset.set_transform(tensorize)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "755255ae5bea49cc866c96f0d291b570",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/39999 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"pbar = tqdm(og_testing_dataset)\n",
"\n",
"for i, entry in enumerate(pbar):\n",
" index = i\n",
" if (i >= 33669):\n",
" index = index + 1\n",
" utils.save_image(entry['image'], \"./datasetimages/test/\"+str(index)+\".jpg\")"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "88288b649a64430bb52e2ae5720e4b1f",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/320000 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"pbar = tqdm(og_training_dataset)\n",
"\n",
"for i, entry in enumerate(pbar):\n",
" utils.save_image(entry['image'], \"./datasetimages/train/\"+str(i)+\".jpg\")"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "ed6ce8bc3d224f278df6723fc0c41d72",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/40000 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"pbar = tqdm(og_validation_dataset)\n",
"\n",
"for i, entry in enumerate(pbar):\n",
" utils.save_image(entry['image'], \"./datasetimages/validation/\"+str(i)+\".jpg\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@ -1,430 +0,0 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"## ORIGINAL FILE FOR SELECTIVE SEGMENTATION SEARCH"
]
},
{
"cell_type": "code",
"execution_count": 350,
"metadata": {},
"outputs": [],
"source": [
"import cv2\n",
"import numpy as np\n",
"from queue import PriorityQueue\n",
"import myfunctions as mf\n",
"\n",
"import matplotlib.pyplot as plt\n",
"import random"
]
},
{
"cell_type": "code",
"execution_count": 351,
"metadata": {},
"outputs": [],
"source": [
"# def ResizeWithAspectRatio(image, width=None, height=None, inter=cv2.INTER_AREA):\n",
"# dim = None\n",
"# (h, w) = image.shape[:2]\n",
"\n",
"# if width is None and height is None:\n",
"# return image\n",
"# if width is None:\n",
"# r = height / float(h)\n",
"# dim = (int(w * r), height)\n",
"# else:\n",
"# r = width / float(w)\n",
"# dim = (width, int(h * r))\n",
"\n",
"# return cv2.resize(image, dim, interpolation=inter)"
]
},
{
"cell_type": "code",
"execution_count": 352,
"metadata": {},
"outputs": [],
"source": [
"import heapq as hq\n",
"\n",
"class MaxHeapObj(object):\n",
" def __init__(self, val): self.val = val\n",
" def __lt__(self, other): return self.val > other.val\n",
" def __eq__(self, other): return self.val == other.val\n",
" def __str__(self): return str(self.val)\n",
" \n",
"class MinHeap(object):\n",
" def __init__(self): self.h = []\n",
" def heappush(self, x): heapq.heappush(self.h, x)\n",
" def heappop(self): return heapq.heappop(self.h)\n",
" def __getitem__(self, i): return self.h[i]\n",
" def __len__(self): return len(self.h)\n",
" \n",
"class MaxHeap(MinHeap):\n",
" def heappush(self, x): heapq.heappush(self.h, MaxHeapObj(x))\n",
" def heappop(self): return heapq.heappop(self.h).val\n",
" def __getitem__(self, i): return self.h[i].val"
]
},
{
"cell_type": "code",
"execution_count": 353,
"metadata": {},
"outputs": [],
"source": [
"# def clip(n, lower, upper):\n",
"# return max(lower, min(n, upper))\n",
"\n",
"# def colourscaler(n, min, max):\n",
"# temp = n-min\n",
"# diff = abs(max - min)\n",
"# return clip((temp/diff)*255, 0, 255)"
]
},
{
"cell_type": "code",
"execution_count": 354,
"metadata": {},
"outputs": [],
"source": [
"# inline double clip(double n, double lower, double upper) {\n",
"# return std::max(lower, std::min(n, upper));\n",
"# };\n",
"\n",
"# inline double colourscaler(double n, double min, double max) {\n",
"# double temp = n - min;\n",
"# double diff = std::abs(max - min);\n",
"# return clip((temp / diff) * 255, 0, 255);\n",
"# };"
]
},
{
"cell_type": "code",
"execution_count": 355,
"metadata": {},
"outputs": [],
"source": [
"# ## Test this code for the masking/colour squishing. it essentially can just speed up clipping the edges.\n",
"# #!/usr/local/bin/python3\n",
"# import cv2 as cv\n",
"# import numpy as np\n",
"\n",
"# # Load the aerial image and convert to HSV colourspace\n",
"# image = cv.imread(\"aerial.png\")\n",
"# hsv=cv.cvtColor(image,cv.COLOR_BGR2HSV)\n",
"\n",
"# # Define lower and uppper limits of what we call \"brown\"\n",
"# brown_lo=np.array([10,0,0])\n",
"# brown_hi=np.array([20,255,255])\n",
"\n",
"# # Mask image to only select browns\n",
"# mask=cv.inRange(hsv,brown_lo,brown_hi)\n",
"\n",
"# # Change image to red where we found brown\n",
"# image[mask>0]=(0,0,255)\n",
"\n",
"# cv.imwrite(\"result.png\",image)\n",
"\n",
"#CAN ALSO TRY USING NUMPY VECTORIZATION"
]
},
{
"cell_type": "code",
"execution_count": 356,
"metadata": {},
"outputs": [],
"source": [
"# def rotate(img, angle):\n",
"# rows,cols = img.shape[0], img.shape[1]\n",
"# M = cv2.getRotationMatrix2D((cols/2,rows/2),angle,1)\n",
"# dst = cv2.warpAffine(img,M,(cols,rows))\n",
"# return dst"
]
},
{
"cell_type": "code",
"execution_count": 357,
"metadata": {},
"outputs": [],
"source": [
"def crop(image, lower = 100, upper = 255, threshold1 = 50, threshold2 = 350):\n",
" lower = max(0,lower)\n",
" upper = min(255, upper)\n",
" gray = cv2.cvtColor(image,cv2.COLOR_BGR2GRAY)\n",
"\n",
" scaled_gray = np.zeros(gray.shape, gray.dtype)\n",
" \n",
" # for y in range(0,gray.shape[0]):\n",
" # for x in range(0,gray.shape[1]):\n",
" # scaled_gray[y][x] = colourscaler(gray[y][x], lower, upper)\n",
" scaled_gray = gray\n",
" \n",
" blurred = cv2.GaussianBlur(scaled_gray, (15,15),0)\n",
" # blurred = scaled_gray\n",
" edged = cv2.Canny(blurred, threshold1, threshold2)\n",
" # meangrayscale = cv2.mean(scaled_gray)[0]\n",
" # print(meangrayscale)\n",
" # edged = cv2.Canny(blurred, int(meangrayscale*2), int(meangrayscale*4))\n",
" return edged\n",
" "
]
},
{
"cell_type": "code",
"execution_count": 358,
"metadata": {},
"outputs": [],
"source": [
"def selectiveSearchSegmentationImp(image):\n",
" ss = cv2.ximgproc.segmentation.createSelectiveSearchSegmentation()\n",
" ss.setBaseImage(image)\n",
" ss.switchToSelectiveSearchFast()\n",
" return ss.process()"
]
},
{
"cell_type": "code",
"execution_count": 359,
"metadata": {},
"outputs": [],
"source": [
"img = cv2.imread('./testing_space/final.jpg')"
]
},
{
"cell_type": "code",
"execution_count": 360,
"metadata": {},
"outputs": [],
"source": [
"# def rectArea(rect):\n",
"# # print(rect)\n",
"# return rect[2]*rect[3]\n",
"\n",
"# def biggestRects(n, rects):\n",
"# dict = {}\n",
"# # outrects = np.zeros(shape=(n, 4))\n",
"# for rect in rects:\n",
"# dict[tuple(rect)] = mf.rectArea(rect)\n",
"# # maxh.heappush(mf.rectArea(rect))\n",
"# # print(maxh[0])\n",
" \n",
" \n",
"# heap = [(-value, key) for key,value in dict.items()]\n",
"# largest = hq.nsmallest(n, heap)\n",
" \n",
"\n",
"# # hq.heapify(list(dict.items()))\n",
"# # for i in range(0,n):\n",
"# # outrects[i] = maxh.heappop()\n",
"# # print(outrects)\n",
"# return [key for value, key in largest]\n",
"\n",
"# def overlapRect(rects):\n",
"# leftwall = -1\n",
"# rightwall = -1\n",
"# topwall = -1\n",
"# bottomwall = -1\n",
"# for (x, y, w, h) in rects:\n",
"# if (leftwall == -1):\n",
"# leftwall = x\n",
"# rightwall = x + w\n",
"# topwall = y\n",
"# bottomwall = y + h\n",
"# continue\n",
"# leftwall = max(leftwall, x)\n",
"# rightwall = min(rightwall, x+w)\n",
"# topwall = max(topwall, y)\n",
"# bottomwall = min(bottomwall, y+h)\n",
" \n",
"# if (topwall >= bottomwall or leftwall >= rightwall):\n",
"# return (-1, -1, -1, -1)\n",
"# return (leftwall, topwall, rightwall-leftwall, bottomwall-topwall)"
]
},
{
"cell_type": "code",
"execution_count": 344,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(-1, -1, -1, -1)\n"
]
}
],
"source": [
"# rect = crop(img)\n",
"\n",
"# _, thresholded = cv2.threshold(cv2.cvtColor(img, cv2.COLOR_BGR2GRAY), 200, 255, cv2.THRESH_BINARY)\n",
"\n",
"rects = selectiveSearchSegmentationImp(cv2.GaussianBlur(ResizeWithAspectRatio(img,300), (15,15),0))\n",
"# mf.rectArea(rects[0])\n",
"bigRects = mf.biggestRects(20, rects)\n",
"# print(bigRects)\n",
"\n",
"finalrect = mf.overlapRect(bigRects)\n",
"print(finalrect)\n",
"output = ResizeWithAspectRatio(img,300)\n",
"for (x, y, w, h) in [finalrect]:\n",
"\t\t# draw the region proposal bounding box on the image\n",
"\t\tcolor = [random.randint(0, 255) for j in range(0, 3)]\n",
"\t\tcv2.rectangle(output, (x, y), (x + w, y + h), color, 2)\n",
"\n",
"# edges = cv2.Canny(cv2.GaussianBlur(cv2.cvtColor(img, cv2.COLOR_BGR2GRAY), (15,15),0),255 / 4, 255)\n",
"\n",
"# plt.imshow(edges, cmap='gray', vmin=0, vmax=255)\n",
"# plt.show()\n",
"\n",
"cv2.imshow(\"banana\", output)\n",
"cv2.waitKey(0)\n",
"cv2.destroyAllWindows()\n",
"\n",
"\n",
"# print(range(0,img.shape[1]))\n",
"# for i in range(0,img.shape[1]):\n",
"# print(i)"
]
},
{
"cell_type": "code",
"execution_count": 389,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1\n"
]
}
],
"source": [
"temp = ResizeWithAspectRatio(crop(img, threshold1=150, threshold2=350),500)\n",
"contours, _ = cv2.findContours(temp, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)\n",
"# print(type(contours))\n",
"# max(cv2.contourArea(contours))\n",
"# areas = list(map(cv2.contourArea, contours))\n",
"# print(areas)\n",
"contourindex = np.argmax(list(map(cv2.contourArea, contours)))\n",
"temp = cv2.drawContours(temp, contours, contourindex, (255,0,0), 2)\n",
"cv2.imshow(\"banana\", temp)\n",
"cv2.waitKey(0)\n",
"cv2.destroyAllWindows()\n",
"print(contourindex)\n",
"rect = cv2.boundingRect(contours[contourindex])\n",
"color = (random.randint(0,256), random.randint(0,256), random.randint(0,256))\n",
"result = cv2.rectangle(ResizeWithAspectRatio(img,500), rect, color, 3)"
]
},
{
"cell_type": "code",
"execution_count": 362,
"metadata": {},
"outputs": [],
"source": [
"# print(contourindex)"
]
},
{
"cell_type": "code",
"execution_count": 371,
"metadata": {},
"outputs": [],
"source": [
"cv2.imshow(\"banana\", result)\n",
"cv2.waitKey(0)\n",
"cv2.destroyAllWindows()"
]
},
{
"cell_type": "code",
"execution_count": 348,
"metadata": {},
"outputs": [],
"source": [
"# HSV = cv2.cvtColor(ResizeWithAspectRatio(img,500), cv2.COLOR_BGR2HSV)\n",
"# low = np.array([0,0,10])\n",
"# high = np.array([179,10,255])\n",
"\n",
"# mask = cv2.inRange(HSV,low,high)\n",
"\n",
"# cv2.imshow(\"banana\", mask)\n",
"# cv2.waitKey(0)\n",
"# cv2.destroyAllWindows()"
]
},
{
"cell_type": "code",
"execution_count": 349,
"metadata": {},
"outputs": [],
"source": [
" # cv::Mat gray, scaled_gray, blurred, edged;\n",
"\n",
" # lower = std::max(lower, 0);\n",
" # upper = std::min(upper, 255);\n",
"\n",
" # cv::cvtColor(src, gray, cv::COLOR_BGR2GRAY);\n",
" # scaled_gray = cv::Mat::zeros(gray.size(), gray.type());\n",
"\n",
" # for (int y = 0; y < gray.rows; y++) {\n",
" # for (int x = 0; x < gray.cols; x++) {\n",
" # scaled_gray.at<uchar>(y, x) =\n",
" # cv::saturate_cast<uchar>(colourscaler(gray.at<uchar>(y, x), lower, upper));\n",
" # }\n",
" # }\n",
"\n",
" # cv::GaussianBlur(scaled_gray, blurred, cv::Size(15, 15), 0);\n",
" # cv::Canny(blurred, edged, threshold1, threshold2);\n",
"\n",
" # std::vector<std::vector<cv::Point>> contours;\n",
" # std::vector<cv::Vec4i> heirarchy;\n",
" # cv::Mat approx;\n",
"\n",
" # cv::findContours(edged, contours, heirarchy, cv::RETR_TREE, cv::CHAIN_APPROX_SIMPLE);\n",
"\n",
" # cv::cvtColor(gray, gray, cv::COLOR_GRAY2BGR);\n",
"\n",
" # std::sort(contours.begin(), contours.end(), [](std::vector<cv::Point> a, std::vector<cv::Point> b) {\n",
" # return cv::arcLength(a, false) > cv::arcLength(b, false); });\n",
"\n",
" # int numContours = contours.size();\n",
"\n",
"\n",
" # return cv::boundingRect(contours[0]);"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@ -1,94 +0,0 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
"import cv2\n",
"import numpy as np\n",
"\n",
"import torch\n",
"import torchvision.transforms.functional as tvf\n",
"import torchvision.transforms.v2 as v2\n",
"import torchvision.transforms as t\n",
"import myfunctions as mf\n",
"\n",
"from skimage import io\n",
"from matplotlib import pyplot as plt\n",
"import time\n",
"\n",
"import myfunctions as mf"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
"# read image as grayscale\n",
"img = cv2.imread('./test_images/IMG_7594.jpg')"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"cropped = mf.morphologyCrop(img)\n",
"# rotated = deskew(cropped)\n",
"# cropped2 = morphologyCrop(rotated)\n",
"# cropped2 = selectiveSearchCrop(rotated)\n",
"# cropped3 = cannyEdgeCrop(cropped2)\n",
"cv2.imwrite(\"./testing_space/final.jpg\", cropped)\n",
"# final = rotate(cropped2, 180) # need to implement the code to determine if a doc is upside down"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
"### Deskew seems to work \n",
"# Note licencing for the deskew package and "
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@ -1,316 +0,0 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# can probably be deleted or put somewhere. Was the original code for the rowsumdeskew"
]
},
{
"cell_type": "code",
"execution_count": 36,
"metadata": {},
"outputs": [],
"source": [
"import cv2\n",
"import numpy as np\n"
]
},
{
"cell_type": "code",
"execution_count": 37,
"metadata": {},
"outputs": [],
"source": [
"\n",
"\n",
"src = 255 - cv2.imread('./testing_space/cropped1.jpg',0)\n",
"scores = []\n",
"\n",
"h,w = src.shape\n",
"small_dimention = min(h,w)\n",
"src = src[:small_dimention, :small_dimention]\n",
"\n",
"out = cv2.VideoWriter('./temp/video.avi',\n",
" cv2.VideoWriter_fourcc('M','J','P','G'),\n",
" 15, (320,320))\n",
"\n",
"src = cv2.threshold(src, 100, 255, cv2.THRESH_BINARY)[1]"
]
},
{
"cell_type": "code",
"execution_count": 38,
"metadata": {},
"outputs": [],
"source": [
"def rotate(img, angle):\n",
" rows,cols = img.shape\n",
" M = cv2.getRotationMatrix2D((cols/2,rows/2),angle,1)\n",
" dst = cv2.warpAffine(img,M,(cols,rows))\n",
" return dst\n",
"\n",
"def sum_rows(img):\n",
" # Create a list to store the row sums\n",
" row_sums = []\n",
" # Iterate through the rows\n",
" for r in range(img.shape[0]-1):\n",
" # Sum the row\n",
" row_sum = sum(sum(img[r:r+1,:]))\n",
" # Add the sum to the list\n",
" row_sums.append(row_sum)\n",
" # Normalize range to (0,255)\n",
" row_sums = (row_sums/max(row_sums)) * 255\n",
" # Return\n",
" return row_sums\n",
"\n",
"def display_data(roi, row_sums, buffer): \n",
" # Create background to draw transform on\n",
" bg = np.zeros((buffer*2, buffer*2), np.uint8) \n",
" # Iterate through the rows and draw on the background\n",
" for row in range(roi.shape[0]-1):\n",
" row_sum = row_sums[row]\n",
" bg[row:row+1, :] = row_sum\n",
" left_side = int(buffer/3)\n",
" bg[:, left_side:] = roi[:,left_side:] \n",
" cv2.imshow('bg1', bg)\n",
" k = cv2.waitKey(1)\n",
" out.write(cv2.cvtColor(cv2.resize(bg, (320,320)), cv2.COLOR_GRAY2BGR))\n",
" return k\n"
]
},
{
"cell_type": "code",
"execution_count": 39,
"metadata": {},
"outputs": [],
"source": [
"count = 0\n",
"othercount = 0\n",
"goodangle = 0"
]
},
{
"cell_type": "code",
"execution_count": 40,
"metadata": {},
"outputs": [],
"source": [
"# cv2.imshow('bg1', src)\n",
"# cv2.waitKey(0)\n",
"# cv2.destroyAllWindows()"
]
},
{
"cell_type": "code",
"execution_count": 41,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"found optimal rotation\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"found optimal rotation\n",
"found optimal rotation\n",
"found optimal rotation\n",
"found optimal rotation\n",
"found optimal rotation\n",
"found optimal rotation\n",
"found optimal rotation\n",
"found optimal rotation\n",
"found optimal rotation\n",
"found optimal rotation\n",
"found optimal rotation\n",
"found optimal rotation\n",
"found optimal rotation\n",
"found optimal rotation\n",
"found optimal rotation\n",
"found optimal rotation\n",
"found optimal rotation\n",
"found optimal rotation\n",
"found optimal rotation\n",
"found optimal rotation\n",
"found optimal rotation\n",
"found optimal rotation\n",
"found optimal rotation\n",
"found optimal rotation\n"
]
}
],
"source": [
"# Rotate the image around in a circle\n",
"angle = 0\n",
"while angle <= 360:\n",
" # Rotate the source image\n",
" img = rotate(src, angle) \n",
" # Crop the center 1/3rd of the image (roi is filled with text)\n",
" h,w = img.shape\n",
" buffer = min(h, w) - int(min(h,w)/1.5)\n",
" roi = img[int(h/2-buffer):int(h/2+buffer), int(w/2-buffer):int(w/2+buffer)]\n",
" # Create background to draw transform on\n",
" bg = np.zeros((buffer*2, buffer*2), np.uint8)\n",
" # Compute the sums of the rows\n",
" row_sums = sum_rows(roi)\n",
" # High score --> Zebra stripes\n",
" score = np.count_nonzero(row_sums)\n",
" scores.append(score)\n",
" othercount = othercount + 1\n",
" # Image has best rotation\n",
" if score <= min(scores):\n",
" count = count + 1\n",
" # Save the rotatied image\n",
" print('found optimal rotation')\n",
" best_rotation = img.copy()\n",
" goodangle = angle\n",
" k = display_data(roi, row_sums, buffer)\n",
" if k == 27: break\n",
" # Increment angle and try again\n",
" angle += .75\n",
"cv2.destroyAllWindows()"
]
},
{
"cell_type": "code",
"execution_count": 42,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"25\n",
"481\n",
"349.5\n"
]
}
],
"source": [
"print(count)\n",
"print(othercount)\n",
"print(goodangle)\n",
"cv2.imshow('bg1', best_rotation)\n",
"cv2.waitKey(0)\n",
"cv2.destroyAllWindows()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"print(\"start\")\n",
"\n",
"# Rotate the image around in a circle\n",
"angle = 0\n",
"while angle <= 360:\n",
" # Rotate the source image\n",
" img = rotate(src, angle) \n",
" # Crop the center 1/3rd of the image (roi is filled with text)\n",
" h,w = img.shape\n",
" buffer = min(h, w) - int(min(h,w)/1.5)\n",
" #roi = img.copy()\n",
" roi = img[int(h/2-buffer):int(h/2+buffer), int(w/2-buffer):int(w/2+buffer)]\n",
" # Create background to draw transform on\n",
" bg = np.zeros((buffer*2, buffer*2), np.uint8)\n",
" # Threshold image\n",
" _, roi = cv2.threshold(roi, 140, 255, cv2.THRESH_BINARY)\n",
" # Compute the sums of the rows\n",
" row_sums = sum_rows(roi)\n",
" # High score --> Zebra stripes\n",
" score = np.count_nonzero(row_sums)\n",
" if sum(row_sums) < 100000: scores.append(angle)\n",
" k = display_data(roi, row_sums, buffer)\n",
" if k == 27: break\n",
" # Increment angle and try again\n",
" angle += .5\n",
" print(\"loop\")\n",
"cv2.destroyAllWindows()\n",
"\n",
"print(\"endofrotate\")\n",
"\n",
"# Create images for display purposes\t\n",
"display = src.copy()\n",
"# Create an image that contains bins. \n",
"bins_image = np.zeros_like(display)\n",
"for angle in scores:\n",
" # Rotate the image and draw a line on it\n",
" display = rotate(display, angle) \n",
" cv2.line(display, (0,int(h/2)), (w,int(h/2)), 255, 1)\n",
" display = rotate(display, -angle)\n",
" # Rotate the bins image\n",
" bins_image = rotate(bins_image, angle)\n",
" # Draw a line on a temporary image\n",
" temp = np.zeros_like(bins_image)\n",
" cv2.line(temp, (0,int(h/2)), (w,int(h/2)), 50, 1)\n",
" # 'Fill' up the bins\n",
" bins_image += temp\n",
" bins_image = rotate(bins_image, -angle)\n",
" \n",
"print(\"endofbins\")\n",
"\n",
"# Find the most filled bin\n",
"for col in range(bins_image.shape[0]-1):\n",
"\tcolumn = bins_image[:, col:col+1]\n",
"\tif np.amax(column) == np.amax(bins_image): x = col\n",
"for col in range(bins_image.shape[0]-1):\n",
"\tcolumn = bins_image[:, col:col+1]\n",
"\tif np.amax(column) == np.amax(bins_image): y = col\n",
"# Draw circles showing the most filled bin\n",
"cv2.circle(display, (x,y), 560, 255, 5)\n",
"\n",
"print(\"plotting\")\n",
"\n",
"# Plot with Matplotlib\n",
"import matplotlib.pyplot as plt\n",
"import matplotlib.image as mpimg\n",
"f, axarr = plt.subplots(1,3, sharex=True)\n",
"axarr[0].imshow(src)\n",
"axarr[1].imshow(display)\n",
"axarr[2].imshow(bins_image)\n",
"axarr[0].set_title('Source Image')\n",
"axarr[1].set_title('Output')\n",
"axarr[2].set_title('Bins Image')\n",
"axarr[0].axis('off')\n",
"axarr[1].axis('off')\n",
"axarr[2].axis('off')\n",
"plt.show()\n",
"\n",
"cv2.waitKey()\n",
"cv2.destroyAllWindows()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@ -1,777 +0,0 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/usr/local/lib/python3.10/dist-packages/torchvision/datapoints/__init__.py:12: UserWarning: The torchvision.datapoints and torchvision.transforms.v2 namespaces are still Beta. While we do not expect major breaking changes, some APIs may still change according to user feedback. Please submit any feedback you may have in this issue: https://github.com/pytorch/vision/issues/6753, and you can also check out https://github.com/pytorch/vision/issues/7319 to learn more about the APIs that we suspect might involve future changes. You can silence this warning by calling torchvision.disable_beta_transforms_warning().\n",
" warnings.warn(_BETA_TRANSFORMS_WARNING)\n",
"/usr/local/lib/python3.10/dist-packages/torchvision/transforms/v2/__init__.py:54: UserWarning: The torchvision.datapoints and torchvision.transforms.v2 namespaces are still Beta. While we do not expect major breaking changes, some APIs may still change according to user feedback. Please submit any feedback you may have in this issue: https://github.com/pytorch/vision/issues/6753, and you can also check out https://github.com/pytorch/vision/issues/7319 to learn more about the APIs that we suspect might involve future changes. You can silence this warning by calling torchvision.disable_beta_transforms_warning().\n",
" warnings.warn(_BETA_TRANSFORMS_WARNING)\n"
]
}
],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as fn\n",
"import torch.optim as optim\n",
"import torchvision.transforms.functional as tvf\n",
"from torchvision.transforms import v2\n",
"from torch.utils.data import DataLoader\n",
"\n",
"from PIL import Image\n",
"\n",
"import datasets as ds\n",
"from tqdm.autonotebook import tqdm\n",
"\n",
"import random\n",
"\n",
"import matplotlib.pyplot as plt\n",
"\n",
"import numpy as np\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"# original_dataset = ds.load_dataset(\"aharley/rvl_cdip\")"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"working_dataset = ds.load_from_disk(\"../.cache/huggingfaces/datasets/customrotation/\")\n",
"prepimage = v2.Compose([v2.Grayscale(num_output_channels=3),v2.Resize(1100), v2.CenterCrop(1100),v2.ToImageTensor(), v2.ConvertImageDtype()])\n",
"working_dataset.set_transform(prepimage)\n",
"torch.cuda.empty_cache()"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"# Parameter Declaration\n",
"minRotation=-180\n",
"maxRotation=180\n",
"minTranslation=0\n",
"maxTranslation=150\n",
"minScale = 0.4\n",
"maxScale = 1\n",
"minShear = 0\n",
"maxShear = 0\n",
"\n",
"minFill=0\n",
"maxFill=255\n",
"\n",
"params = {\"minRotation\":minRotation,\"maxRotation\":maxRotation,\"minTranslation\":minTranslation,\"maxTranslation\":maxTranslation,\"minScale\":minScale,\"maxScale\":maxScale,\"minShear\":minShear,\"maxShear\":maxShear,\"minFill\":minFill,\"maxFill\":maxFill}"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"def transform_picture(image_label, parameters):\n",
" image = image_label['image']\n",
"\n",
" appliedRotation = random.uniform(parameters['minRotation'], parameters['maxRotation'])\n",
" appliedXTranslation = random.uniform(parameters['minTranslation'], parameters['maxTranslation'])\n",
" appliedYTranslation = random.uniform(parameters['minTranslation'], parameters['maxTranslation'])\n",
" appliedScale = random.uniform(parameters['minScale'], parameters['maxScale'])\n",
" appliedFill = random.uniform(parameters['minFill'], parameters['maxFill'])\n",
" appliedXShear = random.uniform(parameters['minShear'], parameters['maxShear'])\n",
" appliedYShear = random.uniform(parameters['minShear'], parameters['maxShear'])\n",
" \n",
" appliers = [v2.RandomApply(transforms=[v2.RandomPosterize(bits=1)], p=0.25),\n",
" v2.RandomApply(transforms=[v2.ElasticTransform(alpha=25.0, fill=appliedFill)], p=0.25), # maybe add fill=appliedFill\n",
" v2.RandomApply(transforms=[v2.GaussianBlur(kernel_size=(5,9), sigma=(0.1,2.))],p=0.25),\n",
" v2.RandomApply(transforms=[v2.RandomEqualize()],p=0.25)]\n",
" \n",
" adjustedimage = tvf.affine(image, appliedRotation, [appliedXTranslation,appliedYTranslation], appliedScale, [appliedXShear, appliedYShear], fill=appliedFill)\n",
"\n",
" for applier in appliers:\n",
" adjustedimage = applier(adjustedimage)\n",
"\n",
" \n",
" adjustedimage = tvf.resize(adjustedimage, size=[1100,1100])\n",
" \n",
" return {'image':adjustedimage,'label':appliedRotation}"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"# # Create own dataset from the images of the original dataset but make the labels the float value for the rotation. do the random rotation on all of the training ones but the labels for the validation and test should/can be 0\n",
"# og_training_dataset = original_dataset['train']\n",
"# og_testing_dataset = original_dataset['test']\n",
"# og_validation_dataset = original_dataset['validation']\n",
"\n",
"# type(og_testing_dataset[0]['label'])\n",
"\n",
"# # type(transform_picture(og_testing_dataset[0], params))\n",
"# new_testing_dataset = og_testing_dataset.map(transform_picture, fn_kwargs={'parameters':params})"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"# class WorkaroundDataset(torch.utils.data.Dataset):\n",
"# def __init__(self, dataset):\n",
"# self._dataset = dataset\n",
"\n",
"# def __len__(self):\n",
"# return len(self._dataset)\n",
"\n",
"# def __getitem__(self, idx):\n",
"# return v2.Compose([v2.ToImageTensor(), v2.ConvertImageDtype()])(self._dataset[idx]['image'])"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"# # type(image_dataset['train'][0]['image'])\n",
"# # print(image_dataset['train'][0]['image'])\n",
"# img = image_dataset['train'][2]['image']\n",
"# # img\n",
"# # print(img.size)\n",
"# crop = tvf.resize(img, size=[500])\n",
"# # crop\n",
"# # print(crop.size)\n",
"# newimg = tvf.affine(crop, 180, [0,0], 0.7, 0)\n",
"# newimg"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"# appliedRotation = random.uniform(minRotation, maxRotation)\n",
"# appliedXTranslation = random.uniform(minTranslation, maxTranslation)\n",
"# appliedYTranslation = random.uniform(minTranslation, maxTranslation)\n",
"# appliedScale = random.uniform(minScale, maxScale)\n",
"# appliedFill = random.uniform(minFill, maxFill)\n",
"\n",
"\n",
"\n",
"# newimg = tvf.affine(crop, appliedRotation, [appliedXTranslation,appliedYTranslation], appliedScale, shear, fill=appliedFill)\n",
"# newimg\n",
"\n",
"# appliers = [v2.RandomApply(transforms=[v2.RandomPosterize(bits=1)], p=0.25),\n",
"# v2.RandomApply(transforms=[v2.ElasticTransform(alpha=25.0, fill=appliedFill)], p=0.25),\n",
"# v2.RandomApply(transforms=[v2.GaussianBlur(kernel_size=(5,9), sigma=(0.1,2.))],p=0.25),\n",
"# v2.RandomApply(transforms=[v2.RandomEqualize()],p=0.25)]\n",
"\n",
"# for applier in appliers:\n",
"# newimg = applier(newimg)\n",
" \n",
"# # newimg\n",
"# newimg= tvf.resize(newimg, size=[1000,1000])\n",
"# newimg\n",
" "
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"# class SquarePad:\n",
"# \tdef __call__(self, image):\n",
"# \t\tw, h = image.size\n",
"# \t\tmax_wh = np.max([w, h])\n",
"# \t\thp = int((max_wh - w) / 2)\n",
"# \t\tvp = int((max_wh - h) / 2)\n",
"# \t\tpadding = (hp, vp, hp, vp)\n",
"# \t\treturn tvf.pad(image, padding, 0, 'constant')"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"\n",
"class RotationDeterminer(nn.Module):\n",
" def __init__(self):\n",
" super().__init__()\n",
" \n",
" torch.cuda.empty_cache()\n",
" \n",
" self.device = torch.device(\"cpu\")\n",
" if torch.cuda.is_available:\n",
" self.device = torch.device(\"cuda:0\")\n",
" \n",
" \n",
" self.appliers = [v2.RandomApply(transforms=[v2.RandomPosterize(bits=1)], p=0.25),\n",
" v2.RandomApply(transforms=[v2.ElasticTransform(alpha=25.0)], p=0.25), # maybe add fill=appliedFill\n",
" v2.RandomApply(transforms=[v2.GaussianBlur(kernel_size=(5,9), sigma=(0.1,2.))],p=0.25),\n",
" v2.RandomApply(transforms=[v2.RandomEqualize()],p=0.25)]\n",
" \n",
" \n",
" self.conv = nn.Sequential(nn.Conv2d(3, 9, kernel_size=11,stride=3), # 1100 x 1100 => 201 x 201\n",
" nn.ReLU(inplace=True),\n",
" nn.Conv2d(9, 18, kernel_size=5,stride=1),\n",
" nn.ReLU(inplace=True),\n",
" nn.MaxPool2d(kernel_size=4, stride=2),\n",
" nn.Conv2d(18, 36, kernel_size=3,stride=2),\n",
" nn.ReLU(inplace=True),\n",
" nn.Conv2d(36, 72, kernel_size=3,stride=2),\n",
" nn.ReLU(inplace=True),\n",
" nn.AvgPool2d(kernel_size=5, stride=3),\n",
" nn.Conv2d(72, 144, kernel_size=3,stride=1),\n",
" nn.ReLU(inplace=True),\n",
" nn.Conv2d(144, 288, kernel_size=5,stride=1),\n",
" nn.ReLU(inplace=True),\n",
" nn.MaxPool2d(kernel_size=4, stride=1),\n",
" nn.Conv2d(288, 192, kernel_size=3,stride=1),\n",
" nn.ReLU(inplace=True),\n",
" nn.Conv2d(192, 192, kernel_size=3,stride=1), # => 1\n",
" nn.ReLU(inplace=True))\n",
" \n",
" self.classifier = nn.Sequential(nn.Dropout(),\n",
" nn.Linear(192, 2048),\n",
" nn.ReLU(inplace=True),\n",
" nn.Dropout(),\n",
" nn.Linear(2048,2048),\n",
" nn.ReLU(inplace=True),\n",
" nn.Linear(2048,1))\n",
" \n",
" self.lossfunc = nn.MSELoss()\n",
" \n",
" self.imageprep = v2.Compose([self.SquarePad(),v2.Resize(1100),v2.Grayscale(num_output_channels=3),v2.CenterCrop(1100),v2.ToImageTensor(), v2.ConvertImageDtype()])\n",
" \n",
" \n",
" class SquarePad:\n",
" def __call__(self, image):\n",
" # print(\"hi type:\", type(image))\n",
" temp = image.size()\n",
" w = temp[-2]\n",
" h = temp[-1]\n",
" max_wh = max([w, h])\n",
" hp = int((max_wh - w) / 2)\n",
" vp = int((max_wh - h) / 2)\n",
" padding = (hp, vp, hp, vp)\n",
" return tvf.pad(image, padding, 0, 'edge')\n",
"\n",
"\n",
" \n",
"\n",
" \n",
" def forward(self, image):\n",
"\n",
" transformedimage = self.imageprep(image)\n",
" transformedimage = transformedimage.to(self.device)\n",
"\n",
" x = self.conv(transformedimage)\n",
" x = nn.Flatten(start_dim=-3)(x)\n",
" x = self.classifier(x)\n",
" guessRotation = nn.Flatten(start_dim=0)(x)\n",
" \n",
" return guessRotation\n",
" \n",
" def loss(self, guess, trueAnswer):\n",
" return self.lossfunc(guess, trueAnswer)\n",
" \n",
" "
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"# def batchmaker(entries, batchsize):\n",
"# random.shuffle(entries)\n",
"# listing = []\n",
"# for i in range(0,len(entries), batchsize):\n",
"# listing.append(entries[i:i+batchsize])\n",
"# return listing"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"# print(type(v2.Compose([v2.ToImageTensor(), v2.ConvertImageDtype()])(image_dataset['train'][0]['image'])))"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"# a, b, x = working_dataset['train'][0]['image'].size()\n",
"# print(x)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"def train(model, dataset, batchsize, num_epochs, stepsize, totalnumiters = -1):\n",
" device = torch.device(\"cpu\")\n",
" if torch.cuda.is_available:\n",
" device = torch.device(\"cuda:0\")\n",
" model = model.cuda()\n",
" optimizer = optim.Adam(model.parameters(), lr=stepsize)\n",
" \n",
" counter = totalnumiters\n",
" model = model.train()\n",
" \n",
" breakearly = True\n",
" if totalnumiters == -1:\n",
" print(\"hi\")\n",
" breakearly = False\n",
" totalnumiters = len(dataset) + 1\n",
" \n",
" for e in range(num_epochs):\n",
" \n",
" train_dataloader = DataLoader(dataset, batch_size=batchsize, shuffle=True)\n",
" \n",
" pbar = tqdm(train_dataloader)\n",
" \n",
" for i, batch in enumerate(pbar):\n",
" torch.cuda.empty_cache()\n",
" images, truerotations = batch['image'], batch['rotation']\n",
" images = images.to(device)\n",
" truerotations = truerotations.to(device)\n",
"\n",
" optimizer.zero_grad()\n",
" \n",
" guessRotation = model(images)\n",
" \n",
" truerotations = truerotations.float()\n",
" \n",
" loss = model.loss(guessRotation, truerotations)\n",
" \n",
" loss.backward()\n",
" \n",
" optimizer.step()\n",
" counter = counter - batchsize\n",
" if counter <= 0 and breakearly:\n",
" print(\"endearly\")\n",
" return\n",
"\n",
" "
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"testimage = working_dataset['train'][10]['image']\n",
"\n",
"# testimage = v2.Compose([v2.Grayscale(num_output_channels=3),v2.ToTensor(),])(testimage)\n",
"# testimage.size()"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
"# plt.imshow(testimage)\n",
"# plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
"# temp = testimage.size()\n",
"# print(temp[-3])"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [],
"source": [
"model = RotationDeterminer()\n",
"device = torch.device(\"cpu\")\n",
"if torch.cuda.is_available:\n",
" device = torch.device(\"cuda:0\")\n",
" model = model.cuda()\n",
"\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
"# output = model(testimage)\n",
"# print(output)"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [],
"source": [
"# train_dataloader = DataLoader(working_dataset['test'], batch_size=100, shuffle=True)\n",
"# hold = next(iter(train_dataloader))\n",
"# images1, labels1 = hold['image'], hold['rotation']\n",
"# # print(images1)\n",
"# print(labels1.size())"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/usr/local/lib/python3.10/dist-packages/torchvision/transforms/functional.py:1603: UserWarning: The default value of the antialias parameter of all the resizing transforms (Resize(), RandomResizedCrop(), etc.) will change from None to True in v0.17, in order to be consistent across the PIL and Tensor backends. To suppress this warning, directly pass antialias=True (recommended, future default), antialias=None (current default, which means False for Tensors and True for PIL), or antialias=False (only works on Tensors - PIL will still use antialiasing). This also applies if you are using the inference transforms from the models weights: update the call to weights.transforms(antialias=True).\n",
" warnings.warn(\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"hi\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "faff1411ea0d485b9321271ebe6820db",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/12800 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "d26b6872bee74eaab8be6e7cfe53b190",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/12800 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"outputarray = np.array([working_dataset['train'][10]['rotation']])\n",
"model = model.eval()\n",
"output = model(testimage)\n",
"outputarray = np.append(outputarray, output.detach().cpu().numpy())\n",
"counter = 0\n",
"\n",
"\n",
"\n",
"train(model, working_dataset['train'], 25, 2, 5e-3)\n",
"\n",
"model = model.eval()\n",
"\n",
"counter = 2 + counter\n",
"output = model(testimage)\n",
"outputarray = np.append(outputarray, output.detach().cpu().numpy())\n",
"np.save(\"./testing_space/outputarray\", outputarray)\n",
"np.save(\"./testing_space/counter\", counter)"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"-164.93280103082208\n",
"8.194759720936418e-05\n",
"-0.1751984804868698\n"
]
}
],
"source": [
"print(outputarray[0])\n",
"print(outputarray[1])\n",
"print(outputarray[2])"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"3"
]
},
"execution_count": 24,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(outputarray)"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [],
"source": [
"torch.save(model.state_dict(), \"./testing_space/modelsave\" + str(counter) +\" epochs\")"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [],
"source": [
"#load model\n",
"# model.load_state_dict(torch.load(\"./testing_space/modelsave2epochs\"))"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"hi\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "e144d16317094603b328e2db88a4853a",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/12800 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "bb6cf6e77ed34628bdfa6ed2a64ef284",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/12800 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"train(model, working_dataset['train'], 25, 2, 1e-3)\n",
"counter = 2 + counter"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [],
"source": [
"# outputarray = []"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [],
"source": [
"model = model.eval()\n",
"output = model(testimage)\n",
"outputarray = np.append(outputarray, output.detach().cpu().numpy())\n",
"np.save(\"./testing_space/outputarray\", outputarray)\n",
"np.save(\"./testing_space/counter\", counter)"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
"outputs": [],
"source": [
"torch.save(model.state_dict(), \"./testing_space/modelsave\" + str(counter) +\" epochs\")"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"hi\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "c270b991cacd4abc996c602748e742f7",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/12800 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "1515223a29da4cfea86be156155fd06e",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/12800 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"train(model, working_dataset['train'], 25, 2, 1e-2)\n",
"counter = 2 + counter"
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {},
"outputs": [],
"source": [
"model = model.eval()\n",
"output = model(testimage)\n",
"outputarray = np.append(outputarray, output.detach().cpu().numpy())\n",
"np.save(\"./testing_space/outputarray\", outputarray)\n",
"np.save(\"./testing_space/counter\", counter)"
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {},
"outputs": [],
"source": [
"torch.save(model.state_dict(), \"./testing_space/modelsave\" + str(counter) +\" epochs\")"
]
},
{
"cell_type": "code",
"execution_count": 35,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[-1.64932801e+02 8.19475972e-05 -1.75198480e-01 -2.21363053e-01\n",
" -2.17262208e-01]\n"
]
}
],
"source": [
"print(outputarray)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@ -1,645 +0,0 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"version=2.0\n",
"cachepath=\"../.cache/\"\n",
"savepath=\"./savespot/\""
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/usr/local/lib/python3.10/dist-packages/torchvision/datapoints/__init__.py:12: UserWarning: The torchvision.datapoints and torchvision.transforms.v2 namespaces are still Beta. While we do not expect major breaking changes, some APIs may still change according to user feedback. Please submit any feedback you may have in this issue: https://github.com/pytorch/vision/issues/6753, and you can also check out https://github.com/pytorch/vision/issues/7319 to learn more about the APIs that we suspect might involve future changes. You can silence this warning by calling torchvision.disable_beta_transforms_warning().\n",
" warnings.warn(_BETA_TRANSFORMS_WARNING)\n",
"/usr/local/lib/python3.10/dist-packages/torchvision/transforms/v2/__init__.py:54: UserWarning: The torchvision.datapoints and torchvision.transforms.v2 namespaces are still Beta. While we do not expect major breaking changes, some APIs may still change according to user feedback. Please submit any feedback you may have in this issue: https://github.com/pytorch/vision/issues/6753, and you can also check out https://github.com/pytorch/vision/issues/7319 to learn more about the APIs that we suspect might involve future changes. You can silence this warning by calling torchvision.disable_beta_transforms_warning().\n",
" warnings.warn(_BETA_TRANSFORMS_WARNING)\n"
]
}
],
"source": [
"import torch\n",
"from torch.utils.data import DataLoader\n",
"import torch.nn as nn\n",
"import torch.nn.functional as fn\n",
"import torch.optim as optim\n",
"import torchvision.transforms.functional as tvf\n",
"import torchvision.transforms.v2 as v2\n",
"import torchvision.models as models\n",
"\n",
"\n",
"from PIL import Image\n",
"\n",
"import datasets as ds\n",
"from tqdm.autonotebook import tqdm\n",
"\n",
"import random\n",
"\n",
"import matplotlib.pyplot as plt\n",
"\n",
"import numpy as np\n",
"\n",
"import os"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"# models.list_models()"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"torch.cuda.empty_cache()\n",
"working_dataset = ds.load_from_disk(cachepath + \"datasets/customrotation/\")\n",
"prepimage = v2.Compose([v2.Grayscale(num_output_channels=3),v2.Resize(512), v2.CenterCrop(512),v2.ToImageTensor(), v2.ConvertImageDtype()])\n",
"working_dataset.set_transform(prepimage)\n",
"testsample = working_dataset['train'][10]\n",
"testimage = testsample['image']\n",
"torch.cuda.empty_cache()"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"# print(models.resnet18(pretrained=True))\n",
"# temp = models.resnet18(pretrained=True)\n",
"# print(temp(testimage.unsqueeze(0)).shape)\n",
"# device = torch.device(\"cpu\")\n",
"# if torch.cuda.is_available:\n",
"# device = torch.device(\"cuda:0\")\n",
"# temp = temp.to(device)\n",
"\n",
"#to be deleted"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"# print(temp(testimage).shape)\n",
"#to be deleted"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"class RotationDeterminer(nn.Module):\n",
" def __init__(self, new=False):\n",
" super(RotationDeterminer,self).__init__()\n",
" \n",
" torch.cuda.empty_cache()\n",
" \n",
" self.device = torch.device(\"cpu\")\n",
" if torch.cuda.is_available:\n",
" self.device = torch.device(\"cuda:0\")\n",
" \n",
" \n",
" self.appliers = [v2.RandomApply(transforms=[v2.RandomPosterize(bits=1)], p=0.25),\n",
" v2.RandomApply(transforms=[v2.ElasticTransform(alpha=25.0)], p=0.25), # maybe add fill=appliedFill\n",
" v2.RandomApply(transforms=[v2.GaussianBlur(kernel_size=(5,9), sigma=(0.1,2.))],p=0.25),\n",
" v2.RandomApply(transforms=[v2.RandomEqualize()],p=0.25)]\n",
" \n",
" \n",
" # self.conv = nn.Sequential(nn.Conv2d(3, 9, kernel_size=11,stride=3), # 1100 x 1100 => 201 x 201\n",
" # nn.ReLU(inplace=True),\n",
" # nn.Conv2d(9, 18, kernel_size=5,stride=1),\n",
" # nn.ReLU(inplace=True),\n",
" # nn.MaxPool2d(kernel_size=4, stride=2),\n",
" # nn.Conv2d(18, 36, kernel_size=3,stride=2),\n",
" # nn.BatchNorm2d(36),\n",
" # nn.ReLU(inplace=True),\n",
" # nn.Conv2d(36, 72, kernel_size=3,stride=2),\n",
" # nn.ReLU(inplace=True),\n",
" # nn.AvgPool2d(kernel_size=5, stride=3),\n",
" # nn.Conv2d(72, 144, kernel_size=3,stride=1),\n",
" # nn.ReLU(inplace=True),\n",
" # nn.Conv2d(144, 288, kernel_size=5,stride=1),\n",
" # nn.ReLU(inplace=True),\n",
" # nn.MaxPool2d(kernel_size=4, stride=1),\n",
" # nn.Conv2d(288, 192, kernel_size=3,stride=1),\n",
" # nn.ReLU(inplace=True),\n",
" # nn.Conv2d(192, 192, kernel_size=3,stride=1), # => 1\n",
" # nn.ReLU(inplace=True))\n",
" # print(\"hi\")\n",
" self.conv = models.resnet18(pretrained=new)\n",
" \n",
" self.classifier = nn.Sequential(nn.Linear(1000, 4096),\n",
" nn.ReLU(inplace=True),\n",
" nn.Linear(4096,1))\n",
" \n",
" self.lossfunc = nn.MSELoss()\n",
" \n",
" self.imageprep = v2.Compose([self.SquarePad(),v2.Resize(512),v2.Grayscale(num_output_channels=3),v2.CenterCrop(512),v2.ToImageTensor(), v2.ConvertImageDtype()])\n",
" \n",
" \n",
" class SquarePad:\n",
" def __call__(self, image):\n",
" # print(\"hi type:\", type(image))\n",
" temp = image.size()\n",
" w = temp[-2]\n",
" h = temp[-1]\n",
" max_wh = max([w, h])\n",
" hp = int((max_wh - w) / 2)\n",
" vp = int((max_wh - h) / 2)\n",
" padding = (hp, vp, hp, vp)\n",
" return tvf.pad(image, padding, 0, 'edge')\n",
"\n",
"\n",
" \n",
"\n",
" \n",
" def forward(self, image):\n",
"\n",
" transformedimage = self.imageprep(image)\n",
" transformedimage = transformedimage.to(self.device)\n",
"\n",
" if (len(transformedimage.shape) != 4 and len(transformedimage.shape) != 3):\n",
" raise Exception(\"Sorry, Dimension of image is incorrect (\", len(transformedimage.shape),\"). Expected a 3D (single image) or 4D (batch of images) tensor\")\n",
"\n",
" if (len(transformedimage.shape) == 3):\n",
" x = transformedimage.unsqueeze(0)\n",
" else:\n",
" x = transformedimage\n",
" \n",
" x = self.conv(x)\n",
" # print(x.shape)\n",
" # x = nn.Flatten(start_dim=-1)(x)\n",
" # print(x.shape)\n",
" x = self.classifier(x)\n",
" # print(x.shape)\n",
" guessRotation = nn.Flatten(start_dim=0)(x)\n",
" \n",
" return guessRotation\n",
" \n",
" def loss(self, guess, trueAnswer):\n",
" return self.lossfunc(guess, trueAnswer)\n",
" \n",
" "
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"def train(model, dataset, batchsize, num_epochs, stepsize, totalnumiters = -1):\n",
" device = torch.device(\"cpu\")\n",
" if torch.cuda.is_available:\n",
" device = torch.device(\"cuda:0\")\n",
" model = model.cuda()\n",
" optimizer = optim.Adam(model.parameters(), lr=stepsize)\n",
" \n",
" counter = totalnumiters\n",
" model = model.train()\n",
" \n",
" breakearly = True\n",
" if totalnumiters == -1:\n",
" print(\"hi\")\n",
" breakearly = False\n",
" totalnumiters = len(dataset) + 1\n",
" \n",
" for e in range(num_epochs):\n",
" \n",
" train_dataloader = DataLoader(dataset, batch_size=batchsize, shuffle=True)\n",
" \n",
" pbar = tqdm(train_dataloader)\n",
" \n",
" for i, batch in enumerate(pbar):\n",
" torch.cuda.empty_cache()\n",
" images, truerotations = batch['image'], batch['rotation']\n",
" images = images.to(device)\n",
" truerotations = truerotations.to(device)\n",
"\n",
" optimizer.zero_grad()\n",
" \n",
" guessRotation = model(images)\n",
" \n",
" truerotations = truerotations.float()\n",
" \n",
" loss = model.loss(guessRotation, truerotations)\n",
" \n",
" loss.backward()\n",
" \n",
" optimizer.step()\n",
" counter = counter - batchsize\n",
" if counter <= 0 and breakearly:\n",
" print(\"endearly\")\n",
" return\n",
"\n",
" "
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"def measure(model, dataset):\n",
" total=0\n",
" within30=0\n",
" within15=0\n",
" within10=0\n",
" within5=0\n",
" within1=0\n",
" withintenth=0\n",
" model = model.eval()\n",
" pbar = tqdm(dataset)\n",
" for i, sample in enumerate(pbar):\n",
" if (i % 100 == 0):\n",
" torch.cuda.empty_cache()\n",
" images, truerotations = sample['image'], sample['rotation']\n",
" output = model(images)\n",
" outputvalue = output.item()\n",
" total = total + 1\n",
" if (abs(outputvalue - truerotations) < 0.1):\n",
" withintenth = withintenth + 1\n",
" within1 = within1 + 1\n",
" within5 = within5 + 1\n",
" within10 = within10 + 1\n",
" within15 = within15 + 1\n",
" within30 = within30 + 1\n",
" elif (abs(outputvalue - truerotations) < 1):\n",
" within1 = within1 + 1\n",
" within5 = within5 + 1\n",
" within10 = within10 + 1\n",
" within15 = within15 + 1\n",
" within30 = within30 + 1\n",
" elif (abs(outputvalue - truerotations) < 5):\n",
" within5 = within5 + 1\n",
" within10 = within10 + 1\n",
" within15 = within15 + 1\n",
" within30 = within30 + 1\n",
" elif (abs(outputvalue - truerotations) < 10):\n",
" within10 = within10 + 1\n",
" within15 = within15 + 1\n",
" within30 = within30 + 1\n",
" elif (abs(outputvalue - truerotations) < 15):\n",
" within15 = within15 + 1\n",
" within30 = within30 + 1\n",
" elif (abs(outputvalue - truerotations) < 30):\n",
" within30 = within30 + 1\n",
" # print(\"Hi\")\n",
" return {\"Within 30 Degrees\": within30/total, \"Within 15 Degrees\": within15/total, \"Within 10 Degrees\": within10/total, \"Within 5 Degrees\": within5/total, \"Within 1 Degree\": within1/total, \"Within 0.1 Degree\": withintenth/total}"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/usr/local/lib/python3.10/dist-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.\n",
" warnings.warn(\n",
"/usr/local/lib/python3.10/dist-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=ResNet18_Weights.IMAGENET1K_V1`. You can also use `weights=ResNet18_Weights.DEFAULT` to get the most up-to-date weights.\n",
" warnings.warn(msg)\n"
]
}
],
"source": [
"model = RotationDeterminer(new=True)\n",
"device = torch.device(\"cpu\")\n",
"if torch.cuda.is_available:\n",
" device = torch.device(\"cuda:0\")\n",
" model = model.to(device)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"# # used when starting a new model training\n",
"# counter = 0\n",
"# outputarray = np.array([])\n",
"# tempdict = {\"Epochs Done\": counter}\n",
"# tempdict.update(measure(model, working_dataset['validation']))\n",
"# outputarray = np.append(outputarray, tempdict)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"# load values\n",
"counter = np.load(savepath + \"/v\"+str(version)+\"/counter.npy\")\n",
"model.load_state_dict(torch.load(savepath + \"/v\"+str(version)+\"/modelsave\" + str(counter) +\"epochs\"))\n",
"outputarray = np.load(savepath + \"/v\"+str(version)+\"/outputarray.npy\", allow_pickle=True)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"# # used to rollback the model one training loop\n",
"# counter = 6\n",
"# outputarray = #removed the 7th element, will go from the 6th epoch"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"hi\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "3048e1546e12444193f99b15781768d9",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/12800 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/usr/local/lib/python3.10/dist-packages/torchvision/transforms/functional.py:1603: UserWarning: The default value of the antialias parameter of all the resizing transforms (Resize(), RandomResizedCrop(), etc.) will change from None to True in v0.17, in order to be consistent across the PIL and Tensor backends. To suppress this warning, directly pass antialias=True (recommended, future default), antialias=None (current default, which means False for Tensors and True for PIL), or antialias=False (only works on Tensors - PIL will still use antialiasing). This also applies if you are using the inference transforms from the models weights: update the call to weights.transforms(antialias=True).\n",
" warnings.warn(\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "1a36a12b123e4b24bf00a8eeec2e396a",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/12800 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# train\n",
"numepochs = 2\n",
"batchsize = 25\n",
"stepsize = 1e-3\n",
"train(model, working_dataset['train'], batchsize, numepochs, stepsize)\n",
"# model = model.eval()\n",
"# output = model(testimage)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"# print(output)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "6b99a98f480745c4a375bf1e713708ed",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/40000 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# outputarray = np.append(outputarray, output.detach().cpu().numpy())\n",
"counter = numepochs + counter\n",
"tempdict = {\"Epochs Done\": counter}\n",
"tempdict.update(measure(model, working_dataset['validation']))\n",
"outputarray = np.append(outputarray, tempdict)"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
"# save values\n",
"torch.save(model.state_dict(), savepath + \"/v\"+str(version)+\"/modelsave\" + str(counter) +\"epochs\")\n",
"np.save(savepath + \"/v\"+str(version)+\"/outputarray\", outputarray)\n",
"np.save(savepath + \"/v\"+str(version)+\"/counter\", counter)"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[{'Epochs Done': 0, 'Within 30 Degrees': 0.162575, 'Within 15 Degrees': 0.080575, 'Within 10 Degrees': 0.053725, 'Within 5 Degrees': 0.027125, 'Within 1 Degree': 0.00545, 'Within 0.1 Degree': 0.00075}\n",
" {'Epochs Done': 1, 'Within 30 Degrees': 0.7764, 'Within 15 Degrees': 0.65105, 'Within 10 Degrees': 0.538875, 'Within 5 Degrees': 0.322625, 'Within 1 Degree': 0.070375, 'Within 0.1 Degree': 0.00805}\n",
" {'Epochs Done': 5, 'Within 30 Degrees': 0.891675, 'Within 15 Degrees': 0.8042, 'Within 10 Degrees': 0.673275, 'Within 5 Degrees': 0.415725, 'Within 1 Degree': 0.092375, 'Within 0.1 Degree': 0.009275}\n",
" {'Epochs Done': 8, 'Within 30 Degrees': 0.928125, 'Within 15 Degrees': 0.881625, 'Within 10 Degrees': 0.7686, 'Within 5 Degrees': 0.4791, 'Within 1 Degree': 0.102925, 'Within 0.1 Degree': 0.009975}\n",
" {'Epochs Done': 11, 'Within 30 Degrees': 0.9417, 'Within 15 Degrees': 0.91265, 'Within 10 Degrees': 0.86655, 'Within 5 Degrees': 0.633125, 'Within 1 Degree': 0.14265, 'Within 0.1 Degree': 0.01495}\n",
" {'Epochs Done': 13, 'Within 30 Degrees': 0.941575, 'Within 15 Degrees': 0.917375, 'Within 10 Degrees': 0.889125, 'Within 5 Degrees': 0.735525, 'Within 1 Degree': 0.1992, 'Within 0.1 Degree': 0.019875}]\n"
]
}
],
"source": [
"print(outputarray)"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "d339e6a22ccf4812bdad90dd3d546c68",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/39999 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/usr/local/lib/python3.10/dist-packages/torchvision/transforms/functional.py:1603: UserWarning: The default value of the antialias parameter of all the resizing transforms (Resize(), RandomResizedCrop(), etc.) will change from None to True in v0.17, in order to be consistent across the PIL and Tensor backends. To suppress this warning, directly pass antialias=True (recommended, future default), antialias=None (current default, which means False for Tensors and True for PIL), or antialias=False (only works on Tensors - PIL will still use antialiasing). This also applies if you are using the inference transforms from the models weights: update the call to weights.transforms(antialias=True).\n",
" warnings.warn(\n"
]
},
{
"data": {
"text/plain": [
"{'Within 30 Degrees': 0.9433985849646241,\n",
" 'Within 15 Degrees': 0.9174979374484362,\n",
" 'Within 10 Degrees': 0.889422235555889,\n",
" 'Within 5 Degrees': 0.737118427960699,\n",
" 'Within 1 Degree': 0.1995799894997375,\n",
" 'Within 0.1 Degree': 0.020050501262531564}"
]
},
"execution_count": 26,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"measure(model, working_dataset['test'])"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
"# first epoch 25 batchsize, 1e-3 stepsize # GOOD PROGRESS SO FAR\n",
"# epoch 2-11 25 batchsize, 1e-3 stepsize"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [],
"source": [
"# model1 = RotationDeterminer(new=False)\n",
"# device = torch.device(\"cpu\")\n",
"# if torch.cuda.is_available:\n",
"# device = torch.device(\"cuda:0\")\n",
"# model1 = model1.to(device)\n",
"# measurementarray=np.array([])\n",
"# for i in range(counter+1):\n",
"# print(i)\n",
"# if (i == 0 or i == 1 or i == 5 or i == 8 or i == 11):\n",
"# tempdict = {\"Epochs Done\": i}\n",
"# model1.load_state_dict(torch.load(savepath + \"/v\"+str(version)+\"/modelsave\" + str(i) +\"epochs\"))\n",
"# tempdict.update(measure(model1, working_dataset['validation']))\n",
"# measurementarray = np.append(measurementarray, tempdict)\n",
" \n",
"# print(\"hi\")"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [],
"source": [
"# print(measurementarray)"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [],
"source": [
"# np.save(savepath + \"/v\"+str(version)+\"/outputarray\", measurementarray)"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [],
"source": [
"# measurementarraycopy = measurementarray\n",
"# tempdict = {\"Epochs Done\": 1}\n",
"# tempdict.update(measurementarraycopy[0])"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [],
"source": [
"# print(tempdict)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@ -1,144 +0,0 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"version=2.0\n",
"cachepath=\"../.cache/\"\n",
"savepath=\"./savespot/\""
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/usr/local/lib/python3.10/dist-packages/torchvision/datapoints/__init__.py:12: UserWarning: The torchvision.datapoints and torchvision.transforms.v2 namespaces are still Beta. While we do not expect major breaking changes, some APIs may still change according to user feedback. Please submit any feedback you may have in this issue: https://github.com/pytorch/vision/issues/6753, and you can also check out https://github.com/pytorch/vision/issues/7319 to learn more about the APIs that we suspect might involve future changes. You can silence this warning by calling torchvision.disable_beta_transforms_warning().\n",
" warnings.warn(_BETA_TRANSFORMS_WARNING)\n",
"/usr/local/lib/python3.10/dist-packages/torchvision/transforms/v2/__init__.py:54: UserWarning: The torchvision.datapoints and torchvision.transforms.v2 namespaces are still Beta. While we do not expect major breaking changes, some APIs may still change according to user feedback. Please submit any feedback you may have in this issue: https://github.com/pytorch/vision/issues/6753, and you can also check out https://github.com/pytorch/vision/issues/7319 to learn more about the APIs that we suspect might involve future changes. You can silence this warning by calling torchvision.disable_beta_transforms_warning().\n",
" warnings.warn(_BETA_TRANSFORMS_WARNING)\n"
]
}
],
"source": [
"import torch\n",
"from torch.utils.data import DataLoader\n",
"import torch.nn as nn\n",
"import torch.nn.functional as fn\n",
"import torch.optim as optim\n",
"import torchvision.transforms.functional as tvf\n",
"import torchvision.transforms.v2 as v2\n",
"import torchvision.models as models\n",
"\n",
"\n",
"from PIL import Image\n",
"\n",
"import datasets as ds\n",
"from tqdm.autonotebook import tqdm\n",
"\n",
"import random\n",
"\n",
"import matplotlib.pyplot as plt\n",
"\n",
"\n",
"import cv2\n",
"import numpy as np\n",
"import myfunctions as mf\n",
"\n",
"torch.cuda.empty_cache()\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"# array = np.load(\"./testing_space/outputarray.npy\")\n",
"# counter = np.load(\"./testing_space/counter.npy\")"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"# print(array)\n",
"# print(counter)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"img = cv2.imread('./test_images/IMG_7605.jpg')\n",
"# img = mf.ResizeWithAspectRatio(img, 1000)\n",
"# img = mf.ResizeWithAspectRatio(mf.SquarePad(fill=255)(img),1000)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"rotatedimg = mf.houghlinedeskewandcrop(img)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"# out = mf.morphologyCrop(img)\n",
"# out = cv2.cvtColor(out, cv2.COLOR_BGR2GRAY)\n",
"# out = cv2.threshold(out, 200, 255, cv2.THRESH_BINARY)[1]"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"cv2.imshow(\"result1\", rotatedimg)\n",
"# cv2.imshow(\"result2\", result2)\n",
"cv2.waitKey(0)\n",
"cv2.destroyAllWindows()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@ -1,345 +0,0 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# ORIGINAL DOCUMENT FOR MORPHOLOGY CROP can maybe be deleted"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [],
"source": [
"import cv2\n",
"import numpy as np\n",
"\n",
"import torch\n",
"from torch.utils.data import DataLoader\n",
"import torch.nn as nn\n",
"import torch.nn.functional as fn\n",
"import torch.optim as optim\n",
"import torchvision.transforms.functional as tvf\n",
"import torchvision.transforms.v2 as v2\n",
"import torchvision.models as models\n",
"import torchvision.transforms as t\n",
"\n",
"import myfunctions as mf\n",
"\n",
"from PIL import Image"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [],
"source": [
"# read image as grayscale\n",
"img = cv2.imread('./test_images/IMG_7640.jpg')"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [],
"source": [
"# def ResizeWithAspectRatio(image, width=None, height=None, inter=cv2.INTER_AREA):\n",
"# dim = None\n",
"# (h, w) = image.shape[:2]\n",
"\n",
"# if width is None and height is None:\n",
"# return image\n",
"# if width is None:\n",
"# r = height / float(h)\n",
"# dim = (int(w * r), height)\n",
"# else:\n",
"# r = width / float(w)\n",
"# dim = (width, int(h * r))\n",
"\n",
"# return cv2.resize(image, dim, interpolation=inter)"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
"outputs": [],
"source": [
"# convert to grayscale\n",
"gray = cv2.cvtColor(img,cv2.COLOR_BGR2GRAY)\n",
"\n",
"# threshold\n",
"thresh = cv2.threshold(gray, 190, 255, cv2.THRESH_BINARY)[1]\n",
"\n",
"# apply morphology\n",
"kernel = np.ones((7,7), np.uint8)\n",
"morph = cv2.morphologyEx(thresh, cv2.MORPH_CLOSE, kernel)\n",
"kernel = np.ones((9,9), np.uint8)\n",
"morph = cv2.morphologyEx(morph, cv2.MORPH_ERODE, kernel)\n",
"\n",
"# get largest contour\n",
"contours = cv2.findContours(morph, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)\n",
"contours = contours[0] if len(contours) == 2 else contours[1]\n",
"area_thresh = 0\n",
"for c in contours:\n",
" area = cv2.contourArea(c)\n",
" if area > area_thresh:\n",
" area_thresh = area\n",
" big_contour = c\n",
"\n",
"\n",
"# get bounding box\n",
"x,y,w,h = cv2.boundingRect(big_contour)\n",
"\n",
"# draw filled contour on black background\n",
"mask = np.zeros_like(gray)\n",
"mask = cv2.merge([mask,mask,mask])\n",
"cv2.drawContours(mask, [big_contour], -1, (255,255,255), cv2.FILLED)\n",
"\n",
"# apply mask to input\n",
"result1 = img.copy()\n",
"result1 = cv2.bitwise_and(result1, mask)\n",
"\n",
"# crop result\n",
"result2 = result1[y:y+h, x:x+w]\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {},
"outputs": [],
"source": [
"# view result\n",
"# cv2.imshow(\"threshold\", thresh)\n",
"# cv2.imshow(\"morph\", morph)\n",
"# cv2.imshow(\"mask\", mask)\n",
"# cv2.imshow(\"result1\", result1)\n",
"resizedresult2 = mf.ResizeWithAspectRatio(result2, 1000)\n",
"cv2.imwrite(\"./testing_space/cropped1.jpg\", resizedresult2)\n",
"cv2.imshow(\"result2\", resizedresult2)\n",
"cv2.waitKey(0)\n",
"cv2.destroyAllWindows()\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"class RotationDeterminer(nn.Module):\n",
" def __init__(self, new=False):\n",
" super(RotationDeterminer,self).__init__()\n",
" \n",
" torch.cuda.empty_cache()\n",
" \n",
" self.device = torch.device(\"cpu\")\n",
" if torch.cuda.is_available:\n",
" self.device = torch.device(\"cuda:0\")\n",
" \n",
" \n",
" self.appliers = [v2.RandomApply(transforms=[v2.RandomPosterize(bits=1)], p=0.25),\n",
" v2.RandomApply(transforms=[v2.ElasticTransform(alpha=25.0)], p=0.25), # maybe add fill=appliedFill\n",
" v2.RandomApply(transforms=[v2.GaussianBlur(kernel_size=(5,9), sigma=(0.1,2.))],p=0.25),\n",
" v2.RandomApply(transforms=[v2.RandomEqualize()],p=0.25)]\n",
" \n",
" \n",
" # self.conv = nn.Sequential(nn.Conv2d(3, 9, kernel_size=11,stride=3), # 1100 x 1100 => 201 x 201\n",
" # nn.ReLU(inplace=True),\n",
" # nn.Conv2d(9, 18, kernel_size=5,stride=1),\n",
" # nn.ReLU(inplace=True),\n",
" # nn.MaxPool2d(kernel_size=4, stride=2),\n",
" # nn.Conv2d(18, 36, kernel_size=3,stride=2),\n",
" # nn.BatchNorm2d(36),\n",
" # nn.ReLU(inplace=True),\n",
" # nn.Conv2d(36, 72, kernel_size=3,stride=2),\n",
" # nn.ReLU(inplace=True),\n",
" # nn.AvgPool2d(kernel_size=5, stride=3),\n",
" # nn.Conv2d(72, 144, kernel_size=3,stride=1),\n",
" # nn.ReLU(inplace=True),\n",
" # nn.Conv2d(144, 288, kernel_size=5,stride=1),\n",
" # nn.ReLU(inplace=True),\n",
" # nn.MaxPool2d(kernel_size=4, stride=1),\n",
" # nn.Conv2d(288, 192, kernel_size=3,stride=1),\n",
" # nn.ReLU(inplace=True),\n",
" # nn.Conv2d(192, 192, kernel_size=3,stride=1), # => 1\n",
" # nn.ReLU(inplace=True))\n",
" # print(\"hi\")\n",
" self.conv = models.resnet18(pretrained=new)\n",
" \n",
" self.classifier = nn.Sequential(nn.Linear(1000, 4096),\n",
" nn.ReLU(inplace=True),\n",
" nn.Linear(4096,1))\n",
" \n",
" self.lossfunc = nn.MSELoss()\n",
" \n",
" self.imageprep = v2.Compose([self.SquarePad(),v2.Resize(512),v2.Grayscale(num_output_channels=3),v2.CenterCrop(512),v2.ToImageTensor(), v2.ConvertImageDtype()])\n",
" \n",
" \n",
" class SquarePad:\n",
" def __call__(self, image):\n",
" # print(\"hi type:\", type(image))\n",
" temp = image.size()\n",
" w = temp[-2]\n",
" h = temp[-1]\n",
" max_wh = max([w, h])\n",
" hp = int((max_wh - w) / 2)\n",
" vp = int((max_wh - h) / 2)\n",
" padding = (hp, vp, hp, vp)\n",
" return tvf.pad(image, padding, 0, 'edge')\n",
"\n",
"\n",
" \n",
"\n",
" \n",
" def forward(self, image):\n",
"\n",
" transformedimage = self.imageprep(image)\n",
" transformedimage = transformedimage.to(self.device)\n",
"\n",
" if (len(transformedimage.shape) != 4 and len(transformedimage.shape) != 3):\n",
" raise Exception(\"Sorry, Dimension of image is incorrect (\", len(transformedimage.shape),\"). Expected a 3D (single image) or 4D (batch of images) tensor\")\n",
"\n",
" if (len(transformedimage.shape) == 3):\n",
" x = transformedimage.unsqueeze(0)\n",
" else:\n",
" x = transformedimage\n",
" \n",
" x = self.conv(x)\n",
" # print(x.shape)\n",
" # x = nn.Flatten(start_dim=-1)(x)\n",
" # print(x.shape)\n",
" x = self.classifier(x)\n",
" # print(x.shape)\n",
" guessRotation = nn.Flatten(start_dim=0)(x)\n",
" \n",
" return guessRotation\n",
" \n",
" def loss(self, guess, trueAnswer):\n",
" return self.lossfunc(guess, trueAnswer)\n",
" \n",
" "
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/usr/local/lib/python3.10/dist-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.\n",
" warnings.warn(\n",
"/usr/local/lib/python3.10/dist-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=ResNet18_Weights.IMAGENET1K_V1`. You can also use `weights=ResNet18_Weights.DEFAULT` to get the most up-to-date weights.\n",
" warnings.warn(msg)\n"
]
}
],
"source": [
"model = RotationDeterminer(new=True)\n",
"device = torch.device(\"cpu\")\n",
"if torch.cuda.is_available:\n",
" device = torch.device(\"cuda:0\")\n",
" model = model.to(device)"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([1, 1174, 1000])\n",
"torch.Size([3, 1174, 1000])\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/usr/local/lib/python3.10/dist-packages/torchvision/transforms/functional.py:1603: UserWarning: The default value of the antialias parameter of all the resizing transforms (Resize(), RandomResizedCrop(), etc.) will change from None to True in v0.17, in order to be consistent across the PIL and Tensor backends. To suppress this warning, directly pass antialias=True (recommended, future default), antialias=None (current default, which means False for Tensors and True for PIL), or antialias=False (only works on Tensors - PIL will still use antialiasing). This also applies if you are using the inference transforms from the models weights: update the call to weights.transforms(antialias=True).\n",
" warnings.warn(\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"-0.1470905989408493\n"
]
}
],
"source": [
"tensorize = v2.Compose([v2.ToImageTensor(), v2.ConvertImageDtype()])\n",
"grayscaler = v2.Grayscale(num_output_channels=3)\n",
"\n",
"imagetobeprocessed = cv2.cvtColor(resizedresult2,cv2.COLOR_BGR2GRAY)\n",
"\n",
"\n",
"tensorizedimage = torch.unsqueeze(torch.from_numpy(imagetobeprocessed),0)\n",
"print(tensorizedimage.shape)\n",
"adjustedtensorizedimage = tensorize(grayscaler(t.ToPILImage()(tensorizedimage)))\n",
"print(adjustedtensorizedimage.shape)\n",
"rotation = model(adjustedtensorizedimage).item()\n",
"print(rotation)\n",
"rotatedimage = t.Resize(size=1000)(tvf.rotate(adjustedtensorizedimage, rotation))\n",
"# imS = mf.ResizeWithAspectRatio(filereadimage, 1000)\n",
"# imS = cv2.resize(filereadimage, (960, 540)) \n",
"open_cv_image = np.array(t.ToPILImage()(rotatedimage))\n",
"cv2.imshow(f'image', open_cv_image)\n",
"key = cv2.waitKey(0)\n",
"cv2.destroyAllWindows()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# # save result\n",
"# cv2.imwrite(\"paper_thresh.jpg\", thresh)\n",
"# cv2.imwrite(\"paper_morph.jpg\", morph)\n",
"# cv2.imwrite(\"paper_mask.jpg\", mask)\n",
"# cv2.imwrite(\"paper_result1.jpg\", result1)\n",
"# cv2.imwrite(\"paper_result2.jpg\", result2)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@ -1,387 +0,0 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 772,
"metadata": {},
"outputs": [],
"source": [
"import cv2\n",
"import numpy as np\n",
"\n",
"import myfunctions as mf\n",
"\n",
"\n",
"import scipy.stats as st\n",
"import math"
]
},
{
"cell_type": "code",
"execution_count": 773,
"metadata": {},
"outputs": [],
"source": [
"# read image as grayscale\n",
"img = cv2.imread('./test_images/IMG_7605.jpg')\n",
"# img = mf.ResizeWithAspectRatio(img,1000)\n",
"# img = mf.rotate(img, 54)"
]
},
{
"cell_type": "code",
"execution_count": 774,
"metadata": {},
"outputs": [],
"source": [
"prepped = mf.ResizeWithAspectRatio(mf.SquarePad(fill=255)(img),1000)\n",
"prepped = mf.premorphCrop(prepped)\n",
"prepped = mf.ResizeWithAspectRatio(mf.SquarePad(fill=255)(prepped),1000)\n",
"# kernel = np.ones((5,5), np.uint8)\n",
"# prepped = cv2.dilate(prepped, kernel, iterations=1)\n",
"gray1 = cv2.cvtColor(prepped, cv2.COLOR_BGR2GRAY)\n",
"dst1 = cv2.Canny(gray1, 0, 500, None, 3)\n",
"\n",
"kernel = np.ones((5,5), np.uint8)\n",
"out = cv2.morphologyEx(dst1, cv2.MORPH_DILATE, kernel)\n",
"out = cv2.blur(out, (5,5))\n",
"kernel = np.ones((6,6), np.uint8)\n",
"dst1 = cv2.morphologyEx(out, cv2.MORPH_ERODE, kernel)\n",
"\n",
"dst1 = cv2.Canny(dst1, 0, 500, None, 3)\n",
"\n",
"cdstP = prepped.copy()\n",
"cdstPmargin = cdstP.copy()\n",
"basecdstP = cdstP.copy()\n",
"linesP = cv2.HoughLinesP(dst1, 1, np.pi / 180, 30, None, 90, 30)"
]
},
{
"cell_type": "code",
"execution_count": 779,
"metadata": {},
"outputs": [],
"source": [
"# # testing = dst1.copy()\n",
"# # kernel = np.ones((5,5), np.uint8)\n",
"# # out = cv2.morphologyEx(testing, cv2.MORPH_DILATE, kernel)\n",
"# # out = cv2.blur(out, (5,5))\n",
"# # kernel = np.ones((3,3), np.uint8)\n",
"# # out = cv2.morphologyEx(out, cv2.MORPH_ERODE, kernel)\n",
"cv2.imshow(\"result1\", dst1)\n",
"cv2.waitKey(0)\n",
"cv2.destroyAllWindows()"
]
},
{
"cell_type": "code",
"execution_count": 758,
"metadata": {},
"outputs": [],
"source": [
"angles = np.zeros(len(linesP))\n",
"if linesP is not None:\n",
" for i in range(0, len(linesP)):\n",
" l = linesP[i][0]\n",
" angles[i] = mf.lineAngle(l)\n",
" cv2.line(cdstP, (l[0], l[1]), (l[2], l[3]), (0,0,255), 3, cv2.LINE_AA)"
]
},
{
"cell_type": "code",
"execution_count": 759,
"metadata": {},
"outputs": [],
"source": [
"# cv2.imshow(\"result1\", cdstP)\n",
"# cv2.waitKey(0)\n",
"# cv2.destroyAllWindows()"
]
},
{
"cell_type": "code",
"execution_count": 760,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"-3.093972093706445\n"
]
}
],
"source": [
"mode = st.mode(np.around(angles, decimals=3))[0]\n",
"rotationangle = np.rad2deg(mode)\n",
"print(rotationangle)"
]
},
{
"cell_type": "code",
"execution_count": 761,
"metadata": {},
"outputs": [],
"source": [
"rotatedcdstP = mf.rotate(basecdstP, rotationangle)"
]
},
{
"cell_type": "code",
"execution_count": 762,
"metadata": {},
"outputs": [],
"source": [
"vmarginlines = mf.WithinXDegrees(linesP, 7, baseangle=rotationangle)\n",
"hmarginlines = mf.WithinXDegrees(linesP, 7, baseangle=90+rotationangle)\n",
"vrect = mf.lineBoundingRect(vmarginlines,asRect=False, returnint=True)\n",
"hmarginlines = mf.lineswithinrange(hmarginlines, (vrect[0], vrect[1]), (vrect[2],vrect[3]), x=True, y=False)\n",
"\n",
"\n",
"if (hmarginlines != []):\n",
" marginlines = np.append(vmarginlines, hmarginlines, axis=0)\n",
"else:\n",
" marginlines = vmarginlines\n",
" \n",
"rect = mf.lineBoundingRect(marginlines,asRect=False, returnint=True)\n",
"cdstP = cv2.rectangle(cdstP, (rect[0],rect[1]), (rect[2],rect[3]), (0,255,0), 3)"
]
},
{
"cell_type": "code",
"execution_count": 763,
"metadata": {},
"outputs": [],
"source": [
"cv2.imshow(\"result1\", cdstP)\n",
"cv2.waitKey(0)\n",
"cv2.destroyAllWindows()"
]
},
{
"cell_type": "code",
"execution_count": 764,
"metadata": {},
"outputs": [],
"source": [
"#####NEED TO WORK ON SCORING THE LINES SO IT PICKS THE CORRECT ORIENTATION (horizontal vs vertical) AND SO THAT THE CROPPING RECTANGLE MOVES/GET TRANSFORMED WITH IT"
]
},
{
"cell_type": "code",
"execution_count": 780,
"metadata": {},
"outputs": [],
"source": [
"def rotatePoint(img, pt, angle, returnint=True):\n",
" rotateaxisx = img.shape[0]/2\n",
" rotateaxisy = img.shape[1]/2\n",
" tempx = pt[0] - rotateaxisx\n",
" tempy = pt[1] - rotateaxisy\n",
" rotatedx = tempx*math.cos(np.deg2rad(-angle)) - tempy*math.sin(np.deg2rad(-angle))\n",
" rotatedy = tempx*math.sin(np.deg2rad(-angle)) + tempy*math.cos(np.deg2rad(-angle))\n",
" finalx = rotatedx + rotateaxisx\n",
" finaly = rotatedy + rotateaxisy\n",
" if (returnint):\n",
" finalx = int(finalx)\n",
" finaly = int(finaly)\n",
" return (finalx, finaly)\n",
" "
]
},
{
"cell_type": "code",
"execution_count": 766,
"metadata": {},
"outputs": [],
"source": [
"def rotateRect(img, rect, angle, returnint=True, asRect=False):\n",
" if (asRect):\n",
" pt1 = rotatePoint(img, (rect[0],rect[1]), angle, returnint)\n",
" pt2 = rotatePoint(img, (rect[0]+rect[2],rect[1]+rect[3]), angle, returnint)\n",
" return (pt1[0], pt1[1], pt2[0]-pt1[0], pt2[1]-pt1[1])\n",
" else:\n",
" pt1 = rotatePoint(img, (rect[0],rect[1]), angle, returnint)\n",
" pt2 = rotatePoint(img, (rect[2],rect[3]), angle, returnint)\n",
" return (pt1[0], pt1[1], pt2[0], pt2[1])\n",
"\n",
"def rotateLine(img, line, angle, returnint=True):\n",
" pt1 = rotatePoint(img, (line[0],line[1]), angle, returnint)\n",
" pt2 = rotatePoint(img, (line[2],line[3]), angle, returnint)\n",
" return (pt1[0], pt1[1], pt2[0], pt2[1])\n",
" \n",
" "
]
},
{
"cell_type": "code",
"execution_count": 767,
"metadata": {},
"outputs": [],
"source": [
"# print(linesP.shape)\n",
"rotatedlines = [rotateLine(rotatedcdstP, line[0], rotationangle) for line in linesP]\n",
"rotatedlines = np.reshape(rotatedlines, (len(rotatedlines),1,4))\n",
"# rotatedlines = linesP\n",
"# print(rotatedlines.shape)"
]
},
{
"cell_type": "code",
"execution_count": 768,
"metadata": {},
"outputs": [],
"source": [
"vmarginlines = mf.WithinXDegrees(rotatedlines, 7)\n",
"hmarginlines = mf.WithinXDegrees(rotatedlines, 7, baseangle=90)\n",
"vrect = mf.lineBoundingRect(vmarginlines,asRect=False, returnint=True)\n",
"hmarginlines = mf.lineswithinrange(hmarginlines, (vrect[0], vrect[1]), (vrect[2],vrect[3]), x=True, y=False)\n",
"\n",
"if (hmarginlines != []):\n",
" marginlines = np.append(vmarginlines, hmarginlines, axis=0)\n",
"else:\n",
" marginlines = vmarginlines\n",
" \n",
"rect = mf.lineBoundingRect(marginlines,asRect=False, returnint=True)\n",
"# rect = vrect\n",
"rotatedcdstP = cv2.rectangle(rotatedcdstP, (rect[0],rect[1]), (rect[2],rect[3]), (0,255,0), 3)"
]
},
{
"cell_type": "code",
"execution_count": 769,
"metadata": {},
"outputs": [],
"source": [
"if rotatedlines is not None:\n",
" for i in range(0, len(rotatedlines)):\n",
" l = rotatedlines[i][0]\n",
" cv2.line(rotatedcdstP, (l[0], l[1]), (l[2], l[3]), (0,0,255), 3, cv2.LINE_AA)"
]
},
{
"cell_type": "code",
"execution_count": 771,
"metadata": {},
"outputs": [],
"source": [
"cv2.imshow(\"result1\", rotatedcdstP)\n",
"# cv2.imshow(\"result1\", cdstP)\n",
"cv2.waitKey(0)\n",
"cv2.destroyAllWindows()"
]
},
{
"cell_type": "code",
"execution_count": 394,
"metadata": {},
"outputs": [],
"source": [
"vmarginlines = mf.WithinXDegrees(linesP, 7)\n",
"hmarginlines = mf.WithinXDegrees(linesP, 7, baseangle=90)\n",
"vrect = mf.lineBoundingRect(vmarginlines,asRect=False, returnint=True)\n",
"hmarginlines = mf.lineswithinrange(hmarginlines, (vrect[0], vrect[1]), (vrect[2],vrect[3]), x=True, y=False)\n",
"\n",
"\n",
"if (hmarginlines != []):\n",
" marginlines = np.append(vmarginlines, hmarginlines, axis=0)\n",
"else:\n",
" marginlines = vmarginlines\n",
"\n",
"rect = mf.lineBoundingRect(marginlines,asRect=False, returnint=True)\n",
"cdstP = cv2.rectangle(cdstP, (rect[0],rect[1]), (rect[2],rect[3]), (0,255,0), 3)\n",
"\n",
"\n",
"# rotatedrect = rotateRect(cdstP, rect, -rotationangle)\n",
"\n",
"# rotatedcdstP = cv2.rectangle(rotatedcdstP, (rotatedrect[0],rotatedrect[1]), (rotatedrect[2],rotatedrect[3]), (0,255,0), 3)"
]
},
{
"cell_type": "code",
"execution_count": 395,
"metadata": {},
"outputs": [],
"source": [
"###figure out how to rotate rectangle"
]
},
{
"cell_type": "code",
"execution_count": 396,
"metadata": {},
"outputs": [],
"source": [
"cv2.imshow(\"result1\", cdstP)\n",
"cv2.waitKey(0)\n",
"cv2.destroyAllWindows()"
]
},
{
"cell_type": "code",
"execution_count": 397,
"metadata": {},
"outputs": [],
"source": [
"# vmarginlines = mf.WithinXDegrees(linesP, 7)\n",
"# hmarginlines = mf.WithinXDegrees(linesP, 7, baseangle=90)\n",
"# vrect = mf.lineBoundingRect(vmarginlines,asRect=False, returnint=True)\n",
"# hmarginlines = mf.lineswithinrange(hmarginlines, (vrect[0], vrect[1]), (vrect[2],vrect[3]), x=True, y=False)\n",
"# # print(hmarginlines)\n",
"# if (hmarginlines != []):\n",
"# marginlines = np.append(vmarginlines, hmarginlines, axis=0)\n",
"# else:\n",
"# marginlines = vmarginlines\n",
"\n",
"# # print(marginlines)\n",
"# rect = mf.lineBoundingRect(marginlines,asRect=False, returnint=True)\n",
"# # print(rect)\n",
"# cdstP = cv2.rectangle(cdstP, (rect[0],rect[1]), (rect[2],rect[3]), (0,255,0), 3)\n",
"# # print(cdstP.shape)\n",
"# # cropped = cdstP[rect[1]:rect[3], rect[0]:rect[2],:]\n",
"\n",
"# if marginlines is not None:\n",
"# for i in range(0, len(marginlines)):\n",
"# l = marginlines[i]\n",
"# cv2.line(cdstP, (int(l[0]), int(l[1])), (int(l[2]), int(l[3])), (255,0,0), 3, cv2.LINE_AA)"
]
},
{
"cell_type": "code",
"execution_count": 398,
"metadata": {},
"outputs": [],
"source": [
"# # view result\n",
"# # cv2.imshow(\"threshold\", thresh)\n",
"# # cv2.imshow(\"morph\", morph)\n",
"# # cv2.imshow(\"mask\", mask)\n",
"# cv2.imshow(\"result1\", mf.ResizeWithAspectRatio(cdstP,height=1000))\n",
"# # cv2.imshow(\"result2\", cropped)\n",
"# cv2.waitKey(0)\n",
"# cv2.destroyAllWindows()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@ -1,165 +0,0 @@
#include "cropper.h"
#include <opencv2/ximgproc/segmentation.hpp>
#include <opencv2/core.hpp>
#include <opencv2/imgproc.hpp>
using namespace cv::ximgproc::segmentation;
inline cv::Point topLeft(cv::Rect rect) {
return cv::Point(rect.x, rect.y);
}
inline cv::Point bottomLeft(cv::Rect rect) {
return cv::Point(rect.x, rect.y + rect.height);
}
inline cv::Point topRight(cv::Rect rect) {
return cv::Point(rect.x + rect.width, rect.y);
}
inline cv::Point bottomRight(cv::Rect rect) {
return cv::Point(rect.x + rect.width, rect.y + rect.height);
}
inline double distanceBetweenPoints(cv::Point p1, cv::Point p2) {
return std::sqrt(std::pow(p1.x - p2.x, 2) + std::pow(p1.y - p2.y, 2));
}
inline void scaleRect(cv::Rect& r, int originalheight, int currentheight) {
int scalingFactor = originalheight / currentheight;
r.x *= scalingFactor;
r.y *= scalingFactor;
r.width *= scalingFactor;
r.height *= scalingFactor;
}
// uses the L2 loss of the corners of the rectangles
double MSELossRect(cv::Rect r1, cv::Rect r2) {
return (distanceBetweenPoints(topLeft(r1), topLeft(r2)) +
distanceBetweenPoints(bottomLeft(r1), bottomLeft(r2)) +
distanceBetweenPoints(topRight(r1), topRight(r2)) +
distanceBetweenPoints(bottomRight(r1), bottomRight(r2))) / 4.0;
}
std::vector<cv::Rect> selectiveSearchSegmentationActor(cv::InputArray src, bool fast = true, int imageHeight = 800) {
cv::setUseOptimized(true);
cv::setNumThreads(4);
cv::Mat temp = src.getMat();
cv::Ptr<cv::ximgproc::segmentation::SelectiveSearchSegmentation> ss =
createSelectiveSearchSegmentation();
ss->setBaseImage(temp);
if (fast) {
ss->switchToSelectiveSearchFast();
} else {
ss->switchToSelectiveSearchQuality();
}
std::vector<cv::Rect> rects;
ss->process(rects);
return rects;
}
inline double clip(double n, double lower, double upper) {
return std::max(lower, std::min(n, upper));
};
inline double colourscaler(double n, double min, double max) {
double temp = n - min;
double diff = std::abs(max - min);
return clip((temp / diff) * 255, 0, 255);
};
cv::Rect cannyEdgeRectangle(cv::InputArray src, int lower = 100, int upper = 255, double threshold1 = 50, double threshold2 = 350) {
cv::Mat gray, scaled_gray, blurred, edged;
lower = std::max(lower, 0);
upper = std::min(upper, 255);
cv::cvtColor(src, gray, cv::COLOR_BGR2GRAY);
scaled_gray = cv::Mat::zeros(gray.size(), gray.type());
for (int y = 0; y < gray.rows; y++) {
for (int x = 0; x < gray.cols; x++) {
scaled_gray.at<uchar>(y, x) =
cv::saturate_cast<uchar>(colourscaler(gray.at<uchar>(y, x), lower, upper));
}
}
cv::GaussianBlur(scaled_gray, blurred, cv::Size(15, 15), 0);
cv::Canny(blurred, edged, threshold1, threshold2);
std::vector<std::vector<cv::Point>> contours;
std::vector<cv::Vec4i> heirarchy;
cv::Mat approx;
cv::findContours(edged, contours, heirarchy, cv::RETR_TREE, cv::CHAIN_APPROX_SIMPLE);
cv::cvtColor(gray, gray, cv::COLOR_GRAY2BGR);
std::sort(contours.begin(), contours.end(), [](std::vector<cv::Point> a, std::vector<cv::Point> b) {
return cv::arcLength(a, false) > cv::arcLength(b, false); });
int numContours = contours.size();
return cv::boundingRect(contours[0]);
}
bool crop(cv::InputArray src, cv::OutputArray dst, bool fastsearch, int imageHeight) { //add other params or maybe overload or something
cv::Mat temp;
src.copyTo(temp);
int newWidth = temp.cols * imageHeight / temp.rows;
cv::resize(temp, temp, cv::Size(newWidth, imageHeight));
cv::Rect cannyRect = cannyEdgeRectangle(temp, 100, 255, 255 / 4, 255);
std::vector<cv::Rect> rects = selectiveSearchSegmentationActor(temp, fastsearch);
int indexOfMin = -1;
double currentMin = std::numeric_limits<double>::max();
int lengthOfRects = rects.size();
for (int i = 0; i < lengthOfRects; i++) {
double tempMin = MSELossRect(rects[i], cannyRect);
if (tempMin < currentMin) {
indexOfMin = i;
currentMin = tempMin;
}
}
cv::Rect goodRect = rects[indexOfMin];
cv::Rect finalRect;
if (goodRect.area() > cannyRect.area()) {
finalRect = goodRect;
} else {
finalRect = cannyRect;
}
cv::Mat extra = src.getMat();
scaleRect(finalRect, extra.rows, temp.rows);
extra = extra(finalRect);
extra.copyTo(dst);
return true;
}

File diff suppressed because it is too large Load Diff

Binary file not shown.

Before

Width:  |  Height:  |  Size: 3.0 MiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 2.1 MiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 5.3 MiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 2.3 MiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 139 KiB

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -1,62 +0,0 @@
cmake_minimum_required(VERSION 3.22)
project(imagemanipulation_libraries
VERSION 0.1
DESCRIPTION "Libraries for image preprocessing"
LANGUAGES CXX)
include(GNUInstallDirs)
find_package(OpenCV REQUIRED)
# RECTANGLE
add_library(rect SHARED src/rectangle.cpp)
target_compile_features(rect PRIVATE cxx_std_20)
# set_target_properties(rect PROPERTIES VERSION ${PROJECT_VERSION}) # git can't deal with the symlinks for some reason
# set_target_properties(rect PROPERTIES PUBLIC_HEADER ${CMAKE_CURRENT_SOURCE_DIR}/include/rect_lib.h)
set_target_properties(rect PROPERTIES LIBRARY_OUTPUT_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/lib)
target_link_libraries(rect ${OpenCV_LIBS})
target_include_directories(rect
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include
PRIVATE ${OpenCV_INCLUDE_DIRS})
# LINE
add_library(line SHARED src/line.cpp)
target_compile_features(line PRIVATE cxx_std_20)
# set_target_properties(line PROPERTIES VERSION ${PROJECT_VERSION}) # git can't deal with the symlinks for some reason
# set_target_properties(line PROPERTIES PUBLIC_HEADER ${CMAKE_CURRENT_SOURCE_DIR}/include/line_lib.h)
set_target_properties(line PROPERTIES LIBRARY_OUTPUT_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/lib)
target_link_libraries(line ${OpenCV_LIBS})
target_include_directories(line
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include
PRIVATE ${OpenCV_INCLUDE_DIRS})
# install(TARGETS rect
# LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}
# PUBLIC_HEADER DESTINATION ${CMAKE_INSTALL_INCLUDEDIR})
# find_package(OpenCV REQUIRED)
# target_include_directories(CropperEx
# PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include
# PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../externallibraries/stbimagehelpers
# PRIVATE ${OpenCV_INCLUDE_DIRS})

View File

@ -1,20 +0,0 @@
#ifndef LINE_H
#define LINE_H
class Line {
private:
public:
private:
public:
};
#endif //LINE_H

Some files were not shown because too many files have changed in this diff Show More