// sandbox.cpp
// Single-file vehicle sandbox (software raytracer with NEON), SDL2 window/input, Bullet Physics, ENTT ECS.
// Features:
//  - Vehicle (box) simulated by Bullet; rendered as triangle mesh
//  - Spheres simulated by Bullet; rendered analytically by ray-sphere
//  - Checkerboard ground plane
//  - Third-person camera following the vehicle
//  - WASD controls apply forces/torque to the vehicle
//  - Software raytracer (no BVH) accelerated with AArch64 NEON batch-of-4 ray kernels
//  - Basic single-light Lambert diffuse shading (per-hit shading computed after intersections)
// Build example (aarch64 Linux with NEON, adjust include/library paths):
// g++ -std=c++17 -O3 -march=armv8-a+simd -I/path/to/entt -I/path/to/bullet/include sandbox.cpp -lSDL2 -lBulletDynamics -lBulletCollision -lLinearMath -pthread -o sandbox
//
// Notes:
//  - Requires ENTT (https://github.com/skypjack/entt), Bullet (https://github.com/bulletphysics/bullet3) and SDL2.
//  - Run on an aarch64 machine with NEON support.

#include <SDL2/SDL.h>
#include <entt/entt.hpp>
#include <btBulletDynamicsCommon.h>

#include <arm_neon.h>

#include <vector>
#include <array>
#include <chrono>
#include <cmath>
#include <cstring>
#include <iostream>

// ---------- Math helpers ----------
struct Vec3 {
    float x,y,z;
    inline Vec3():x(0),y(0),z(0){}
    inline Vec3(float X,float Y,float Z):x(X),y(Y),z(Z){}
    inline Vec3 operator+(const Vec3& o) const { return Vec3(x+o.x,y+o.y,z+o.z); }
    inline Vec3 operator-(const Vec3& o) const { return Vec3(x-o.x,y-o.y,z-o.z); }
    inline Vec3 operator*(float s) const { return Vec3(x*s,y*s,z*s); }
    inline Vec3 operator/(float s) const { return Vec3(x/s,y/s,z/s); }
};
inline float dot(const Vec3 &a,const Vec3 &b){ return a.x*b.x + a.y*b.y + a.z*b.z; }
inline Vec3 cross(const Vec3 &a,const Vec3 &b){ return Vec3(a.y*b.z-a.z*b.y, a.z*b.x-a.x*b.z, a.x*b.y-a.y*b.x); }
inline float len(const Vec3 &v){ return sqrtf(dot(v,v)); }
inline Vec3 normalize(const Vec3 &v){ float L=len(v); return L>0 ? v*(1.0f/L) : Vec3(0,0,0); }

// ---------- ENTT components ----------
struct BulletBody { btRigidBody* body; };
struct RenderMesh {
    std::vector<std::array<Vec3,3>> triangles;
    float r,g,b;
};

// ---------- Camera ----------
struct Camera {
    Vec3 pos;
    Vec3 target;
    Vec3 up;
    float fov;
};

// ---------- Ray batching ----------
// Ray4 is used to represent 4 Rays in a single data structure
struct Ray4 {
    float32x4_t ox, oy, oz;
    float32x4_t dx, dy, dz;
    float32x4_t tmin, tmax;
};

struct Hit4 {
    float32x4_t t;
    uint32x4_t hit;
    float32x4_t nx, ny, nz;
    float32x4_t rr, rg, rb; // base colors
};

// helpers
static inline float32x4_t splat_f32(float v){ return vdupq_n_f32(v); }
static inline uint32x4_t splat_u32(uint32_t v){ return vdupq_n_u32(v); }

// ---------- Intersection kernels (vectorized batch-of-4) ----------

// Ray-sphere intersection (vectorized)
static inline void intersect_sphere_batch(const Ray4& R, float cx, float cy, float cz, float r, Hit4& out, float matr, float matg, float matb) {
    float32x4_t ocx = vsubq_f32(R.ox, splat_f32(cx));
    float32x4_t ocy = vsubq_f32(R.oy, splat_f32(cy));
    float32x4_t ocz = vsubq_f32(R.oz, splat_f32(cz));

    float32x4_t b = vaddq_f32(vaddq_f32(vmulq_f32(ocx, R.dx), vmulq_f32(ocy, R.dy)), vmulq_f32(ocz, R.dz));
    float32x4_t c = vsubq_f32(vaddq_f32(vaddq_f32(vmulq_f32(ocx,ocx), vmulq_f32(ocy,ocy)), vmulq_f32(ocz,ocz)), splat_f32(r*r));
    float32x4_t disc = vsubq_f32(vmulq_f32(b,b), c);

    uint32x4_t mask = vcgeq_f32(disc, vdupq_n_f32(0.0f));
    if (! (vgetq_lane_u32(mask,0) || vgetq_lane_u32(mask,1) || vgetq_lane_u32(mask,2) || vgetq_lane_u32(mask,3)) )
        return;

    float32x4_t sdisc = vsqrtq_f32(disc);
    float32x4_t t0 = vsubq_f32(vnegq_f32(b), sdisc);
    float32x4_t t1 = vaddq_f32(vnegq_f32(b), sdisc);

    const float INF = 1e30f;
    float32x4_t tsel = vdupq_n_f32(INF);
    uint32x4_t ok0 = vandq_u32(mask, vcgeq_f32(t0, R.tmin));
    ok0 = vandq_u32(ok0, vcleq_f32(t0, R.tmax));
    uint32x4_t ok1 = vandq_u32(mask, vcgeq_f32(t1, R.tmin));
    ok1 = vandq_u32(ok1, vcleq_f32(t1, R.tmax));
    tsel = vbslq_f32(ok0, t0, tsel);
    uint32x4_t use1 = vandq_u32(vmvnq_u32(ok0), ok1);
    tsel = vbslq_f32(use1, t1, tsel);

    uint32x4_t nearer = vcltq_f32(tsel, out.t);
    uint32x4_t finiteMask = vcltq_f32(tsel, splat_f32(INF));
    uint32x4_t finalMask = vandq_u32(nearer, finiteMask);

    if (vgetq_lane_u32(finalMask,0) || vgetq_lane_u32(finalMask,1) || vgetq_lane_u32(finalMask,2) || vgetq_lane_u32(finalMask,3)) {
        out.t = vbslq_f32(finalMask, tsel, out.t);
        float32x4_t px = vaddq_f32(ocx, vmulq_f32(R.dx, tsel));
        float32x4_t py = vaddq_f32(ocy, vmulq_f32(R.dy, tsel));
        float32x4_t pz = vaddq_f32(ocz, vmulq_f32(R.dz, tsel));
        float32x4_t len2 = vaddq_f32(vaddq_f32(vmulq_f32(px,px), vmulq_f32(py,py)), vmulq_f32(pz,pz));
        float32x4_t invlen = vrsqrteq_f32(len2);
        invlen = vmulq_f32(vrsqrtsq_f32(vmulq_f32(len2, invlen), invlen), invlen);
        out.nx = vmulq_f32(px, invlen);
        out.ny = vmulq_f32(py, invlen);
        out.nz = vmulq_f32(pz, invlen);
        out.rr = vbslq_f32(finalMask, splat_f32(matr), out.rr);
        out.rg = vbslq_f32(finalMask, splat_f32(matg), out.rg);
        out.rb = vbslq_f32(finalMask, splat_f32(matb), out.rb);
        out.hit = vorrq_u32(out.hit, finalMask);
    }
}

// Fixed Ray-triangle intersection (Moller-Trumbore) vectorized batch-of-4
static inline void intersect_triangle_batch(const Ray4& R,
                                           const Vec3& v0, const Vec3& v1, const Vec3& v2,
                                           const float colr, const float colg, const float colb,
                                           Hit4& out)
{
    // compute the edges of the triangle in reference to v0
    Vec3 e1 = v1 - v0;
    Vec3 e2 = v2 - v0;

    // load vectors with each of the component values of the triangle. It is composed of one vector
    // and two edges
    float32x4_t v0x = splat_f32(v0.x), v0y = splat_f32(v0.y), v0z = splat_f32(v0.z);
    float32x4_t e1x = splat_f32(e1.x), e1y = splat_f32(e1.y), e1z = splat_f32(e1.z);
    float32x4_t e2x = splat_f32(e2.x), e2y = splat_f32(e2.y), e2z = splat_f32(e2.z);

    // compute the cross product of each ray in R and the second edge 
    float32x4_t hx = vsubq_f32(vmulq_f32(R.dy, e2z), vmulq_f32(R.dz, e2y));
    float32x4_t hy = vsubq_f32(vmulq_f32(R.dz, e2x), vmulq_f32(R.dx, e2z));
    float32x4_t hz = vsubq_f32(vmulq_f32(R.dx, e2y), vmulq_f32(R.dy, e2x));

    // compute the dot product of the first edge and the result from the prior step
    float32x4_t a = vaddq_f32(vaddq_f32(vmulq_f32(e1x, hx), vmulq_f32(e1y, hy)), vmulq_f32(e1z, hz));

    const float EPS_F = 1e-6f; // this epsilon in the code, the smallest value that can be computed before zero is reached
    float32x4_t EPS = splat_f32(EPS_F);
    float32x4_t absa = vabsq_f32(a); // get the absolute value of the dot product
    // check if the absolute value is less than epsilon
    uint32x4_t parallelMask = vcltq_f32(absa, EPS);

    float32x4_t f = vrecpeq_f32(a); // compute the estimated reciprocal
    f = vmulq_f32(f, vrecpsq_f32(a, f));  // compute the estimated square root then multiple it by the estimated reciprocal

    // compute the difference between each ray in R and the vertex of the triangle
    float32x4_t sx = vsubq_f32(R.ox, v0x);
    float32x4_t sy = vsubq_f32(R.oy, v0y);
    float32x4_t sz = vsubq_f32(R.oz, v0z);

    // compute the dot product of h and s then multiply by f
    float32x4_t u = vmulq_f32(f, vaddq_f32(vaddq_f32(vmulq_f32(sx, hx), vmulq_f32(sy, hy)), vmulq_f32(sz, hz)));

    uint32x4_t uLess0 = vcltq_f32(u, splat_f32(0.0f));
    uint32x4_t uGreater1 = vcgtq_f32(u, splat_f32(1.0f));
    // use bitwise or to combine the results. If the bit is set then u is less than zero or greater than 1
    uint32x4_t uOutsideMask = vorrq_u32(uLess0, uGreater1);

    // compute the cross product of s and edge 1
    float32x4_t qx = vsubq_f32(vmulq_f32(sy, e1z), vmulq_f32(sz, e1y));
    float32x4_t qy = vsubq_f32(vmulq_f32(sz, e1x), vmulq_f32(sx, e1z));
    float32x4_t qz = vsubq_f32(vmulq_f32(sx, e1y), vmulq_f32(sy, e1x));

    // compute the dot product of each vector in R and q, then multiply by f
    float32x4_t v = vmulq_f32(f, vaddq_f32(vaddq_f32(vmulq_f32(R.dx, qx), vmulq_f32(R.dy, qy)), vmulq_f32(R.dz, qz)));

    uint32x4_t vLess0 = vcltq_f32(v, splat_f32(0.0f));
    uint32x4_t vGreater1 = vcgtq_f32(v, splat_f32(1.0f));
    // use bit wise or to combine the results. If the bit is set then v is less than zero or greater than 1
    uint32x4_t vOutsideMask = vorrq_u32(vLess0, vGreater1);

    // check if the sum of u and v is greater than 1
    float32x4_t uplusv = vaddq_f32(u, v);
    uint32x4_t uvGreater1 = vcgtq_f32(uplusv, splat_f32(1.0f));

    // computation of t
    float32x4_t t = vmulq_f32(f, vaddq_f32(vaddq_f32(vmulq_f32(e2x, qx), vmulq_f32(e2y, qy)), vmulq_f32(e2z, qz)));

    // bounds check of t against the values in R
    uint32x4_t tGeMin = vcgeq_f32(t, R.tmin);
    uint32x4_t tLeMax = vcleq_f32(t, R.tmax);

    // compute the bitwise not of the parallel mask
    uint32x4_t valid = vmvnq_u32(parallelMask);
    // use bitwise and to combine each check together
    // for the out of bounds checks vmvnq_u32 is applied first to invert the result
    valid = vandq_u32(valid, vmvnq_u32(uOutsideMask));
    valid = vandq_u32(valid, vmvnq_u32(vOutsideMask));
    valid = vandq_u32(valid, vmvnq_u32(uvGreater1));
    valid = vandq_u32(valid, tGeMin);
    valid = vandq_u32(valid, tLeMax);

    uint32x4_t nearer = vcltq_f32(t, out.t);
    uint32x4_t result = vandq_u32(valid, nearer);

    // if any of the pixels have a hit copy into the output hit4 vector
    if (vgetq_lane_u32(result,0) || vgetq_lane_u32(result,1) || vgetq_lane_u32(result,2) || vgetq_lane_u32(result,3)) {
        out.t = vbslq_f32(result, t, out.t);
        Vec3 n = normalize(cross(e1,e2));
        out.nx = vbslq_f32(result, splat_f32(n.x), out.nx);
        out.ny = vbslq_f32(result, splat_f32(n.y), out.ny);
        out.nz = vbslq_f32(result, splat_f32(n.z), out.nz);
        out.rr = vbslq_f32(result, splat_f32(colr), out.rr);
        out.rg = vbslq_f32(result, splat_f32(colg), out.rg);
        out.rb = vbslq_f32(result, splat_f32(colb), out.rb);
        out.hit = vorrq_u32(out.hit, result);
    }
}

// ---------- Pixel store ----------
static inline void store_pixel(uint32_t* pixels, int idx, float r, float g, float b) {
    uint8_t R = (uint8_t)fminf(255.0f, fmaxf(0.0f, r*255.0f));
    uint8_t G = (uint8_t)fminf(255.0f, fmaxf(0.0f, g*255.0f));
    uint8_t B = (uint8_t)fminf(255.0f, fmaxf(0.0f, b*255.0f));    
    pixels[idx] = (0xFFu << 24) | (R << 16) | (G << 8) | B;
}

// ---------- Globals ----------
const int WIDTH = 800;
const int HEIGHT = 600;

static void applyTransformToMesh(const RenderMesh& mesh, const btTransform& tr, std::vector<std::array<Vec3,3>>& outTris) {
    outTris.clear();
    outTris.reserve(mesh.triangles.size());
    btScalar m[16];
    tr.getOpenGLMatrix(m);
    for (auto &tri : mesh.triangles) {
        std::array<Vec3,3> t;
        for (int i=0;i<3;i++){
            Vec3 v = tri[i];
            float x = v.x*m[0] + v.y*m[4] + v.z*m[8] + m[12];
            float y = v.x*m[1] + v.y*m[5] + v.z*m[9] + m[13];
            float z = v.x*m[2] + v.y*m[6] + v.z*m[10] + m[14];
            t[i] = Vec3(x,y,z);
        }
        outTris.push_back(t);
    }
}

// ---------- Main ----------
int main(int argc, char** argv) {
    if (SDL_Init(SDL_INIT_VIDEO) != 0) {
        std::cerr << "SDL_Init failed: " << SDL_GetError() << "\n";
        return -1;
    }
    SDL_Window* window = SDL_CreateWindow("Vehicle Sandbox (NEON raytracer with simple shading)",
                                          SDL_WINDOWPOS_CENTERED, SDL_WINDOWPOS_CENTERED,
                                          WIDTH, HEIGHT, SDL_WINDOW_SHOWN);
    if (!window) { std::cerr<<"CreateWindow failed\n"; return -1; }
    SDL_Renderer* renderer = SDL_CreateRenderer(window, -1, SDL_RENDERER_SOFTWARE);
    if (!renderer) { std::cerr<<"CreateRenderer failed\n"; return -1; }
    SDL_Texture* texture = SDL_CreateTexture(renderer, SDL_PIXELFORMAT_ARGB8888, SDL_TEXTUREACCESS_STREAMING, WIDTH, HEIGHT);
    if (!texture) { std::cerr<<"CreateTexture failed\n"; return -1; }

    // Bullet setup
    btDefaultCollisionConfiguration* collisionConfiguration = new btDefaultCollisionConfiguration();
    btCollisionDispatcher* dispatcher = new btCollisionDispatcher(collisionConfiguration);
    btDbvtBroadphase* overlappingPairCache = new btDbvtBroadphase();
    btSequentialImpulseConstraintSolver* solver = new btSequentialImpulseConstraintSolver;
    btDiscreteDynamicsWorld* dynamicsWorld = new btDiscreteDynamicsWorld(dispatcher, overlappingPairCache, solver, collisionConfiguration);
    dynamicsWorld->setGravity(btVector3(0,-9.81f,0));

    entt::registry reg;

    // Ground plane
    {
        btCollisionShape* groundShape = new btStaticPlaneShape(btVector3(0,1,0), 0);
        btDefaultMotionState* groundMotion = new btDefaultMotionState(btTransform(btQuaternion(0,0,0,1), btVector3(0,0,0)));
        btRigidBody::btRigidBodyConstructionInfo groundRigidBodyCI(0, groundMotion, groundShape, btVector3(0,0,0));
        btRigidBody* groundBody = new btRigidBody(groundRigidBodyCI);
        dynamicsWorld->addRigidBody(groundBody);
        auto e = reg.create();
        reg.emplace<BulletBody>(e, BulletBody{groundBody});
    }

    // Vehicle mesh + body
    entt::entity vehicle_entity;
    RenderMesh vehicleMesh;
    {
        float hx=1.0f, hy=0.5f, hz=2.0f;
        std::array<Vec3,8> v = {
            Vec3(-hx, -hy, -hz), Vec3(hx, -hy, -hz), Vec3(hx, hy, -hz), Vec3(-hx, hy, -hz),
            Vec3(-hx, -hy, hz),  Vec3(hx, -hy, hz),  Vec3(hx, hy, hz),  Vec3(-hx, hy, hz)
        };
        int idxs[] = {
            0,1,2, 0,2,3,
            4,6,5, 4,7,6,
            0,4,5, 0,5,1,
            3,2,6, 3,6,7,
            1,5,6, 1,6,2,
            0,3,7, 0,7,4
        };
        for (int i=0;i<12;i++){
            std::array<Vec3,3> tri = { v[idxs[i*3+0]], v[idxs[i*3+1]], v[idxs[i*3+2]] };
            vehicleMesh.triangles.push_back(tri);
        }
        vehicleMesh.r=0.25f; vehicleMesh.g=0.25f; vehicleMesh.b=0.85f;

        btCollisionShape* boxShape = new btBoxShape(btVector3(hx, hy, hz));
        btTransform startTransform; startTransform.setIdentity(); startTransform.setOrigin(btVector3(0,2,0));
        btScalar mass = 800.0f;
        btVector3 localInertia(0,0,0);
        boxShape->calculateLocalInertia(mass, localInertia);
        btDefaultMotionState* motionState = new btDefaultMotionState(startTransform);
        btRigidBody::btRigidBodyConstructionInfo rbInfo(mass, motionState, boxShape, localInertia);
        btRigidBody* body = new btRigidBody(rbInfo);
        body->setActivationState(DISABLE_DEACTIVATION);
        dynamicsWorld->addRigidBody(body);

        vehicle_entity = reg.create();
        reg.emplace<BulletBody>(vehicle_entity, BulletBody{body});
        reg.emplace<RenderMesh>(vehicle_entity, vehicleMesh);
    }

    // Spheres
    std::vector<entt::entity> spheres;
    {
        const int NUM = 6;
        for (int i=0;i<NUM;i++){
            float radius = 0.5f + (i%3)*0.15f;
            btCollisionShape* sphereShape = new btSphereShape(radius);
            btTransform start; start.setIdentity();
            start.setOrigin(btVector3(-6.0f + i*2.4f, 4.0f + i*0.2f, -1.0f + (i%2)*2.0f));
            btScalar mass = 5.0f + i;
            btVector3 inertia(0,0,0);
            sphereShape->calculateLocalInertia(mass, inertia);
            btDefaultMotionState* ms = new btDefaultMotionState(start);
            btRigidBody::btRigidBodyConstructionInfo info(mass, ms, sphereShape, inertia);
            btRigidBody* body = new btRigidBody(info);
            dynamicsWorld->addRigidBody(body);

            auto e = reg.create();
            reg.emplace<BulletBody>(e, BulletBody{body});
            RenderMesh rm;
            rm.r = 0.85f; rm.g = 0.35f; rm.b = 0.35f;
            reg.emplace<RenderMesh>(e, rm);
            spheres.push_back(e);
        }
    }

    // Camera
    Camera cam;
    cam.pos = Vec3(0,4,-8);
    cam.target = Vec3(0,0,0);
    cam.up = Vec3(0,1,0);
    cam.fov = 60.0f;

    // Light (single directional light approximating distant light)
    Vec3 lightDir = normalize(Vec3(-0.5f, -1.0f, -0.3f)); // direction from surface to light (pointing towards light)
    float ambient = 0.15f;
    float diffuseMul = 0.85f;

    uint32_t* pixels = new uint32_t[WIDTH * HEIGHT];

    auto last = std::chrono::high_resolution_clock::now();
    const auto start = last;
    bool running = true;
    bool keyW=false,keyA=false,keyS=false,keyD=false;
    uint32_t frames = 0;

    while (running) {
        SDL_Event ev;
        while (SDL_PollEvent(&ev)) {
            if (ev.type == SDL_QUIT) running=false;
            if (ev.type == SDL_KEYDOWN || ev.type == SDL_KEYUP) {
                bool down = (ev.type==SDL_KEYDOWN);
                switch (ev.key.keysym.sym) {
                    case SDLK_ESCAPE: running=false; break;
                    case SDLK_w: keyW=down; break;
                    case SDLK_s: keyS=down; break;
                    case SDLK_a: keyA=down; break;
                    case SDLK_d: keyD=down; break;
                }
            }
        }

        auto now = std::chrono::high_resolution_clock::now();
        float dt = std::chrono::duration<float>(now - last).count();
        if (dt>0.05f) dt = 0.05f;
        last = now;

        // Controls: apply forces/torque to vehicle
        {
            auto &bb = reg.get<BulletBody>(vehicle_entity);
            btRigidBody* vehicleBody = bb.body;
            btTransform tr; vehicleBody->getMotionState()->getWorldTransform(tr);
            btVector3 forward = tr.getBasis() * btVector3(0,0,1);
            btVector3 force(0,0,0);
            float driveForce = 6000.0f;
            if (keyW) force += forward * (driveForce);
            if (keyS) force += forward * (-driveForce*0.6f);
            vehicleBody->applyCentralForce(force);
            float steerTorque = 10000.0f;
            if (keyA) vehicleBody->applyTorque(btVector3(0, steerTorque, 0));
            if (keyD) vehicleBody->applyTorque(btVector3(0, -steerTorque, 0));
        }

        dynamicsWorld->stepSimulation(dt, 10);

        // Camera follow vehicle
        {
            auto &bb = reg.get<BulletBody>(vehicle_entity);
            btTransform tr; bb.body->getMotionState()->getWorldTransform(tr);
            btVector3 pos = tr.getOrigin();
            btVector3 back = tr.getBasis() * btVector3(0,0,-1);
            btVector3 camTarget = pos;
            btVector3 camPosVec = pos + btVector3(0,1.5,0) + back * 6.0;
            cam.pos = Vec3(camPosVec.x(), camPosVec.y(), camPosVec.z());
            cam.target = Vec3(camTarget.x(), camTarget.y(), camTarget.z());
            cam.up = Vec3(0,1,0);
        }

        // Collect transformed vehicle triangles
        std::vector<std::array<Vec3,3>> vehicleTrisWorld;
        {
            auto &rm = reg.get<RenderMesh>(vehicle_entity);
            auto &bb = reg.get<BulletBody>(vehicle_entity);
            btTransform tr; 
            bb.body->getMotionState()->getWorldTransform(tr);
            applyTransformToMesh(rm, tr, vehicleTrisWorld);
        }

        // Gather sphere info
        struct SpherePrim { Vec3 center; float r; float cr,cg,cb; };
        std::vector<SpherePrim> spherePrims;
        for (auto e : spheres) {
            auto &bb = reg.get<BulletBody>(e);
            btTransform tr; 
            bb.body->getMotionState()->getWorldTransform(tr);
            btVector3 origin = tr.getOrigin();
            btCollisionShape* shape = bb.body->getCollisionShape();
            float radius = 0.5f;
            if (shape->getShapeType() == SPHERE_SHAPE_PROXYTYPE) {
                btSphereShape* s = static_cast<btSphereShape*>(shape);
                radius = s->getRadius();
            }
            auto &rm = reg.get<RenderMesh>(e);
            spherePrims.push_back({ Vec3(origin.x(), origin.y(), origin.z()), radius, rm.r, rm.g, rm.b });
        }

        // Camera basis
        Vec3 fwd = normalize(cam.target - cam.pos);
        Vec3 right = normalize(cross(fwd, cam.up));
        Vec3 up = normalize(cross(right, fwd));
        float aspect = float(WIDTH)/float(HEIGHT);
        float scale = tanf((cam.fov * 0.5f) * (3.14159265f/180.0f));

        // Raytrace pixels in batches of 4 horizontally
        for (int y=0;y<HEIGHT;y++){
            float sy = (1 - 2 * ((y + 0.5f) / (float)HEIGHT)) * scale;
            for (int x=0;x<WIDTH; x+=4) {
                Ray4 R;
                // create arrays for each component with the batch size
                float ox_arr[4], oy_arr[4], oz_arr[4];
                float dx_arr[4], dy_arr[4], dz_arr[4];
                float tmin_arr[4], tmax_arr[4];
                for (int k=0;k<4;k++){
                    int px = x+k;
                    if (px>=WIDTH) {
                         px = WIDTH-1;
                    }
                    float sx_pix = (2 * ((px + 0.5f) / (float)WIDTH) - 1) * aspect * scale;
                    float rx = sx_pix;
                    float ry = sy;
                    // compute the normalized direction based off the camera direction
                    Vec3 dir = normalize(fwd + right*rx + up*ry);
                    // initialize to the camera position
                    ox_arr[k] = cam.pos.x;
                    oy_arr[k] = cam.pos.y;
                    oz_arr[k] = cam.pos.z;
                    // initialize to the normalized direction
                    dx_arr[k] = dir.x;
                    dy_arr[k] = dir.y;
                    dz_arr[k] = dir.z;
                    tmin_arr[k] = 0.001f;
                    tmax_arr[k] = 1e30f;
                }
                // copy each of the values from the arrays into the ray vector
                R.ox = vld1q_f32(ox_arr);
                R.oy = vld1q_f32(oy_arr);
                R.oz = vld1q_f32(oz_arr);
                R.dx = vld1q_f32(dx_arr);
                R.dy = vld1q_f32(dy_arr);
                R.dz = vld1q_f32(dz_arr);
                R.tmin = vld1q_f32(tmin_arr);
                R.tmax = vld1q_f32(tmax_arr);

                // initialize a hit4 vector which stores the results of the hit
                Hit4 hit;
                hit.t = splat_f32(1e30f);
                hit.hit = splat_u32(0);
                hit.nx = hit.ny = hit.nz = splat_f32(0.0f);
                hit.rr = hit.rg = hit.rb = splat_f32(0.0f);

                // Spheres
                for (auto &sp : spherePrims) {
                    intersect_sphere_batch(R, sp.center.x, sp.center.y, sp.center.z, sp.r, hit, sp.cr, sp.cg, sp.cb);
                }

                // Vehicle triangles
                for (auto &tri : vehicleTrisWorld) {
                    intersect_triangle_batch(R, tri[0], tri[1], tri[2], vehicleMesh.r, vehicleMesh.g, vehicleMesh.b, hit);
                }

                // Ground plane (y=0)
                float32x4_t tplane = vdivq_f32(vnegq_f32(R.oy), R.dy);
                uint32x4_t planeValid = vandq_u32(vcgeq_f32(tplane, R.tmin), vcleq_f32(tplane, R.tmax));
                uint32x4_t closer = vcltq_f32(tplane, hit.t);
                uint32x4_t planeMask = vandq_u32(planeValid, closer);
                if (vgetq_lane_u32(planeMask,0) || vgetq_lane_u32(planeMask,1) || vgetq_lane_u32(planeMask,2) || vgetq_lane_u32(planeMask,3)) {
                    float32x4_t px = vaddq_f32(R.ox, vmulq_f32(R.dx, tplane));
                    float32x4_t pz = vaddq_f32(R.oz, vmulq_f32(R.dz, tplane));
                    float px_s[4], pz_s[4];
                    vst1q_f32(px_s, px);
                    vst1q_f32(pz_s, pz);
                    float cr[4], cg[4], cb[4];
                    for (int k=0;k<4;k++){
                        int xi = (int)floorf(px_s[k]);
                        int zi = (int)floorf(pz_s[k]);
                        bool white = ((xi + zi) & 1) == 0;
                        if (white) { cr[k]=0.85f; cg[k]=0.85f; cb[k]=0.8f; }
                        else      { cr[k]=0.2f; cg[k]=0.2f; cb[k]=0.25f; }
                    }
                    float32x4_t crv = vld1q_f32(cr);
                    float32x4_t cgv = vld1q_f32(cg);
                    float32x4_t cbv = vld1q_f32(cb);
                    hit.t = vbslq_f32(planeMask, tplane, hit.t);
                    hit.nx = vbslq_f32(planeMask, splat_f32(0.0f), hit.nx);
                    hit.ny = vbslq_f32(planeMask, splat_f32(1.0f), hit.ny);
                    hit.nz = vbslq_f32(planeMask, splat_f32(0.0f), hit.nz);
                    hit.rr = vbslq_f32(planeMask, crv, hit.rr);
                    hit.rg = vbslq_f32(planeMask, cgv, hit.rg);
                    hit.rb = vbslq_f32(planeMask, cbv, hit.rb);
                    hit.hit = vorrq_u32(hit.hit, planeMask);
                }

                // At this point we have per-lane: hit.hit, hit.t, hit.nx/ny/nz, hit.rr/rg/rb (base colors)
                // Compute simple Lambert shading per lane: color = base * (ambient + diffuseMul * max(0, dot(n, lightDir)))
                float nxv[4], nyv[4], nzv[4];
                float br[4], bg[4], bbv[4];
                uint32_t hits[4];
                vst1q_f32(nxv, hit.nx);
                vst1q_f32(nyv, hit.ny);
                vst1q_f32(nzv, hit.nz);
                vst1q_f32(br, hit.rr);
                vst1q_f32(bg, hit.rg);
                vst1q_f32(bbv, hit.rb);
                vst1q_u32(hits, hit.hit);

                for (int k=0;k<4;k++){
                    int px = x+k; if (px>=WIDTH) continue;
                    int idx = y*WIDTH + px;
                    if (hits[k]) {
                        // normal vector may be zero for missed lanes; clamp
                        Vec3 n(nxv[k], nyv[k], nzv[k]);
                        float nlen = len(n);
                        if (nlen > 0.0f) {
                            n = n / nlen;
                        } else {
                            n = Vec3(0,1,0); // fallback normal
                        }
                        float ldot = dot(n, lightDir);
                        if (ldot < 0.0f) ldot = 0.0f;
                        float shade = ambient + diffuseMul * ldot;
                        float finalr = br[k] * shade;
                        float finalg = bg[k] * shade;
                        float finalb = bbv[k] * shade;
                        // simple tone mapping clamp
                        store_pixel(pixels, idx, finalr, finalg, finalb);
                    } else {
                        // sky
                        store_pixel(pixels, idx, 1.0f, 0.2f, 0.99f);
                    }
                }
            }
        }

        SDL_UpdateTexture(texture, NULL, pixels, WIDTH * sizeof(uint32_t));
        SDL_RenderClear(renderer);
        SDL_RenderCopy(renderer, texture, NULL, NULL);
        SDL_RenderPresent(renderer);
        frames++;
    }
    const auto end = std::chrono::high_resolution_clock::now();

    const std::chrono::duration<double, std::milli> elapsed = end - start;
    std::cout << "frames rendered: " << frames << "; elapsed time: " << elapsed.count() << std::endl;
    const double frame_rate = static_cast<double>(frames) / ( elapsed.count() / 1000.0 );
    std::cout << "average frame rate: " << frame_rate << std::endl;

    delete[] pixels;
    SDL_DestroyTexture(texture);
    SDL_DestroyRenderer(renderer);
    SDL_DestroyWindow(window);
    SDL_Quit();

    // Cleanup Bullet
    int numCollisionObjects = dynamicsWorld->getNumCollisionObjects();
    for (int i = numCollisionObjects - 1; i >= 0; i--) {
        btCollisionObject* obj = dynamicsWorld->getCollisionObjectArray()[i];
        btRigidBody* body = btRigidBody::upcast(obj);
        if (body && body->getMotionState()) {
            delete body->getMotionState();
        }
        dynamicsWorld->removeCollisionObject(obj);
        delete obj;
    }
    delete dynamicsWorld;
    delete solver;
    delete overlappingPairCache;
    delete dispatcher;
    delete collisionConfiguration;

    return 0;
}
