// SPDX-FileCopyrightText: Copyright 2022 Citra Emulator Project
// SPDX-License-Identifier: GPL-2.0-or-later

#pragma once

#include <string_view>

namespace HostShaders {

constexpr std::string_view TILING_COMP = R"shader_src(
// SPDX-FileCopyrightText: Copyright 2025 shadPS4 Emulator Project
// SPDX-License-Identifier: GPL-2.0-or-later

#version 450 core

#extension GL_GOOGLE_include_directive : require
#extension GL_EXT_scalar_block_layout : require
#extension GL_EXT_shader_explicit_arithmetic_types : require

layout (local_size_x = 64, local_size_y = 1, local_size_z = 1) in;

// #define BITS_PER_PIXEL
// #define NUM_SAMPLES
// #define MICRO_TILE_MODE
// #define ARRAY_MODE
// #define MICRO_TILE_THICKNESS
// #define PIPE_CONFIG
// #define BANK_WIDTH
// #define BANK_HEIGHT
// #define NUM_BANKS
// #define NUM_BANK_BITS
// #define TILE_SPLIT_BYTES
// #define MACRO_TILE_ASPECT

#define BYTES_PER_PIXEL (BITS_PER_PIXEL / 8)

#if BITS_PER_PIXEL == 8
#define BLOCK_TYPE uint8_t
#elif BITS_PER_PIXEL == 16
#define BLOCK_TYPE uint16_t
#elif BITS_PER_PIXEL == 32
#define BLOCK_TYPE uint32_t
#elif BITS_PER_PIXEL == 64
#define BLOCK_TYPE u32vec2
#elif BITS_PER_PIXEL == 96
#define BLOCK_TYPE u32vec3
#else
#define BLOCK_TYPE u32vec4
#endif

#if PIPE_CONFIG == ADDR_SURF_P2
#define NUM_PIPES 2
#define NUM_PIPE_BITS 1
#else
#define NUM_PIPES 8
#define NUM_PIPE_BITS 3
#endif

#define MICRO_TILE_WIDTH 8
#define MICRO_TILE_HEIGHT 8
#define MICRO_TILE_PIXELS (MICRO_TILE_WIDTH * MICRO_TILE_HEIGHT)
#define MICRO_TILE_BITS (MICRO_TILE_PIXELS * MICRO_TILE_THICKNESS * BITS_PER_PIXEL * NUM_SAMPLES)
#define MICRO_TILE_BYTES (MICRO_TILE_BITS / 8)

#define NUM_PIPE_INTERLEAVE_BITS 8

#define ADDR_SURF_DISPLAY_MICRO_TILING 0
#define ADDR_SURF_THIN_MICRO_TILING 1
#define ADDR_SURF_DEPTH_MICRO_TILING 2
#define ADDR_SURF_ROTATED_MICRO_TILING 3

#define ARRAY_LINEAR_GENERAL 0
#define ARRAY_LINEAR_ALIGNED 1
#define ARRAY_1D_TILED_THIN1 2
#define ARRAY_1D_TILED_THICK 3
#define ARRAY_2D_TILED_THIN1 4
#define ARRAY_PRT_TILED_THIN1 5
#define ARRAY_PRT_2D_TILED_THIN1 6
#define ARRAY_2D_TILED_THICK 7
#define ARRAY_2D_TILED_XTHICK 8
#define ARRAY_PRT_TILED_THICK 9
#define ARRAY_PRT_2D_TILED_THICK 10
#define ARRAY_PRT_3D_TILED_THIN1 11
#define ARRAY_3D_TILED_THIN1 12
#define ARRAY_3D_TILED_THICK 13
#define ARRAY_3D_TILED_XTHICK 14
#define ARRAY_PRT_3D_TILED_THICK 15

#define	ADDR_SURF_P2 0
#define	ADDR_SURF_P8_32x32_8x16	10
#define	ADDR_SURF_P8_32x32_16x16 12

#define BITS_PER_BYTE 8
#define BITS_TO_BYTES(x) (((x) + (BITS_PER_BYTE-1)) / BITS_PER_BYTE)

#define _BIT(v, b) bitfieldExtract((v), (b), 1)

struct MipInfo {
    uint size;
    uint pitch;
    uint height;
    uint offset;
};

layout (set = 0, binding = 0, scalar) buffer InputBuf {
    BLOCK_TYPE tiled_data[];
};

layout (set = 0, binding = 1, scalar) buffer OutputBuf {
    BLOCK_TYPE linear_data[];
};

layout (set = 0, binding = 2, scalar) uniform TilingInfo {
    uint bank_swizzle;
    uint num_slices;
    uint num_mips;
    MipInfo mips[16];
} info;

uint32_t ComputePixelIndexWithinMicroTile(uint32_t x, uint32_t y, uint32_t z) {
    uint32_t p0 = 0;
    uint32_t p1 = 0;
    uint32_t p2 = 0;
    uint32_t p3 = 0;
    uint32_t p4 = 0;
    uint32_t p5 = 0;
    uint32_t p6 = 0;
    uint32_t p7 = 0;
    uint32_t p8 = 0;

    uint32_t x0 = _BIT(x, 0);
    uint32_t x1 = _BIT(x, 1);
    uint32_t x2 = _BIT(x, 2);
    uint32_t y0 = _BIT(y, 0);
    uint32_t y1 = _BIT(y, 1);
    uint32_t y2 = _BIT(y, 2);
    uint32_t z0 = _BIT(z, 0);
    uint32_t z1 = _BIT(z, 1);
    uint32_t z2 = _BIT(z, 2);

#if MICRO_TILE_MODE == ADDR_SURF_DISPLAY_MICRO_TILING
    #if BITS_PER_PIXEL == 8
        p0 = x0;
        p1 = x1;
        p2 = x2;
        p3 = y1;
        p4 = y0;
        p5 = y2;
    #elif BITS_PER_PIXEL == 16
        p0 = x0;
        p1 = x1;
        p2 = x2;
        p3 = y0;
        p4 = y1;
        p5 = y2;
    #elif BITS_PER_PIXEL == 32
        p0 = x0;
        p1 = x1;
        p2 = y0;
        p3 = x2;
        p4 = y1;
        p5 = y2;
    #elif BITS_PER_PIXEL == 64
        p0 = x0;
        p1 = y0;
        p2 = x1;
        p3 = x2;
        p4 = y1;
        p5 = y2;
    #elif BITS_PER_PIXEL == 128
        p0 = y0;
        p1 = x0;
        p2 = x1;
        p3 = x2;
        p4 = y1;
        p5 = y2;
    #endif
#elif MICRO_TILE_MODE == ADDR_SURF_THIN_MICRO_TILING || MICRO_TILE_MODE == ADDR_SURF_DEPTH_MICRO_TILING
        p0 = x0;
        p1 = y0;
        p2 = x1;
        p3 = y1;
        p4 = x2;
        p5 = y2;
#else
    #if BITS_PER_PIXEL == 8 || BITS_PER_PIXEL == 16
        p0 = x0;
        p1 = y0;
        p2 = x1;
        p3 = y1;
        p4 = z0;
        p5 = z1;
    #elif BITS_PER_PIXEL == 32
        p0 = x0;
        p1 = y0;
        p2 = x1;
        p3 = z0;
        p4 = y1;
        p5 = z1;
    #elif BITS_PER_PIXEL == 64 || BITS_PER_PIXEL == 128
        p0 = x0;
        p1 = y0;
        p2 = z0;
        p3 = x1;
        p4 = y1;
        p5 = z1;
    #endif
        p6 = x2;
        p7 = y2;

    #if MICRO_TILE_THICKNESS == 8
        p8 = z2;
    #endif
#endif

    uint32_t pixel_number =
        ((p0) | (p1 << 1) | (p2 << 2) | (p3 << 3) | (p4 << 4) |
         (p5 << 5) | (p6 << 6) | (p7 << 7) | (p8 << 8));

    return pixel_number;
}

#if ARRAY_MODE == ARRAY_1D_TILED_THIN1 || ARRAY_MODE == ARRAY_1D_TILED_THICK
uint32_t ComputeSurfaceAddrFromCoordMicroTiled(uint32_t x, uint32_t y, uint32_t slice, uint32_t pitch, uint32_t height, uint32_t sample_index) {
    uint32_t slice_bytes = BITS_TO_BYTES(pitch * height * MICRO_TILE_THICKNESS * BITS_PER_PIXEL * NUM_SAMPLES);

    uint32_t micro_tiles_per_row = pitch / MICRO_TILE_WIDTH;
    uint32_t micro_tile_index_x = x / MICRO_TILE_WIDTH;
    uint32_t micro_tile_index_y = y / MICRO_TILE_HEIGHT;
    uint32_t micro_tile_index_z = slice / MICRO_TILE_THICKNESS;

    uint32_t slice_offset = micro_tile_index_z * slice_bytes;
    uint32_t micro_tile_offset = (micro_tile_index_y * micro_tiles_per_row + micro_tile_index_x) * MICRO_TILE_BYTES;

    uint32_t pixel_index = ComputePixelIndexWithinMicroTile(x, y, slice);

    uint32_t sample_offset;
    uint32_t pixel_offset;
#if MICRO_TILE_MODE == ADDR_SURF_DEPTH_MICRO_TILING
    sample_offset = sample_index * BITS_PER_PIXEL;
    pixel_offset = pixel_index * BITS_PER_PIXEL * NUM_SAMPLES;
#else
    sample_offset = sample_index * (MICRO_TILE_BYTES * 8 / NUM_SAMPLES);
    pixel_offset = pixel_index * BITS_PER_PIXEL;
#endif

    uint32_t elem_offset = (sample_offset + pixel_offset) / 8;
    return slice_offset + micro_tile_offset + elem_offset;
}
#else
uint32_t ComputePipeFromCoord(uint32_t x, uint32_t y, uint32_t slice) {
    uint32_t p0 = 0;
    uint32_t p1 = 0;
    uint32_t p2 = 0;

    uint32_t tx = x / MICRO_TILE_WIDTH;
    uint32_t ty = y / MICRO_TILE_HEIGHT;
    uint32_t x3 = _BIT(tx, 0);
    uint32_t x4 = _BIT(tx, 1);
    uint32_t x5 = _BIT(tx, 2);
    uint32_t y3 = _BIT(ty, 0);
    uint32_t y4 = _BIT(ty, 1);
    uint32_t y5 = _BIT(ty, 2);

#if PIPE_CONFIG == ADDR_SURF_P2
    p0 = x3 ^ y3;
#elif PIPE_CONFIG == ADDR_SURF_P8_32x32_8x16
    p0 = x4 ^ y3 ^ x5;
    p1 = x3 ^ y4;
    p2 = x5 ^ y5;
#elif PIPE_CONFIG == ADDR_SURF_P8_32x32_16x16
    p0 = x3 ^ y3 ^ x4;
    p1 = x4 ^ y4;
    p2 = x5 ^ y5;
#endif

    uint32_t pipe = p0 | (p1 << 1) | (p2 << 2);

    uint32_t pipe_swizzle = 0;
#if ARRAY_MODE == ARRAY_3D_TILED_THIN1 || ARRAY_MODE == ARRAY_3D_TILED_THICK || ARRAY_MODE == ARRAY_3D_TILED_XTHICK
    pipe_swizzle += max(1, NUM_PIPES / 2 - 1) * (slice / MICRO_TILE_THICKNESS);
#endif
    pipe_swizzle &= (NUM_PIPES - 1);
    pipe = pipe ^ pipe_swizzle;
    return pipe;
}

uint32_t ComputeBankFromCoord(uint32_t x, uint32_t y, uint32_t slice, uint32_t tile_split_slice) {
    uint32_t b0 = 0;
    uint32_t b1 = 0;
    uint32_t b2 = 0;
    uint32_t b3 = 0;
    uint32_t slice_rotation = 0;
    uint32_t tile_split_rotation = 0;

    uint32_t tx = x / MICRO_TILE_WIDTH / (BANK_WIDTH * NUM_PIPES);
    uint32_t ty = y / MICRO_TILE_HEIGHT / BANK_HEIGHT;

    uint32_t x3 = _BIT(tx, 0);
    uint32_t x4 = _BIT(tx, 1);
    uint32_t x5 = _BIT(tx, 2);
    uint32_t x6 = _BIT(tx, 3);
    uint32_t y3 = _BIT(ty, 0);
    uint32_t y4 = _BIT(ty, 1);
    uint32_t y5 = _BIT(ty, 2);
    uint32_t y6 = _BIT(ty, 3);

#if NUM_BANKS == 16
    b0 = x3 ^ y6;
    b1 = x4 ^ y5 ^ y6;
    b2 = x5 ^ y4;
    b3 = x6 ^ y3;
#elif NUM_BANKS == 8
    b0 = x3 ^ y5;
    b1 = x4 ^ y4 ^ y5;
    b2 = x5 ^ y3;
#elif NUM_BANKS == 4
    b0 = x3 ^ y4;
    b1 = x4 ^ y3;
#elif NUM_BANKS == 2
    b0 = x3 ^ y3;
#endif

    uint32_t bank = b0 | (b1 << 1) | (b2 << 2) | (b3 << 3);

#if ARRAY_MODE == ARRAY_2D_TILED_THIN1 || ARRAY_MODE == ARRAY_2D_TILED_THICK || ARRAY_MODE == ARRAY_2D_TILED_XTHICK
    slice_rotation = ((NUM_BANKS / 2) - 1) * (slice / MICRO_TILE_THICKNESS);
#elif ARRAY_MODE == ARRAY_3D_TILED_THIN1 || ARRAY_MODE == ARRAY_3D_TILED_THICK || ARRAY_MODE == ARRAY_3D_TILED_XTHICK
    slice_rotation = max(1u, (NUM_PIPES / 2) - 1) * (slice / MICRO_TILE_THICKNESS) / NUM_PIPES;
#endif

#if ARRAY_MODE == ARRAY_2D_TILED_THIN1 || ARRAY_MODE == ARRAY_3D_TILED_THIN1 || \
    ARRAY_MODE == ARRAY_PRT_2D_TILED_THIN1 || ARRAY_MODE == ARRAY_PRT_3D_TILED_THIN1
                                                                tile_split_rotation = ((NUM_BANKS / 2) + 1) * tile_split_slice;
#endif

    bank ^= info.bank_swizzle + slice_rotation;
    bank ^= tile_split_rotation;
    bank &= (NUM_BANKS - 1);

    return bank;
}

uint32_t ComputeSurfaceAddrFromCoordMacroTiled(uint32_t x, uint32_t y, uint32_t slice, uint32_t pitch, uint32_t height, uint32_t sample_index) {
    uint32_t pixel_index = ComputePixelIndexWithinMicroTile(x, y, slice);

    uint32_t sample_offset;
    uint32_t pixel_offset;
#if MICRO_TILE_MODE == ADDR_SURF_DEPTH_MICRO_TILING
    sample_offset = sample_index * BITS_PER_PIXEL;
    pixel_offset = pixel_index * BITS_PER_PIXEL * NUM_SAMPLES;
#else
    sample_offset = sample_index * (MICRO_TILE_BITS / NUM_SAMPLES);
    pixel_offset = pixel_index * BITS_PER_PIXEL;
#endif

    uint32_t element_offset = (pixel_offset + sample_offset) / 8;

    uint32_t slices_per_tile = 1;
    uint32_t tile_split_slice = 0;
#if MICRO_TILE_BYTES > TILE_SPLIT_BYTES && MICRO_TILE_THICKNESS == 1
    slices_per_tile = MICRO_TILE_BYTES / TILE_SPLIT_BYTES;
    tile_split_slice = element_offset / TILE_SPLIT_BYTES;
    element_offset %= TILE_SPLIT_BYTES;
    #undef MICRO_TILE_BYTES
    #define MICRO_TILE_BYTES TILE_SPLIT_BYTES
#endif

    uint32_t macro_tile_pitch = (MICRO_TILE_WIDTH * BANK_WIDTH * NUM_PIPES) * MACRO_TILE_ASPECT;
    uint32_t macro_tile_height = (MICRO_TILE_HEIGHT * BANK_HEIGHT * NUM_BANKS) / MACRO_TILE_ASPECT;

    uint32_t macro_tile_bytes = MICRO_TILE_BYTES *
                               (macro_tile_pitch / MICRO_TILE_WIDTH) *
                               (macro_tile_height / MICRO_TILE_HEIGHT) / (NUM_PIPES * NUM_BANKS);

    uint32_t macro_tiles_per_row = pitch / macro_tile_pitch;

    uint32_t macro_tile_index_x = x / macro_tile_pitch;
    uint32_t macro_tile_index_y = y / macro_tile_height;
    uint32_t macro_tile_offset =
        ((macro_tile_index_y * macro_tiles_per_row) + macro_tile_index_x) * macro_tile_bytes;
    uint32_t macro_tiles_per_slice = macro_tiles_per_row * (height / macro_tile_height);

    uint32_t slice_bytes = macro_tiles_per_slice * macro_tile_bytes;
    uint32_t slice_offset =
        slice_bytes * (tile_split_slice + slices_per_tile * (slice / MICRO_TILE_THICKNESS));

    uint32_t tile_row_index = (y / MICRO_TILE_HEIGHT) % BANK_HEIGHT;
    uint32_t tile_column_index = ((x / MICRO_TILE_WIDTH) / NUM_PIPES) % BANK_WIDTH;
    uint32_t tile_index = (tile_row_index * BANK_WIDTH) + tile_column_index;
    uint32_t tile_offset = tile_index * MICRO_TILE_BYTES;

    uint32_t total_offset = slice_offset + macro_tile_offset + element_offset + tile_offset;

#if ARRAY_MODE == ARRAY_PRT_TILED_THIN1 || ARRAY_MODE == ARRAY_PRT_TILED_THICK || \
    ARRAY_MODE == ARRAY_PRT_2D_TILED_THIN1 || ARRAY_MODE == ARRAY_PRT_2D_TILED_THICK || \
    ARRAY_MODE == ARRAY_PRT_3D_TILED_THIN1 || ARRAY_MODE == ARRAY_PRT_3D_TILED_THICK
    x %= macro_tile_pitch;
    y %= macro_tile_height;
#endif

    uint32_t pipe = ComputePipeFromCoord(x, y, slice);
    uint32_t bank = ComputeBankFromCoord(x, y, slice, tile_split_slice);

    uint32_t pipe_interleave_mask = (1 << NUM_PIPE_INTERLEAVE_BITS) - 1;
    uint32_t pipe_interleave_offset = total_offset & pipe_interleave_mask;
    uint32_t offset = total_offset >> NUM_PIPE_INTERLEAVE_BITS;

    uint32_t addr = pipe_interleave_offset;
    uint32_t pipe_bits = pipe << NUM_PIPE_INTERLEAVE_BITS;
    uint32_t bank_bits = bank << (NUM_PIPE_INTERLEAVE_BITS + NUM_PIPE_BITS);
    uint32_t offset_bits = offset << (NUM_PIPE_INTERLEAVE_BITS + NUM_PIPE_BITS + NUM_BANK_BITS);

    addr |= pipe_bits;
    addr |= bank_bits;
    addr |= offset_bits;

    return addr;
}
#endif

uint GetMipLevel(inout uint texel) {
    uint mip = 0;
    uint mip_size = info.mips[mip].size / BYTES_PER_PIXEL;
    while (texel >= mip_size && mip < info.num_mips) {
        texel -= mip_size;
        ++mip;
        mip_size = info.mips[mip].size / BYTES_PER_PIXEL;
    }
    return mip;
}

void main() {
    uint texel = gl_GlobalInvocationID.x;
    uint mip = GetMipLevel(texel);
    uint pitch = info.mips[mip].pitch;
    uint height = info.mips[mip].height;
    uint tiled_offset = info.mips[mip].offset;
    uint x = texel % pitch;
    uint y = (texel / pitch) % height;
    uint slice = texel / (pitch * height);

#if ARRAY_MODE == ARRAY_1D_TILED_THIN1 || ARRAY_MODE == ARRAY_1D_TILED_THICK
    tiled_offset += ComputeSurfaceAddrFromCoordMicroTiled(x, y, slice, pitch, height, 0);
#else
    tiled_offset += ComputeSurfaceAddrFromCoordMacroTiled(x, y, slice, pitch, height, 0);
#endif

#ifdef IS_TILER
    tiled_data[tiled_offset / BYTES_PER_PIXEL] = linear_data[gl_GlobalInvocationID.x];
#else
    linear_data[gl_GlobalInvocationID.x] = tiled_data[tiled_offset / BYTES_PER_PIXEL];
#endif
}

)shader_src";

} // namespace HostShaders
