package main
import (
"encoding/json"
"fmt"
"log"
"net/http"
"os"
"runtime"
"strconv"
"sync"
)
type multiplyRequest struct {
A [][]float64 `json:"A"`
B [][]float64 `json:"B"`
}
type multiplyResponse struct {
Result [][]float64 `json:"result"`
Workers int `json:"workers"`
}
func main() {
runtime.GOMAXPROCS(runtime.NumCPU())
mux := http.NewServeMux()
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
fmt.Fprintf(w, `{"message":"go server is running","cpus":%d}`, runtime.NumCPU())
})
mux.HandleFunc("/multiply", func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return
}
var req multiplyRequest
dec := json.NewDecoder(r.Body)
dec.DisallowUnknownFields()
if err := dec.Decode(&req); err != nil {
http.Error(w, fmt.Sprintf("invalid json: %v", err), http.StatusBadRequest)
return
}
if len(req.A) == 0 || len(req.B) == 0 || len(req.A[0]) == 0 || len(req.B[0]) == 0 {
http.Error(w, "invalid matrices A or B", http.StatusBadRequest)
return
}
nColsA := len(req.A[0])
nRowsB := len(req.B)
if nColsA != nRowsB {
http.Error(w, "number of columns of A must equal number of rows of B", http.StatusBadRequest)
return
}
numWorkers := runtime.NumCPU()
// numWorkers := 1
if q := r.URL.Query().Get("workers"); q != "" {
if v, err := strconv.Atoi(q); err == nil && v > 0 {
numWorkers = v
}
}
result := multiplyParallel(req.A, req.B, numWorkers)
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(multiplyResponse{Result: result, Workers: numWorkers})
})
mux.HandleFunc("/cores", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
fmt.Fprintf(w, `{"cores":%d}`, runtime.NumCPU())
})
port := os.Getenv("PORT")
if port == "" {
port = "3000"
}
log.Printf("Go server running on port %s (cpus=%d)\\n", port, runtime.NumCPU())
if err := http.ListenAndServe(":"+port, withCORS(mux)); err != nil {
log.Fatal(err)
}
}
func withCORS(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Access-Control-Allow-Headers", "Content-Type")
w.Header().Set("Access-Control-Allow-Methods", "POST, GET, OPTIONS")
if r.Method == http.MethodOptions {
w.WriteHeader(http.StatusNoContent)
return
}
next.ServeHTTP(w, r)
})
}
func multiplyParallel(A, B [][]float64, numWorkers int) [][]float64 {
m := len(A)
n := len(A[0])
p := len(B[0])
result := make([][]float64, m)
for i := 0; i < m; i++ {
result[i] = make([]float64, p)
}
jobs := make(chan int, m)
var wg sync.WaitGroup
workerCount := numWorkers
if workerCount < 1 {
workerCount = 1
}
wg.Add(workerCount)
for w := 0; w < workerCount; w++ {
workerID := w
go func() {
defer wg.Done()
for i := range jobs {
log.Printf("core %d/%d processing row %d", workerID+1, workerCount, i)
row := A[i]
for j := 0; j < p; j++ {
var sum float64
for k := 0; k < n; k++ {
sum += row[k] * B[k][j]
}
result[i][j] = sum
}
}
}()
}
for i := 0; i < m; i++ {
jobs <- i
}
close(jobs)
wg.Wait()
return result
}