1use std::marker::PhantomData;
2use std::num::NonZeroU64;
3use std::ops::RangeBounds;
4
5pub const MAX_WRITE_SIZE: usize = 100 * 1024;
6
7const MAX_WRITE_SIZE_U64: NonZeroU64 =
8 NonZeroU64::new(MAX_WRITE_SIZE as u64).expect("MAX_WRITE_SIZE must be non-zero");
9
10#[derive(Debug)]
11pub struct Buffer<T> {
12 label: &'static str,
13 size: u64,
14 usage: wgpu::BufferUsages,
15 pub(crate) raw: wgpu::Buffer,
16 type_: PhantomData<T>,
17}
18
19impl<T: bytemuck::Pod> Buffer<T> {
20 pub fn new(
21 device: &wgpu::Device,
22 label: &'static str,
23 amount: usize,
24 usage: wgpu::BufferUsages,
25 ) -> Self {
26 let size = next_copy_size::<T>(amount);
27
28 let raw = device.create_buffer(&wgpu::BufferDescriptor {
29 label: Some(label),
30 size,
31 usage,
32 mapped_at_creation: false,
33 });
34
35 Self {
36 label,
37 size,
38 usage,
39 raw,
40 type_: PhantomData,
41 }
42 }
43
44 pub fn resize(&mut self, device: &wgpu::Device, new_count: usize) -> bool {
45 let new_size = next_copy_size::<T>(new_count);
46
47 if self.size < new_size {
48 self.raw = device.create_buffer(&wgpu::BufferDescriptor {
49 label: Some(self.label),
50 size: new_size,
51 usage: self.usage,
52 mapped_at_creation: false,
53 });
54
55 self.size = new_size;
56
57 true
58 } else {
59 false
60 }
61 }
62
63 pub fn write(
65 &mut self,
66 device: &wgpu::Device,
67 encoder: &mut wgpu::CommandEncoder,
68 belt: &mut wgpu::util::StagingBelt,
69 offset: usize,
70 contents: &[T],
71 ) -> usize {
72 let bytes: &[u8] = bytemuck::cast_slice(contents);
73 let mut bytes_written = 0;
74
75 while bytes_written + MAX_WRITE_SIZE < bytes.len() {
77 belt.write_buffer(
78 encoder,
79 &self.raw,
80 (offset + bytes_written) as u64,
81 MAX_WRITE_SIZE_U64,
82 device,
83 )
84 .copy_from_slice(&bytes[bytes_written..bytes_written + MAX_WRITE_SIZE]);
85
86 bytes_written += MAX_WRITE_SIZE;
87 }
88
89 let bytes_left = ((bytes.len() - bytes_written) as u64)
92 .try_into()
93 .expect("non-empty write");
94
95 belt.write_buffer(
97 encoder,
98 &self.raw,
99 (offset + bytes_written) as u64,
100 bytes_left,
101 device,
102 )
103 .copy_from_slice(&bytes[bytes_written..]);
104
105 bytes.len()
106 }
107
108 pub fn slice(&self, bounds: impl RangeBounds<wgpu::BufferAddress>) -> wgpu::BufferSlice<'_> {
109 self.raw.slice(bounds)
110 }
111
112 pub fn range(&self, start: usize, end: usize) -> wgpu::BufferSlice<'_> {
113 self.slice(
114 start as u64 * std::mem::size_of::<T>() as u64
115 ..end as u64 * std::mem::size_of::<T>() as u64,
116 )
117 }
118}
119
120fn next_copy_size<T>(amount: usize) -> u64 {
121 let align_mask = wgpu::COPY_BUFFER_ALIGNMENT - 1;
122
123 (((std::mem::size_of::<T>() * amount).next_power_of_two() as u64 + align_mask) & !align_mask)
124 .max(wgpu::COPY_BUFFER_ALIGNMENT)
125}