package main

import (
	"encoding/json"
	"fmt"
	"log"
	"net/http"
	"os"
	"runtime"
	"strconv"
	"sync"
)

type multiplyRequest struct {
	A [][]float64 `json:"A"`
	B [][]float64 `json:"B"`
}

type multiplyResponse struct {
	Result  [][]float64 `json:"result"`
	Workers int         `json:"workers"`
}

func main() {
	runtime.GOMAXPROCS(runtime.NumCPU())

	mux := http.NewServeMux()
	mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
		w.Header().Set("Content-Type", "application/json")
		fmt.Fprintf(w, `{"message":"go server is running","cpus":%d}`, runtime.NumCPU())
	})

	mux.HandleFunc("/multiply", func(w http.ResponseWriter, r *http.Request) {
		if r.Method != http.MethodPost {
			http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
			return
		}

		var req multiplyRequest
		dec := json.NewDecoder(r.Body)
		dec.DisallowUnknownFields()
		if err := dec.Decode(&req); err != nil {
			http.Error(w, fmt.Sprintf("invalid json: %v", err), http.StatusBadRequest)
			return
		}

		if len(req.A) == 0 || len(req.B) == 0 || len(req.A[0]) == 0 || len(req.B[0]) == 0 {
			http.Error(w, "invalid matrices A or B", http.StatusBadRequest)
			return
		}
		nColsA := len(req.A[0])
		nRowsB := len(req.B)
		if nColsA != nRowsB {
			http.Error(w, "number of columns of A must equal number of rows of B", http.StatusBadRequest)
			return
		}

		numWorkers := runtime.NumCPU()
		// numWorkers := 1
		if q := r.URL.Query().Get("workers"); q != "" {
			if v, err := strconv.Atoi(q); err == nil && v > 0 {
				numWorkers = v
			}
		}

		result := multiplyParallel(req.A, req.B, numWorkers)

		w.Header().Set("Content-Type", "application/json")
		_ = json.NewEncoder(w).Encode(multiplyResponse{Result: result, Workers: numWorkers})
	})

	mux.HandleFunc("/cores", func(w http.ResponseWriter, r *http.Request) {
		w.Header().Set("Content-Type", "application/json")
		fmt.Fprintf(w, `{"cores":%d}`, runtime.NumCPU())
	})

	port := os.Getenv("PORT")
	if port == "" {
		port = "3000"
	}
	log.Printf("Go server running on port %s (cpus=%d)\\n", port, runtime.NumCPU())
	if err := http.ListenAndServe(":"+port, withCORS(mux)); err != nil {
		log.Fatal(err)
	}
}

func withCORS(next http.Handler) http.Handler {
	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		w.Header().Set("Access-Control-Allow-Origin", "*")
		w.Header().Set("Access-Control-Allow-Headers", "Content-Type")
		w.Header().Set("Access-Control-Allow-Methods", "POST, GET, OPTIONS")
		if r.Method == http.MethodOptions {
			w.WriteHeader(http.StatusNoContent)
			return
		}
		next.ServeHTTP(w, r)
	})
}

func multiplyParallel(A, B [][]float64, numWorkers int) [][]float64 {
	m := len(A)
	n := len(A[0])
	p := len(B[0])

	result := make([][]float64, m)
	for i := 0; i < m; i++ {
		result[i] = make([]float64, p)
	}

	jobs := make(chan int, m)
	var wg sync.WaitGroup

	workerCount := numWorkers
	if workerCount < 1 {
		workerCount = 1
	}
	wg.Add(workerCount)
	for w := 0; w < workerCount; w++ {
		workerID := w
		go func() {
			defer wg.Done()
			for i := range jobs {
				log.Printf("core %d/%d processing row %d", workerID+1, workerCount, i)
				row := A[i]
				for j := 0; j < p; j++ {
					var sum float64
					for k := 0; k < n; k++ {
						sum += row[k] * B[k][j]
					}
					result[i][j] = sum
				}
			}
		}()
	}

	for i := 0; i < m; i++ {
		jobs <- i
	}
	close(jobs)
	wg.Wait()

	return result
}