Vendor things

This commit is contained in:
John Doty 2024-03-08 11:03:01 -08:00
parent 5deceec006
commit 977e3c17e5
19434 changed files with 10682014 additions and 0 deletions

View file

@ -0,0 +1,435 @@
/*!
Definitions for index bounds checking.
*/
use crate::{valid, Handle, UniqueArena};
use bit_set::BitSet;
/// How should code generated by Naga do bounds checks?
///
/// When a vector, matrix, or array index is out of bounds—either negative, or
/// greater than or equal to the number of elements in the type—WGSL requires
/// that some other index of the implementation's choice that is in bounds is
/// used instead. (There are no types with zero elements.)
///
/// Similarly, when out-of-bounds coordinates, array indices, or sample indices
/// are presented to the WGSL `textureLoad` and `textureStore` operations, the
/// operation is redirected to do something safe.
///
/// Different users of Naga will prefer different defaults:
///
/// - When used as part of a WebGPU implementation, the WGSL specification
/// requires the `Restrict` behavior for array, vector, and matrix accesses,
/// and either the `Restrict` or `ReadZeroSkipWrite` behaviors for texture
/// accesses.
///
/// - When used by the `wgpu` crate for native development, `wgpu` selects
/// `ReadZeroSkipWrite` as its default.
///
/// - Naga's own default is `Unchecked`, so that shader translations
/// are as faithful to the original as possible.
///
/// Sometimes the underlying hardware and drivers can perform bounds checks
/// themselves, in a way that performs better than the checks Naga would inject.
/// If you're using native checks like this, then having Naga inject its own
/// checks as well would be redundant, and the `Unchecked` policy is
/// appropriate.
#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
pub enum BoundsCheckPolicy {
/// Replace out-of-bounds indexes with some arbitrary in-bounds index.
///
/// (This does not necessarily mean clamping. For example, interpreting the
/// index as unsigned and taking the minimum with the largest valid index
/// would also be a valid implementation. That would map negative indices to
/// the last element, not the first.)
Restrict,
/// Out-of-bounds reads return zero, and writes have no effect.
///
/// When applied to a chain of accesses, like `a[i][j].b[k]`, all index
/// expressions are evaluated, regardless of whether prior or later index
/// expressions were in bounds. But all the accesses per se are skipped
/// if any index is out of bounds.
ReadZeroSkipWrite,
/// Naga adds no checks to indexing operations. Generate the fastest code
/// possible. This is the default for Naga, as a translator, but consumers
/// should consider defaulting to a safer behavior.
Unchecked,
}
/// Policies for injecting bounds checks during code generation.
#[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq)]
#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
pub struct BoundsCheckPolicies {
/// How should the generated code handle array, vector, or matrix indices
/// that are out of range?
#[cfg_attr(feature = "deserialize", serde(default))]
pub index: BoundsCheckPolicy,
/// How should the generated code handle array, vector, or matrix indices
/// that are out of range, when those values live in a [`GlobalVariable`] in
/// the [`Storage`] or [`Uniform`] address spaces?
///
/// Some graphics hardware provides "robust buffer access", a feature that
/// ensures that using a pointer cannot access memory outside the 'buffer'
/// that it was derived from. In Naga terms, this means that the hardware
/// ensures that pointers computed by applying [`Access`] and
/// [`AccessIndex`] expressions to a [`GlobalVariable`] whose [`space`] is
/// [`Storage`] or [`Uniform`] will never read or write memory outside that
/// global variable.
///
/// When hardware offers such a feature, it is probably undesirable to have
/// Naga inject bounds checking code for such accesses, since the hardware
/// can probably provide the same protection more efficiently. However,
/// bounds checks are still needed on accesses to indexable values that do
/// not live in buffers, like local variables.
///
/// So, this option provides a separate policy that applies only to accesses
/// to storage and uniform globals. When depending on hardware bounds
/// checking, this policy can be `Unchecked` to avoid unnecessary overhead.
///
/// When special hardware support is not available, this should probably be
/// the same as `index_bounds_check_policy`.
///
/// [`GlobalVariable`]: crate::GlobalVariable
/// [`space`]: crate::GlobalVariable::space
/// [`Restrict`]: crate::back::BoundsCheckPolicy::Restrict
/// [`ReadZeroSkipWrite`]: crate::back::BoundsCheckPolicy::ReadZeroSkipWrite
/// [`Access`]: crate::Expression::Access
/// [`AccessIndex`]: crate::Expression::AccessIndex
/// [`Storage`]: crate::AddressSpace::Storage
/// [`Uniform`]: crate::AddressSpace::Uniform
#[cfg_attr(feature = "deserialize", serde(default))]
pub buffer: BoundsCheckPolicy,
/// How should the generated code handle image texel loads that are out
/// of range?
///
/// This controls the behavior of [`ImageLoad`] expressions when a coordinate,
/// texture array index, level of detail, or multisampled sample number is out of range.
///
/// [`ImageLoad`]: crate::Expression::ImageLoad
#[cfg_attr(feature = "deserialize", serde(default))]
pub image_load: BoundsCheckPolicy,
/// How should the generated code handle image texel stores that are out
/// of range?
///
/// This controls the behavior of [`ImageStore`] statements when a coordinate,
/// texture array index, level of detail, or multisampled sample number is out of range.
///
/// This policy should't be needed since all backends should ignore OOB writes.
///
/// [`ImageStore`]: crate::Statement::ImageStore
#[cfg_attr(feature = "deserialize", serde(default))]
pub image_store: BoundsCheckPolicy,
/// How should the generated code handle binding array indexes that are out of bounds.
#[cfg_attr(feature = "deserialize", serde(default))]
pub binding_array: BoundsCheckPolicy,
}
/// The default `BoundsCheckPolicy` is `Unchecked`.
impl Default for BoundsCheckPolicy {
fn default() -> Self {
BoundsCheckPolicy::Unchecked
}
}
impl BoundsCheckPolicies {
/// Determine which policy applies to `base`.
///
/// `base` is the "base" expression (the expression being indexed) of a `Access`
/// and `AccessIndex` expression. This is either a pointer, a value, being directly
/// indexed, or a binding array.
///
/// See the documentation for [`BoundsCheckPolicy`] for details about
/// when each policy applies.
pub fn choose_policy(
&self,
base: Handle<crate::Expression>,
types: &UniqueArena<crate::Type>,
info: &valid::FunctionInfo,
) -> BoundsCheckPolicy {
let ty = info[base].ty.inner_with(types);
if let crate::TypeInner::BindingArray { .. } = *ty {
return self.binding_array;
}
match ty.pointer_space() {
Some(crate::AddressSpace::Storage { access: _ } | crate::AddressSpace::Uniform) => {
self.buffer
}
// This covers other address spaces, but also accessing vectors and
// matrices by value, where no pointer is involved.
_ => self.index,
}
}
/// Return `true` if any of `self`'s policies are `policy`.
pub fn contains(&self, policy: BoundsCheckPolicy) -> bool {
self.index == policy
|| self.buffer == policy
|| self.image_load == policy
|| self.image_store == policy
}
}
/// An index that may be statically known, or may need to be computed at runtime.
///
/// This enum lets us handle both [`Access`] and [`AccessIndex`] expressions
/// with the same code.
///
/// [`Access`]: crate::Expression::Access
/// [`AccessIndex`]: crate::Expression::AccessIndex
#[derive(Clone, Copy, Debug)]
pub enum GuardedIndex {
Known(u32),
Expression(Handle<crate::Expression>),
}
/// Build a set of expressions used as indices, to cache in temporary variables when
/// emitted.
///
/// Given the bounds-check policies `policies`, construct a `BitSet` containing the handle
/// indices of all the expressions in `function` that are ever used as guarded indices
/// under the [`ReadZeroSkipWrite`] policy. The `module` argument must be the module to
/// which `function` belongs, and `info` should be that function's analysis results.
///
/// Such index expressions will be used twice in the generated code: first for the
/// comparison to see if the index is in bounds, and then for the access itself, should
/// the comparison succeed. To avoid computing the expressions twice, the generated code
/// should cache them in temporary variables.
///
/// Why do we need to build such a set in advance, instead of just processing access
/// expressions as we encounter them? Whether an expression needs to be cached depends on
/// whether it appears as something like the [`index`] operand of an [`Access`] expression
/// or the [`level`] operand of an [`ImageLoad`] expression, and on the index bounds check
/// policies that apply to those accesses. But [`Emit`] statements just identify a range
/// of expressions by index; there's no good way to tell what an expression is used
/// for. The only way to do it is to just iterate over all the expressions looking for
/// relevant `Access` expressions --- which is what this function does.
///
/// Simple expressions like variable loads and constants don't make sense to cache: it's
/// no better than just re-evaluating them. But constants are not covered by `Emit`
/// statements, and `Load`s are always cached to ensure they occur at the right time, so
/// we don't bother filtering them out from this set.
///
/// Fortunately, we don't need to deal with [`ImageStore`] statements here. When we emit
/// code for a statement, the writer isn't in the middle of an expression, so we can just
/// emit declarations for temporaries, initialized appropriately.
///
/// None of these concerns apply for SPIR-V output, since it's easy to just reuse an
/// instruction ID in two places; that has the same semantics as a temporary variable, and
/// it's inherent in the design of SPIR-V. This function is more useful for text-based
/// back ends.
///
/// [`ReadZeroSkipWrite`]: BoundsCheckPolicy::ReadZeroSkipWrite
/// [`index`]: crate::Expression::Access::index
/// [`Access`]: crate::Expression::Access
/// [`level`]: crate::Expression::ImageLoad::level
/// [`ImageLoad`]: crate::Expression::ImageLoad
/// [`Emit`]: crate::Statement::Emit
/// [`ImageStore`]: crate::Statement::ImageStore
pub fn find_checked_indexes(
module: &crate::Module,
function: &crate::Function,
info: &crate::valid::FunctionInfo,
policies: BoundsCheckPolicies,
) -> BitSet {
use crate::Expression as Ex;
let mut guarded_indices = BitSet::new();
// Don't bother scanning if we never need `ReadZeroSkipWrite`.
if policies.contains(BoundsCheckPolicy::ReadZeroSkipWrite) {
for (_handle, expr) in function.expressions.iter() {
// There's no need to handle `AccessIndex` expressions, as their
// indices never need to be cached.
match *expr {
Ex::Access { base, index } => {
if policies.choose_policy(base, &module.types, info)
== BoundsCheckPolicy::ReadZeroSkipWrite
&& access_needs_check(
base,
GuardedIndex::Expression(index),
module,
function,
info,
)
.is_some()
{
guarded_indices.insert(index.index());
}
}
Ex::ImageLoad {
coordinate,
array_index,
sample,
level,
..
} => {
if policies.image_load == BoundsCheckPolicy::ReadZeroSkipWrite {
guarded_indices.insert(coordinate.index());
if let Some(array_index) = array_index {
guarded_indices.insert(array_index.index());
}
if let Some(sample) = sample {
guarded_indices.insert(sample.index());
}
if let Some(level) = level {
guarded_indices.insert(level.index());
}
}
}
_ => {}
}
}
}
guarded_indices
}
/// Determine whether `index` is statically known to be in bounds for `base`.
///
/// If we can't be sure that the index is in bounds, return the limit within
/// which valid indices must fall.
///
/// The return value is one of the following:
///
/// - `Some(Known(n))` indicates that `n` is the largest valid index.
///
/// - `Some(Computed(global))` indicates that the largest valid index is one
/// less than the length of the array that is the last member of the
/// struct held in `global`.
///
/// - `None` indicates that the index need not be checked, either because it
/// is statically known to be in bounds, or because the applicable policy
/// is `Unchecked`.
///
/// This function only handles subscriptable types: arrays, vectors, and
/// matrices. It does not handle struct member indices; those never require
/// run-time checks, so it's best to deal with them further up the call
/// chain.
pub fn access_needs_check(
base: Handle<crate::Expression>,
mut index: GuardedIndex,
module: &crate::Module,
function: &crate::Function,
info: &crate::valid::FunctionInfo,
) -> Option<IndexableLength> {
let base_inner = info[base].ty.inner_with(&module.types);
// Unwrap safety: `Err` here indicates unindexable base types and invalid
// length constants, but `access_needs_check` is only used by back ends, so
// validation should have caught those problems.
let length = base_inner.indexable_length(module).unwrap();
index.try_resolve_to_constant(function, module);
if let (&GuardedIndex::Known(index), &IndexableLength::Known(length)) = (&index, &length) {
if index < length {
// Index is statically known to be in bounds, no check needed.
return None;
}
};
Some(length)
}
impl GuardedIndex {
/// Make a `GuardedIndex::Known` from a `GuardedIndex::Expression` if possible.
///
/// Return values that are already `Known` unchanged.
fn try_resolve_to_constant(&mut self, function: &crate::Function, module: &crate::Module) {
if let GuardedIndex::Expression(expr) = *self {
if let Ok(value) = module
.to_ctx()
.eval_expr_to_u32_from(expr, &function.expressions)
{
*self = GuardedIndex::Known(value);
}
}
}
}
#[derive(Clone, Copy, Debug, thiserror::Error, PartialEq)]
pub enum IndexableLengthError {
#[error("Type is not indexable, and has no length (validation error)")]
TypeNotIndexable,
#[error("Array length constant {0:?} is invalid")]
InvalidArrayLength(Handle<crate::Expression>),
}
impl crate::TypeInner {
/// Return the length of a subscriptable type.
///
/// The `self` parameter should be a handle to a vector, matrix, or array
/// type, a pointer to one of those, or a value pointer. Arrays may be
/// fixed-size, dynamically sized, or sized by a specializable constant.
/// This function does not handle struct member references, as with
/// `AccessIndex`.
///
/// The value returned is appropriate for bounds checks on subscripting.
///
/// Return an error if `self` does not describe a subscriptable type at all.
pub fn indexable_length(
&self,
module: &crate::Module,
) -> Result<IndexableLength, IndexableLengthError> {
use crate::TypeInner as Ti;
let known_length = match *self {
Ti::Vector { size, .. } => size as _,
Ti::Matrix { columns, .. } => columns as _,
Ti::Array { size, .. } | Ti::BindingArray { size, .. } => {
return size.to_indexable_length(module);
}
Ti::ValuePointer {
size: Some(size), ..
} => size as _,
Ti::Pointer { base, .. } => {
// When assigning types to expressions, ResolveContext::Resolve
// does a separate sub-match here instead of a full recursion,
// so we'll do the same.
let base_inner = &module.types[base].inner;
match *base_inner {
Ti::Vector { size, .. } => size as _,
Ti::Matrix { columns, .. } => columns as _,
Ti::Array { size, .. } | Ti::BindingArray { size, .. } => {
return size.to_indexable_length(module)
}
_ => return Err(IndexableLengthError::TypeNotIndexable),
}
}
_ => return Err(IndexableLengthError::TypeNotIndexable),
};
Ok(IndexableLength::Known(known_length))
}
}
/// The number of elements in an indexable type.
///
/// This summarizes the length of vectors, matrices, and arrays in a way that is
/// convenient for indexing and bounds-checking code.
#[derive(Debug)]
pub enum IndexableLength {
/// Values of this type always have the given number of elements.
Known(u32),
/// The number of elements is determined at runtime.
Dynamic,
}
impl crate::ArraySize {
pub const fn to_indexable_length(
self,
_module: &crate::Module,
) -> Result<IndexableLength, IndexableLengthError> {
Ok(match self {
Self::Constant(length) => IndexableLength::Known(length.get()),
Self::Dynamic => IndexableLength::Dynamic,
})
}
}

View file

@ -0,0 +1,252 @@
use crate::arena::Handle;
use std::{fmt::Display, num::NonZeroU32, ops};
/// A newtype struct where its only valid values are powers of 2
#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, PartialOrd, Ord)]
#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
pub struct Alignment(NonZeroU32);
impl Alignment {
pub const ONE: Self = Self(unsafe { NonZeroU32::new_unchecked(1) });
pub const TWO: Self = Self(unsafe { NonZeroU32::new_unchecked(2) });
pub const FOUR: Self = Self(unsafe { NonZeroU32::new_unchecked(4) });
pub const EIGHT: Self = Self(unsafe { NonZeroU32::new_unchecked(8) });
pub const SIXTEEN: Self = Self(unsafe { NonZeroU32::new_unchecked(16) });
pub const MIN_UNIFORM: Self = Self::SIXTEEN;
pub const fn new(n: u32) -> Option<Self> {
if n.is_power_of_two() {
// SAFETY: value can't be 0 since we just checked if it's a power of 2
Some(Self(unsafe { NonZeroU32::new_unchecked(n) }))
} else {
None
}
}
/// # Panics
/// If `width` is not a power of 2
pub fn from_width(width: u8) -> Self {
Self::new(width as u32).unwrap()
}
/// Returns whether or not `n` is a multiple of this alignment.
pub const fn is_aligned(&self, n: u32) -> bool {
// equivalent to: `n % self.0.get() == 0` but much faster
n & (self.0.get() - 1) == 0
}
/// Round `n` up to the nearest alignment boundary.
pub const fn round_up(&self, n: u32) -> u32 {
// equivalent to:
// match n % self.0.get() {
// 0 => n,
// rem => n + (self.0.get() - rem),
// }
let mask = self.0.get() - 1;
(n + mask) & !mask
}
}
impl Display for Alignment {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.0.get().fmt(f)
}
}
impl ops::Mul<u32> for Alignment {
type Output = u32;
fn mul(self, rhs: u32) -> Self::Output {
self.0.get() * rhs
}
}
impl ops::Mul for Alignment {
type Output = Alignment;
fn mul(self, rhs: Alignment) -> Self::Output {
// SAFETY: both lhs and rhs are powers of 2, the result will be a power of 2
Self(unsafe { NonZeroU32::new_unchecked(self.0.get() * rhs.0.get()) })
}
}
impl From<crate::VectorSize> for Alignment {
fn from(size: crate::VectorSize) -> Self {
match size {
crate::VectorSize::Bi => Alignment::TWO,
crate::VectorSize::Tri => Alignment::FOUR,
crate::VectorSize::Quad => Alignment::FOUR,
}
}
}
/// Size and alignment information for a type.
#[derive(Clone, Copy, Debug, Hash, PartialEq)]
#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
pub struct TypeLayout {
pub size: u32,
pub alignment: Alignment,
}
impl TypeLayout {
/// Produce the stride as if this type is a base of an array.
pub const fn to_stride(&self) -> u32 {
self.alignment.round_up(self.size)
}
}
/// Helper processor that derives the sizes of all types.
///
/// `Layouter` uses the default layout algorithm/table, described in
/// [WGSL §4.3.7, "Memory Layout"]
///
/// A `Layouter` may be indexed by `Handle<Type>` values: `layouter[handle]` is the
/// layout of the type whose handle is `handle`.
///
/// [WGSL §4.3.7, "Memory Layout"](https://gpuweb.github.io/gpuweb/wgsl/#memory-layouts)
#[derive(Debug, Default)]
#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
pub struct Layouter {
/// Layouts for types in an arena, indexed by `Handle` index.
layouts: Vec<TypeLayout>,
}
impl ops::Index<Handle<crate::Type>> for Layouter {
type Output = TypeLayout;
fn index(&self, handle: Handle<crate::Type>) -> &TypeLayout {
&self.layouts[handle.index()]
}
}
#[derive(Clone, Copy, Debug, PartialEq, thiserror::Error)]
pub enum LayoutErrorInner {
#[error("Array element type {0:?} doesn't exist")]
InvalidArrayElementType(Handle<crate::Type>),
#[error("Struct member[{0}] type {1:?} doesn't exist")]
InvalidStructMemberType(u32, Handle<crate::Type>),
#[error("Type width must be a power of two")]
NonPowerOfTwoWidth,
}
#[derive(Clone, Copy, Debug, PartialEq, thiserror::Error)]
#[error("Error laying out type {ty:?}: {inner}")]
pub struct LayoutError {
pub ty: Handle<crate::Type>,
pub inner: LayoutErrorInner,
}
impl LayoutErrorInner {
const fn with(self, ty: Handle<crate::Type>) -> LayoutError {
LayoutError { ty, inner: self }
}
}
impl Layouter {
/// Remove all entries from this `Layouter`, retaining storage.
pub fn clear(&mut self) {
self.layouts.clear();
}
/// Extend this `Layouter` with layouts for any new entries in `types`.
///
/// Ensure that every type in `types` has a corresponding [TypeLayout] in
/// [`self.layouts`].
///
/// Some front ends need to be able to compute layouts for existing types
/// while module construction is still in progress and new types are still
/// being added. This function assumes that the `TypeLayout` values already
/// present in `self.layouts` cover their corresponding entries in `types`,
/// and extends `self.layouts` as needed to cover the rest. Thus, a front
/// end can call this function at any time, passing its current type and
/// constant arenas, and then assume that layouts are available for all
/// types.
#[allow(clippy::or_fun_call)]
pub fn update(&mut self, gctx: super::GlobalCtx) -> Result<(), LayoutError> {
use crate::TypeInner as Ti;
for (ty_handle, ty) in gctx.types.iter().skip(self.layouts.len()) {
let size = ty.inner.size(gctx);
let layout = match ty.inner {
Ti::Scalar { width, .. } | Ti::Atomic { width, .. } => {
let alignment = Alignment::new(width as u32)
.ok_or(LayoutErrorInner::NonPowerOfTwoWidth.with(ty_handle))?;
TypeLayout { size, alignment }
}
Ti::Vector {
size: vec_size,
width,
..
} => {
let alignment = Alignment::new(width as u32)
.ok_or(LayoutErrorInner::NonPowerOfTwoWidth.with(ty_handle))?;
TypeLayout {
size,
alignment: Alignment::from(vec_size) * alignment,
}
}
Ti::Matrix {
columns: _,
rows,
width,
} => {
let alignment = Alignment::new(width as u32)
.ok_or(LayoutErrorInner::NonPowerOfTwoWidth.with(ty_handle))?;
TypeLayout {
size,
alignment: Alignment::from(rows) * alignment,
}
}
Ti::Pointer { .. } | Ti::ValuePointer { .. } => TypeLayout {
size,
alignment: Alignment::ONE,
},
Ti::Array {
base,
stride: _,
size: _,
} => TypeLayout {
size,
alignment: if base < ty_handle {
self[base].alignment
} else {
return Err(LayoutErrorInner::InvalidArrayElementType(base).with(ty_handle));
},
},
Ti::Struct { span, ref members } => {
let mut alignment = Alignment::ONE;
for (index, member) in members.iter().enumerate() {
alignment = if member.ty < ty_handle {
alignment.max(self[member.ty].alignment)
} else {
return Err(LayoutErrorInner::InvalidStructMemberType(
index as u32,
member.ty,
)
.with(ty_handle));
};
}
TypeLayout {
size: span,
alignment,
}
}
Ti::Image { .. }
| Ti::Sampler { .. }
| Ti::AccelerationStructure
| Ti::RayQuery
| Ti::BindingArray { .. } => TypeLayout {
size,
alignment: Alignment::ONE,
},
};
debug_assert!(size <= layout.size);
self.layouts.push(layout);
}
Ok(())
}
}

636
third-party/vendor/naga/src/proc/mod.rs vendored Normal file
View file

@ -0,0 +1,636 @@
/*!
[`Module`](super::Module) processing functionality.
*/
pub mod index;
mod layouter;
mod namer;
mod terminator;
mod typifier;
pub use index::{BoundsCheckPolicies, BoundsCheckPolicy, IndexableLength, IndexableLengthError};
pub use layouter::{Alignment, LayoutError, LayoutErrorInner, Layouter, TypeLayout};
pub use namer::{EntryPointIndex, NameKey, Namer};
pub use terminator::ensure_block_returns;
pub use typifier::{ResolveContext, ResolveError, TypeResolution};
impl From<super::StorageFormat> for super::ScalarKind {
fn from(format: super::StorageFormat) -> Self {
use super::{ScalarKind as Sk, StorageFormat as Sf};
match format {
Sf::R8Unorm => Sk::Float,
Sf::R8Snorm => Sk::Float,
Sf::R8Uint => Sk::Uint,
Sf::R8Sint => Sk::Sint,
Sf::R16Uint => Sk::Uint,
Sf::R16Sint => Sk::Sint,
Sf::R16Float => Sk::Float,
Sf::Rg8Unorm => Sk::Float,
Sf::Rg8Snorm => Sk::Float,
Sf::Rg8Uint => Sk::Uint,
Sf::Rg8Sint => Sk::Sint,
Sf::R32Uint => Sk::Uint,
Sf::R32Sint => Sk::Sint,
Sf::R32Float => Sk::Float,
Sf::Rg16Uint => Sk::Uint,
Sf::Rg16Sint => Sk::Sint,
Sf::Rg16Float => Sk::Float,
Sf::Rgba8Unorm => Sk::Float,
Sf::Rgba8Snorm => Sk::Float,
Sf::Rgba8Uint => Sk::Uint,
Sf::Rgba8Sint => Sk::Sint,
Sf::Rgb10a2Unorm => Sk::Float,
Sf::Rg11b10Float => Sk::Float,
Sf::Rg32Uint => Sk::Uint,
Sf::Rg32Sint => Sk::Sint,
Sf::Rg32Float => Sk::Float,
Sf::Rgba16Uint => Sk::Uint,
Sf::Rgba16Sint => Sk::Sint,
Sf::Rgba16Float => Sk::Float,
Sf::Rgba32Uint => Sk::Uint,
Sf::Rgba32Sint => Sk::Sint,
Sf::Rgba32Float => Sk::Float,
Sf::R16Unorm => Sk::Float,
Sf::R16Snorm => Sk::Float,
Sf::Rg16Unorm => Sk::Float,
Sf::Rg16Snorm => Sk::Float,
Sf::Rgba16Unorm => Sk::Float,
Sf::Rgba16Snorm => Sk::Float,
}
}
}
impl super::ScalarKind {
pub const fn is_numeric(self) -> bool {
match self {
crate::ScalarKind::Sint | crate::ScalarKind::Uint | crate::ScalarKind::Float => true,
crate::ScalarKind::Bool => false,
}
}
}
impl PartialEq for crate::Literal {
fn eq(&self, other: &Self) -> bool {
match (*self, *other) {
(Self::F64(a), Self::F64(b)) => a.to_bits() == b.to_bits(),
(Self::F32(a), Self::F32(b)) => a.to_bits() == b.to_bits(),
(Self::U32(a), Self::U32(b)) => a == b,
(Self::I32(a), Self::I32(b)) => a == b,
(Self::Bool(a), Self::Bool(b)) => a == b,
_ => false,
}
}
}
impl Eq for crate::Literal {}
impl std::hash::Hash for crate::Literal {
fn hash<H: std::hash::Hasher>(&self, hasher: &mut H) {
match *self {
Self::F64(v) => {
hasher.write_u8(0);
v.to_bits().hash(hasher);
}
Self::F32(v) => {
hasher.write_u8(1);
v.to_bits().hash(hasher);
}
Self::U32(v) => {
hasher.write_u8(2);
v.hash(hasher);
}
Self::I32(v) => {
hasher.write_u8(3);
v.hash(hasher);
}
Self::Bool(v) => {
hasher.write_u8(4);
v.hash(hasher);
}
}
}
}
impl crate::Literal {
pub const fn new(value: u8, kind: crate::ScalarKind, width: crate::Bytes) -> Option<Self> {
match (value, kind, width) {
(value, crate::ScalarKind::Float, 8) => Some(Self::F64(value as _)),
(value, crate::ScalarKind::Float, 4) => Some(Self::F32(value as _)),
(value, crate::ScalarKind::Uint, 4) => Some(Self::U32(value as _)),
(value, crate::ScalarKind::Sint, 4) => Some(Self::I32(value as _)),
(1, crate::ScalarKind::Bool, 4) => Some(Self::Bool(true)),
(0, crate::ScalarKind::Bool, 4) => Some(Self::Bool(false)),
_ => None,
}
}
pub const fn zero(kind: crate::ScalarKind, width: crate::Bytes) -> Option<Self> {
Self::new(0, kind, width)
}
pub const fn one(kind: crate::ScalarKind, width: crate::Bytes) -> Option<Self> {
Self::new(1, kind, width)
}
pub const fn width(&self) -> crate::Bytes {
match *self {
Self::F64(_) => 8,
Self::F32(_) | Self::U32(_) | Self::I32(_) => 4,
Self::Bool(_) => 1,
}
}
pub const fn scalar_kind(&self) -> crate::ScalarKind {
match *self {
Self::F64(_) | Self::F32(_) => crate::ScalarKind::Float,
Self::U32(_) => crate::ScalarKind::Uint,
Self::I32(_) => crate::ScalarKind::Sint,
Self::Bool(_) => crate::ScalarKind::Bool,
}
}
pub const fn ty_inner(&self) -> crate::TypeInner {
crate::TypeInner::Scalar {
kind: self.scalar_kind(),
width: self.width(),
}
}
}
pub const POINTER_SPAN: u32 = 4;
impl super::TypeInner {
pub const fn scalar_kind(&self) -> Option<super::ScalarKind> {
match *self {
super::TypeInner::Scalar { kind, .. } | super::TypeInner::Vector { kind, .. } => {
Some(kind)
}
super::TypeInner::Matrix { .. } => Some(super::ScalarKind::Float),
_ => None,
}
}
pub const fn scalar_width(&self) -> Option<u8> {
// Multiply by 8 to get the bit width
match *self {
super::TypeInner::Scalar { width, .. } | super::TypeInner::Vector { width, .. } => {
Some(width * 8)
}
super::TypeInner::Matrix { width, .. } => Some(width * 8),
_ => None,
}
}
pub const fn pointer_space(&self) -> Option<crate::AddressSpace> {
match *self {
Self::Pointer { space, .. } => Some(space),
Self::ValuePointer { space, .. } => Some(space),
_ => None,
}
}
pub fn is_atomic_pointer(&self, types: &crate::UniqueArena<crate::Type>) -> bool {
match *self {
crate::TypeInner::Pointer { base, .. } => match types[base].inner {
crate::TypeInner::Atomic { .. } => true,
_ => false,
},
_ => false,
}
}
/// Get the size of this type.
pub fn size(&self, _gctx: GlobalCtx) -> u32 {
match *self {
Self::Scalar { kind: _, width } | Self::Atomic { kind: _, width } => width as u32,
Self::Vector {
size,
kind: _,
width,
} => size as u32 * width as u32,
// matrices are treated as arrays of aligned columns
Self::Matrix {
columns,
rows,
width,
} => Alignment::from(rows) * width as u32 * columns as u32,
Self::Pointer { .. } | Self::ValuePointer { .. } => POINTER_SPAN,
Self::Array {
base: _,
size,
stride,
} => {
let count = match size {
super::ArraySize::Constant(count) => count.get(),
// A dynamically-sized array has to have at least one element
super::ArraySize::Dynamic => 1,
};
count * stride
}
Self::Struct { span, .. } => span,
Self::Image { .. }
| Self::Sampler { .. }
| Self::AccelerationStructure
| Self::RayQuery
| Self::BindingArray { .. } => 0,
}
}
/// Return the canonical form of `self`, or `None` if it's already in
/// canonical form.
///
/// Certain types have multiple representations in `TypeInner`. This
/// function converts all forms of equivalent types to a single
/// representative of their class, so that simply applying `Eq` to the
/// result indicates whether the types are equivalent, as far as Naga IR is
/// concerned.
pub fn canonical_form(
&self,
types: &crate::UniqueArena<crate::Type>,
) -> Option<crate::TypeInner> {
use crate::TypeInner as Ti;
match *self {
Ti::Pointer { base, space } => match types[base].inner {
Ti::Scalar { kind, width } => Some(Ti::ValuePointer {
size: None,
kind,
width,
space,
}),
Ti::Vector { size, kind, width } => Some(Ti::ValuePointer {
size: Some(size),
kind,
width,
space,
}),
_ => None,
},
_ => None,
}
}
/// Compare `self` and `rhs` as types.
///
/// This is mostly the same as `<TypeInner as Eq>::eq`, but it treats
/// `ValuePointer` and `Pointer` types as equivalent.
///
/// When you know that one side of the comparison is never a pointer, it's
/// fine to not bother with canonicalization, and just compare `TypeInner`
/// values with `==`.
pub fn equivalent(
&self,
rhs: &crate::TypeInner,
types: &crate::UniqueArena<crate::Type>,
) -> bool {
let left = self.canonical_form(types);
let right = rhs.canonical_form(types);
left.as_ref().unwrap_or(self) == right.as_ref().unwrap_or(rhs)
}
pub fn is_dynamically_sized(&self, types: &crate::UniqueArena<crate::Type>) -> bool {
use crate::TypeInner as Ti;
match *self {
Ti::Array { size, .. } => size == crate::ArraySize::Dynamic,
Ti::Struct { ref members, .. } => members
.last()
.map(|last| types[last.ty].inner.is_dynamically_sized(types))
.unwrap_or(false),
_ => false,
}
}
pub fn components(&self) -> Option<u32> {
Some(match *self {
Self::Vector { size, .. } => size as u32,
Self::Matrix { columns, .. } => columns as u32,
Self::Array {
size: crate::ArraySize::Constant(len),
..
} => len.get(),
Self::Struct { ref members, .. } => members.len() as u32,
_ => return None,
})
}
pub fn component_type(&self, index: usize) -> Option<TypeResolution> {
Some(match *self {
Self::Vector { kind, width, .. } => {
TypeResolution::Value(crate::TypeInner::Scalar { kind, width })
}
Self::Matrix { rows, width, .. } => TypeResolution::Value(crate::TypeInner::Vector {
size: rows,
kind: crate::ScalarKind::Float,
width,
}),
Self::Array {
base,
size: crate::ArraySize::Constant(_),
..
} => TypeResolution::Handle(base),
Self::Struct { ref members, .. } => TypeResolution::Handle(members[index].ty),
_ => return None,
})
}
}
impl super::AddressSpace {
pub fn access(self) -> crate::StorageAccess {
use crate::StorageAccess as Sa;
match self {
crate::AddressSpace::Function
| crate::AddressSpace::Private
| crate::AddressSpace::WorkGroup => Sa::LOAD | Sa::STORE,
crate::AddressSpace::Uniform => Sa::LOAD,
crate::AddressSpace::Storage { access } => access,
crate::AddressSpace::Handle => Sa::LOAD,
crate::AddressSpace::PushConstant => Sa::LOAD,
}
}
}
impl super::MathFunction {
pub const fn argument_count(&self) -> usize {
match *self {
// comparison
Self::Abs => 1,
Self::Min => 2,
Self::Max => 2,
Self::Clamp => 3,
Self::Saturate => 1,
// trigonometry
Self::Cos => 1,
Self::Cosh => 1,
Self::Sin => 1,
Self::Sinh => 1,
Self::Tan => 1,
Self::Tanh => 1,
Self::Acos => 1,
Self::Asin => 1,
Self::Atan => 1,
Self::Atan2 => 2,
Self::Asinh => 1,
Self::Acosh => 1,
Self::Atanh => 1,
Self::Radians => 1,
Self::Degrees => 1,
// decomposition
Self::Ceil => 1,
Self::Floor => 1,
Self::Round => 1,
Self::Fract => 1,
Self::Trunc => 1,
Self::Modf => 2,
Self::Frexp => 2,
Self::Ldexp => 2,
// exponent
Self::Exp => 1,
Self::Exp2 => 1,
Self::Log => 1,
Self::Log2 => 1,
Self::Pow => 2,
// geometry
Self::Dot => 2,
Self::Outer => 2,
Self::Cross => 2,
Self::Distance => 2,
Self::Length => 1,
Self::Normalize => 1,
Self::FaceForward => 3,
Self::Reflect => 2,
Self::Refract => 3,
// computational
Self::Sign => 1,
Self::Fma => 3,
Self::Mix => 3,
Self::Step => 2,
Self::SmoothStep => 3,
Self::Sqrt => 1,
Self::InverseSqrt => 1,
Self::Inverse => 1,
Self::Transpose => 1,
Self::Determinant => 1,
// bits
Self::CountTrailingZeros => 1,
Self::CountLeadingZeros => 1,
Self::CountOneBits => 1,
Self::ReverseBits => 1,
Self::ExtractBits => 3,
Self::InsertBits => 4,
Self::FindLsb => 1,
Self::FindMsb => 1,
// data packing
Self::Pack4x8snorm => 1,
Self::Pack4x8unorm => 1,
Self::Pack2x16snorm => 1,
Self::Pack2x16unorm => 1,
Self::Pack2x16float => 1,
// data unpacking
Self::Unpack4x8snorm => 1,
Self::Unpack4x8unorm => 1,
Self::Unpack2x16snorm => 1,
Self::Unpack2x16unorm => 1,
Self::Unpack2x16float => 1,
}
}
}
impl crate::Expression {
/// Returns true if the expression is considered emitted at the start of a function.
pub const fn needs_pre_emit(&self) -> bool {
match *self {
Self::Literal(_)
| Self::Constant(_)
| Self::ZeroValue(_)
| Self::FunctionArgument(_)
| Self::GlobalVariable(_)
| Self::LocalVariable(_) => true,
_ => false,
}
}
/// Return true if this expression is a dynamic array index, for [`Access`].
///
/// This method returns true if this expression is a dynamically computed
/// index, and as such can only be used to index matrices and arrays when
/// they appear behind a pointer. See the documentation for [`Access`] for
/// details.
///
/// Note, this does not check the _type_ of the given expression. It's up to
/// the caller to establish that the `Access` expression is well-typed
/// through other means, like [`ResolveContext`].
///
/// [`Access`]: crate::Expression::Access
/// [`ResolveContext`]: crate::proc::ResolveContext
pub fn is_dynamic_index(&self, module: &crate::Module) -> bool {
if let Self::Constant(handle) = *self {
let constant = &module.constants[handle];
!matches!(constant.r#override, crate::Override::None)
} else {
true
}
}
}
impl crate::Function {
/// Return the global variable being accessed by the expression `pointer`.
///
/// Assuming that `pointer` is a series of `Access` and `AccessIndex`
/// expressions that ultimately access some part of a `GlobalVariable`,
/// return a handle for that global.
///
/// If the expression does not ultimately access a global variable, return
/// `None`.
pub fn originating_global(
&self,
mut pointer: crate::Handle<crate::Expression>,
) -> Option<crate::Handle<crate::GlobalVariable>> {
loop {
pointer = match self.expressions[pointer] {
crate::Expression::Access { base, .. } => base,
crate::Expression::AccessIndex { base, .. } => base,
crate::Expression::GlobalVariable(handle) => return Some(handle),
crate::Expression::LocalVariable(_) => return None,
crate::Expression::FunctionArgument(_) => return None,
// There are no other expressions that produce pointer values.
_ => unreachable!(),
}
}
}
}
impl crate::SampleLevel {
pub const fn implicit_derivatives(&self) -> bool {
match *self {
Self::Auto | Self::Bias(_) => true,
Self::Zero | Self::Exact(_) | Self::Gradient { .. } => false,
}
}
}
impl crate::Binding {
pub const fn to_built_in(&self) -> Option<crate::BuiltIn> {
match *self {
crate::Binding::BuiltIn(built_in) => Some(built_in),
Self::Location { .. } => None,
}
}
}
impl super::SwizzleComponent {
pub const XYZW: [Self; 4] = [Self::X, Self::Y, Self::Z, Self::W];
pub const fn index(&self) -> u32 {
match *self {
Self::X => 0,
Self::Y => 1,
Self::Z => 2,
Self::W => 3,
}
}
pub const fn from_index(idx: u32) -> Self {
match idx {
0 => Self::X,
1 => Self::Y,
2 => Self::Z,
_ => Self::W,
}
}
}
impl super::ImageClass {
pub const fn is_multisampled(self) -> bool {
match self {
crate::ImageClass::Sampled { multi, .. } | crate::ImageClass::Depth { multi } => multi,
crate::ImageClass::Storage { .. } => false,
}
}
pub const fn is_mipmapped(self) -> bool {
match self {
crate::ImageClass::Sampled { multi, .. } | crate::ImageClass::Depth { multi } => !multi,
crate::ImageClass::Storage { .. } => false,
}
}
}
impl crate::Module {
pub const fn to_ctx(&self) -> GlobalCtx<'_> {
GlobalCtx {
types: &self.types,
constants: &self.constants,
const_expressions: &self.const_expressions,
}
}
}
#[derive(Debug)]
pub(super) enum U32EvalError {
NonConst,
Negative,
}
#[derive(Clone, Copy)]
pub struct GlobalCtx<'a> {
pub types: &'a crate::UniqueArena<crate::Type>,
pub constants: &'a crate::Arena<crate::Constant>,
pub const_expressions: &'a crate::Arena<crate::Expression>,
}
impl GlobalCtx<'_> {
/// Try to evaluate the expression in `self.const_expressions` using its `handle` and return it as a `u32`.
#[allow(dead_code)]
pub(super) fn eval_expr_to_u32(
&self,
handle: crate::Handle<crate::Expression>,
) -> Result<u32, U32EvalError> {
self.eval_expr_to_u32_from(handle, self.const_expressions)
}
/// Try to evaluate the expression in the `arena` using its `handle` and return it as a `u32`.
pub(super) fn eval_expr_to_u32_from(
&self,
handle: crate::Handle<crate::Expression>,
arena: &crate::Arena<crate::Expression>,
) -> Result<u32, U32EvalError> {
fn get(
gctx: GlobalCtx,
handle: crate::Handle<crate::Expression>,
arena: &crate::Arena<crate::Expression>,
) -> Result<u32, U32EvalError> {
match arena[handle] {
crate::Expression::Literal(crate::Literal::U32(value)) => Ok(value),
crate::Expression::Literal(crate::Literal::I32(value)) => {
value.try_into().map_err(|_| U32EvalError::Negative)
}
crate::Expression::ZeroValue(ty)
if matches!(
gctx.types[ty].inner,
crate::TypeInner::Scalar {
kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint,
width: _
}
) =>
{
Ok(0)
}
_ => Err(U32EvalError::NonConst),
}
}
match arena[handle] {
crate::Expression::Constant(c) => {
get(*self, self.constants[c].init, self.const_expressions)
}
_ => get(*self, handle, arena),
}
}
}
#[test]
fn test_matrix_size() {
let module = crate::Module::default();
assert_eq!(
crate::TypeInner::Matrix {
columns: crate::VectorSize::Tri,
rows: crate::VectorSize::Tri,
width: 4
}
.size(module.to_ctx()),
48,
);
}

View file

@ -0,0 +1,271 @@
use crate::{arena::Handle, FastHashMap, FastHashSet};
use std::borrow::Cow;
use std::hash::{Hash, Hasher};
pub type EntryPointIndex = u16;
const SEPARATOR: char = '_';
#[derive(Debug, Eq, Hash, PartialEq)]
pub enum NameKey {
Constant(Handle<crate::Constant>),
GlobalVariable(Handle<crate::GlobalVariable>),
Type(Handle<crate::Type>),
StructMember(Handle<crate::Type>, u32),
Function(Handle<crate::Function>),
FunctionArgument(Handle<crate::Function>, u32),
FunctionLocal(Handle<crate::Function>, Handle<crate::LocalVariable>),
EntryPoint(EntryPointIndex),
EntryPointLocal(EntryPointIndex, Handle<crate::LocalVariable>),
EntryPointArgument(EntryPointIndex, u32),
}
/// This processor assigns names to all the things in a module
/// that may need identifiers in a textual backend.
#[derive(Default)]
pub struct Namer {
/// The last numeric suffix used for each base name. Zero means "no suffix".
unique: FastHashMap<String, u32>,
keywords: FastHashSet<&'static str>,
keywords_case_insensitive: FastHashSet<AsciiUniCase<&'static str>>,
reserved_prefixes: Vec<&'static str>,
}
impl Namer {
/// Return a form of `string` suitable for use as the base of an identifier.
///
/// - Drop leading digits.
/// - Retain only alphanumeric and `_` characters.
/// - Avoid prefixes in [`Namer::reserved_prefixes`].
///
/// The return value is a valid identifier prefix in all of Naga's output languages,
/// and it never ends with a `SEPARATOR` character.
/// It is used as a key into the unique table.
fn sanitize<'s>(&self, string: &'s str) -> Cow<'s, str> {
let string = string
.trim_start_matches(|c: char| c.is_numeric())
.trim_end_matches(SEPARATOR);
let base = if !string.is_empty()
&& string
.chars()
.all(|c: char| c.is_ascii_alphanumeric() || c == '_')
{
Cow::Borrowed(string)
} else {
let mut filtered = string
.chars()
.filter(|&c| c.is_ascii_alphanumeric() || c == '_')
.collect::<String>();
let stripped_len = filtered.trim_end_matches(SEPARATOR).len();
filtered.truncate(stripped_len);
if filtered.is_empty() {
filtered.push_str("unnamed");
}
Cow::Owned(filtered)
};
for prefix in &self.reserved_prefixes {
if base.starts_with(prefix) {
return format!("gen_{base}").into();
}
}
base
}
/// Return a new identifier based on `label_raw`.
///
/// The result:
/// - is a valid identifier even if `label_raw` is not
/// - conflicts with no keywords listed in `Namer::keywords`, and
/// - is different from any identifier previously constructed by this
/// `Namer`.
///
/// Guarantee uniqueness by applying a numeric suffix when necessary. If `label_raw`
/// itself ends with digits, separate them from the suffix with an underscore.
pub fn call(&mut self, label_raw: &str) -> String {
use std::fmt::Write as _; // for write!-ing to Strings
let base = self.sanitize(label_raw);
debug_assert!(!base.is_empty() && !base.ends_with(SEPARATOR));
// This would seem to be a natural place to use `HashMap::entry`. However, `entry`
// requires an owned key, and we'd like to avoid heap-allocating strings we're
// just going to throw away. The approach below double-hashes only when we create
// a new entry, in which case the heap allocation of the owned key was more
// expensive anyway.
match self.unique.get_mut(base.as_ref()) {
Some(count) => {
*count += 1;
// Add the suffix. This may fit in base's existing allocation.
let mut suffixed = base.into_owned();
write!(suffixed, "{}{}", SEPARATOR, *count).unwrap();
suffixed
}
None => {
let mut suffixed = base.to_string();
if base.ends_with(char::is_numeric)
|| self.keywords.contains(base.as_ref())
|| self
.keywords_case_insensitive
.contains(&AsciiUniCase(base.as_ref()))
{
suffixed.push(SEPARATOR);
}
debug_assert!(!self.keywords.contains::<str>(&suffixed));
// `self.unique` wants to own its keys. This allocates only if we haven't
// already done so earlier.
self.unique.insert(base.into_owned(), 0);
suffixed
}
}
}
pub fn call_or(&mut self, label: &Option<String>, fallback: &str) -> String {
self.call(match *label {
Some(ref name) => name,
None => fallback,
})
}
/// Enter a local namespace for things like structs.
///
/// Struct member names only need to be unique amongst themselves, not
/// globally. This function temporarily establishes a fresh, empty naming
/// context for the duration of the call to `body`.
fn namespace(&mut self, capacity: usize, body: impl FnOnce(&mut Self)) {
let fresh = FastHashMap::with_capacity_and_hasher(capacity, Default::default());
let outer = std::mem::replace(&mut self.unique, fresh);
body(self);
self.unique = outer;
}
pub fn reset(
&mut self,
module: &crate::Module,
reserved_keywords: &[&'static str],
extra_reserved_keywords: &[&'static str],
reserved_keywords_case_insensitive: &[&'static str],
reserved_prefixes: &[&'static str],
output: &mut FastHashMap<NameKey, String>,
) {
self.reserved_prefixes.clear();
self.reserved_prefixes.extend(reserved_prefixes.iter());
self.unique.clear();
self.keywords.clear();
self.keywords.extend(reserved_keywords.iter());
self.keywords.extend(extra_reserved_keywords.iter());
debug_assert!(reserved_keywords_case_insensitive
.iter()
.all(|s| s.is_ascii()));
self.keywords_case_insensitive.clear();
self.keywords_case_insensitive.extend(
reserved_keywords_case_insensitive
.iter()
.map(|string| (AsciiUniCase(*string))),
);
let mut temp = String::new();
for (ty_handle, ty) in module.types.iter() {
let ty_name = self.call_or(&ty.name, "type");
output.insert(NameKey::Type(ty_handle), ty_name);
if let crate::TypeInner::Struct { ref members, .. } = ty.inner {
// struct members have their own namespace, because access is always prefixed
self.namespace(members.len(), |namer| {
for (index, member) in members.iter().enumerate() {
let name = namer.call_or(&member.name, "member");
output.insert(NameKey::StructMember(ty_handle, index as u32), name);
}
})
}
}
for (ep_index, ep) in module.entry_points.iter().enumerate() {
let ep_name = self.call(&ep.name);
output.insert(NameKey::EntryPoint(ep_index as _), ep_name);
for (index, arg) in ep.function.arguments.iter().enumerate() {
let name = self.call_or(&arg.name, "param");
output.insert(
NameKey::EntryPointArgument(ep_index as _, index as u32),
name,
);
}
for (handle, var) in ep.function.local_variables.iter() {
let name = self.call_or(&var.name, "local");
output.insert(NameKey::EntryPointLocal(ep_index as _, handle), name);
}
}
for (fun_handle, fun) in module.functions.iter() {
let fun_name = self.call_or(&fun.name, "function");
output.insert(NameKey::Function(fun_handle), fun_name);
for (index, arg) in fun.arguments.iter().enumerate() {
let name = self.call_or(&arg.name, "param");
output.insert(NameKey::FunctionArgument(fun_handle, index as u32), name);
}
for (handle, var) in fun.local_variables.iter() {
let name = self.call_or(&var.name, "local");
output.insert(NameKey::FunctionLocal(fun_handle, handle), name);
}
}
for (handle, var) in module.global_variables.iter() {
let name = self.call_or(&var.name, "global");
output.insert(NameKey::GlobalVariable(handle), name);
}
for (handle, constant) in module.constants.iter() {
let label = match constant.name {
Some(ref name) => name,
None => {
use std::fmt::Write;
// Try to be more descriptive about the constant values
temp.clear();
write!(temp, "const_{}", output[&NameKey::Type(constant.ty)]).unwrap();
&temp
}
};
let name = self.call(label);
output.insert(NameKey::Constant(handle), name);
}
}
}
/// A string wrapper type with an ascii case insensitive Eq and Hash impl
struct AsciiUniCase<S: AsRef<str> + ?Sized>(S);
impl<S: AsRef<str>> PartialEq<Self> for AsciiUniCase<S> {
#[inline]
fn eq(&self, other: &Self) -> bool {
self.0.as_ref().eq_ignore_ascii_case(other.0.as_ref())
}
}
impl<S: AsRef<str>> Eq for AsciiUniCase<S> {}
impl<S: AsRef<str>> Hash for AsciiUniCase<S> {
#[inline]
fn hash<H: Hasher>(&self, hasher: &mut H) {
for byte in self
.0
.as_ref()
.as_bytes()
.iter()
.map(|b| b.to_ascii_lowercase())
{
hasher.write_u8(byte);
}
}
}
#[test]
fn test() {
let mut namer = Namer::default();
assert_eq!(namer.call("x"), "x");
assert_eq!(namer.call("x"), "x_1");
assert_eq!(namer.call("x1"), "x1_");
}

View file

@ -0,0 +1,44 @@
/// Ensure that the given block has return statements
/// at the end of its control flow.
///
/// Note: we don't want to blindly append a return statement
/// to the end, because it may be either redundant or invalid,
/// e.g. when the user already has returns in if/else branches.
pub fn ensure_block_returns(block: &mut crate::Block) {
use crate::Statement as S;
match block.last_mut() {
Some(&mut S::Block(ref mut b)) => {
ensure_block_returns(b);
}
Some(&mut S::If {
condition: _,
ref mut accept,
ref mut reject,
}) => {
ensure_block_returns(accept);
ensure_block_returns(reject);
}
Some(&mut S::Switch {
selector: _,
ref mut cases,
}) => {
for case in cases.iter_mut() {
if !case.fall_through {
ensure_block_returns(&mut case.body);
}
}
}
Some(&mut (S::Emit(_) | S::Break | S::Continue | S::Return { .. } | S::Kill)) => (),
Some(
&mut (S::Loop { .. }
| S::Store { .. }
| S::ImageStore { .. }
| S::Call { .. }
| S::RayQuery { .. }
| S::Atomic { .. }
| S::WorkGroupUniformLoad { .. }
| S::Barrier(_)),
)
| None => block.push(S::Return { value: None }, Default::default()),
}
}

View file

@ -0,0 +1,894 @@
use crate::arena::{Arena, Handle, UniqueArena};
use thiserror::Error;
/// The result of computing an expression's type.
///
/// This is the (Rust) type returned by [`ResolveContext::resolve`] to represent
/// the (Naga) type it ascribes to some expression.
///
/// You might expect such a function to simply return a `Handle<Type>`. However,
/// we want type resolution to be a read-only process, and that would limit the
/// possible results to types already present in the expression's associated
/// `UniqueArena<Type>`. Naga IR does have certain expressions whose types are
/// not certain to be present.
///
/// So instead, type resolution returns a `TypeResolution` enum: either a
/// [`Handle`], referencing some type in the arena, or a [`Value`], holding a
/// free-floating [`TypeInner`]. This extends the range to cover anything that
/// can be represented with a `TypeInner` referring to the existing arena.
///
/// What sorts of expressions can have types not available in the arena?
///
/// - An [`Access`] or [`AccessIndex`] expression applied to a [`Vector`] or
/// [`Matrix`] must have a [`Scalar`] or [`Vector`] type. But since `Vector`
/// and `Matrix` represent their element and column types implicitly, not
/// via a handle, there may not be a suitable type in the expression's
/// associated arena. Instead, resolving such an expression returns a
/// `TypeResolution::Value(TypeInner::X { ... })`, where `X` is `Scalar` or
/// `Vector`.
///
/// - Similarly, the type of an [`Access`] or [`AccessIndex`] expression
/// applied to a *pointer to* a vector or matrix must produce a *pointer to*
/// a scalar or vector type. These cannot be represented with a
/// [`TypeInner::Pointer`], since the `Pointer`'s `base` must point into the
/// arena, and as before, we cannot assume that a suitable scalar or vector
/// type is there. So we take things one step further and provide
/// [`TypeInner::ValuePointer`], specifically for the case of pointers to
/// scalars or vectors. This type fits in a `TypeInner` and is exactly
/// equivalent to a `Pointer` to a `Vector` or `Scalar`.
///
/// So, for example, the type of an `Access` expression applied to a value of type:
///
/// ```ignore
/// TypeInner::Matrix { columns, rows, width }
/// ```
///
/// might be:
///
/// ```ignore
/// TypeResolution::Value(TypeInner::Vector {
/// size: rows,
/// kind: ScalarKind::Float,
/// width,
/// })
/// ```
///
/// and the type of an access to a pointer of address space `space` to such a
/// matrix might be:
///
/// ```ignore
/// TypeResolution::Value(TypeInner::ValuePointer {
/// size: Some(rows),
/// kind: ScalarKind::Float,
/// width,
/// space,
/// })
/// ```
///
/// [`Handle`]: TypeResolution::Handle
/// [`Value`]: TypeResolution::Value
///
/// [`Access`]: crate::Expression::Access
/// [`AccessIndex`]: crate::Expression::AccessIndex
///
/// [`TypeInner`]: crate::TypeInner
/// [`Matrix`]: crate::TypeInner::Matrix
/// [`Pointer`]: crate::TypeInner::Pointer
/// [`Scalar`]: crate::TypeInner::Scalar
/// [`ValuePointer`]: crate::TypeInner::ValuePointer
/// [`Vector`]: crate::TypeInner::Vector
///
/// [`TypeInner::Pointer`]: crate::TypeInner::Pointer
/// [`TypeInner::ValuePointer`]: crate::TypeInner::ValuePointer
#[derive(Debug, PartialEq)]
#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
pub enum TypeResolution {
/// A type stored in the associated arena.
Handle(Handle<crate::Type>),
/// A free-floating [`TypeInner`], representing a type that may not be
/// available in the associated arena. However, the `TypeInner` itself may
/// contain `Handle<Type>` values referring to types from the arena.
///
/// [`TypeInner`]: crate::TypeInner
Value(crate::TypeInner),
}
impl TypeResolution {
pub const fn handle(&self) -> Option<Handle<crate::Type>> {
match *self {
Self::Handle(handle) => Some(handle),
Self::Value(_) => None,
}
}
pub fn inner_with<'a>(&'a self, arena: &'a UniqueArena<crate::Type>) -> &'a crate::TypeInner {
match *self {
Self::Handle(handle) => &arena[handle].inner,
Self::Value(ref inner) => inner,
}
}
}
// Clone is only implemented for numeric variants of `TypeInner`.
impl Clone for TypeResolution {
fn clone(&self) -> Self {
use crate::TypeInner as Ti;
match *self {
Self::Handle(handle) => Self::Handle(handle),
Self::Value(ref v) => Self::Value(match *v {
Ti::Scalar { kind, width } => Ti::Scalar { kind, width },
Ti::Vector { size, kind, width } => Ti::Vector { size, kind, width },
Ti::Matrix {
rows,
columns,
width,
} => Ti::Matrix {
rows,
columns,
width,
},
Ti::Pointer { base, space } => Ti::Pointer { base, space },
Ti::ValuePointer {
size,
kind,
width,
space,
} => Ti::ValuePointer {
size,
kind,
width,
space,
},
_ => unreachable!("Unexpected clone type: {:?}", v),
}),
}
}
}
#[derive(Clone, Debug, Error, PartialEq)]
pub enum ResolveError {
#[error("Index {index} is out of bounds for expression {expr:?}")]
OutOfBoundsIndex {
expr: Handle<crate::Expression>,
index: u32,
},
#[error("Invalid access into expression {expr:?}, indexed: {indexed}")]
InvalidAccess {
expr: Handle<crate::Expression>,
indexed: bool,
},
#[error("Invalid sub-access into type {ty:?}, indexed: {indexed}")]
InvalidSubAccess {
ty: Handle<crate::Type>,
indexed: bool,
},
#[error("Invalid scalar {0:?}")]
InvalidScalar(Handle<crate::Expression>),
#[error("Invalid vector {0:?}")]
InvalidVector(Handle<crate::Expression>),
#[error("Invalid pointer {0:?}")]
InvalidPointer(Handle<crate::Expression>),
#[error("Invalid image {0:?}")]
InvalidImage(Handle<crate::Expression>),
#[error("Function {name} not defined")]
FunctionNotDefined { name: String },
#[error("Function without return type")]
FunctionReturnsVoid,
#[error("Incompatible operands: {0}")]
IncompatibleOperands(String),
#[error("Function argument {0} doesn't exist")]
FunctionArgumentNotFound(u32),
#[error("Special type is not registered within the module")]
MissingSpecialType,
}
pub struct ResolveContext<'a> {
pub constants: &'a Arena<crate::Constant>,
pub types: &'a UniqueArena<crate::Type>,
pub special_types: &'a crate::SpecialTypes,
pub global_vars: &'a Arena<crate::GlobalVariable>,
pub local_vars: &'a Arena<crate::LocalVariable>,
pub functions: &'a Arena<crate::Function>,
pub arguments: &'a [crate::FunctionArgument],
}
impl<'a> ResolveContext<'a> {
/// Initialize a resolve context from the module.
pub const fn with_locals(
module: &'a crate::Module,
local_vars: &'a Arena<crate::LocalVariable>,
arguments: &'a [crate::FunctionArgument],
) -> Self {
Self {
constants: &module.constants,
types: &module.types,
special_types: &module.special_types,
global_vars: &module.global_variables,
local_vars,
functions: &module.functions,
arguments,
}
}
/// Determine the type of `expr`.
///
/// The `past` argument must be a closure that can resolve the types of any
/// expressions that `expr` refers to. These can be gathered by caching the
/// results of prior calls to `resolve`, perhaps as done by the
/// [`front::Typifier`] utility type.
///
/// Type resolution is a read-only process: this method takes `self` by
/// shared reference. However, this means that we cannot add anything to
/// `self.types` that we might need to describe `expr`. To work around this,
/// this method returns a [`TypeResolution`], rather than simply returning a
/// `Handle<Type>`; see the documentation for [`TypeResolution`] for
/// details.
///
/// [`front::Typifier`]: crate::front::Typifier
pub fn resolve(
&self,
expr: &crate::Expression,
past: impl Fn(Handle<crate::Expression>) -> Result<&'a TypeResolution, ResolveError>,
) -> Result<TypeResolution, ResolveError> {
use crate::TypeInner as Ti;
let types = self.types;
Ok(match *expr {
crate::Expression::Access { base, .. } => match *past(base)?.inner_with(types) {
// Arrays and matrices can only be indexed dynamically behind a
// pointer, but that's a validation error, not a type error, so
// go ahead provide a type here.
Ti::Array { base, .. } => TypeResolution::Handle(base),
Ti::Matrix { rows, width, .. } => TypeResolution::Value(Ti::Vector {
size: rows,
kind: crate::ScalarKind::Float,
width,
}),
Ti::Vector {
size: _,
kind,
width,
} => TypeResolution::Value(Ti::Scalar { kind, width }),
Ti::ValuePointer {
size: Some(_),
kind,
width,
space,
} => TypeResolution::Value(Ti::ValuePointer {
size: None,
kind,
width,
space,
}),
Ti::Pointer { base, space } => {
TypeResolution::Value(match types[base].inner {
Ti::Array { base, .. } => Ti::Pointer { base, space },
Ti::Vector {
size: _,
kind,
width,
} => Ti::ValuePointer {
size: None,
kind,
width,
space,
},
// Matrices are only dynamically indexed behind a pointer
Ti::Matrix {
columns: _,
rows,
width,
} => Ti::ValuePointer {
kind: crate::ScalarKind::Float,
size: Some(rows),
width,
space,
},
Ti::BindingArray { base, .. } => Ti::Pointer { base, space },
ref other => {
log::error!("Access sub-type {:?}", other);
return Err(ResolveError::InvalidSubAccess {
ty: base,
indexed: false,
});
}
})
}
Ti::BindingArray { base, .. } => TypeResolution::Handle(base),
ref other => {
log::error!("Access type {:?}", other);
return Err(ResolveError::InvalidAccess {
expr: base,
indexed: false,
});
}
},
crate::Expression::AccessIndex { base, index } => {
match *past(base)?.inner_with(types) {
Ti::Vector { size, kind, width } => {
if index >= size as u32 {
return Err(ResolveError::OutOfBoundsIndex { expr: base, index });
}
TypeResolution::Value(Ti::Scalar { kind, width })
}
Ti::Matrix {
columns,
rows,
width,
} => {
if index >= columns as u32 {
return Err(ResolveError::OutOfBoundsIndex { expr: base, index });
}
TypeResolution::Value(crate::TypeInner::Vector {
size: rows,
kind: crate::ScalarKind::Float,
width,
})
}
Ti::Array { base, .. } => TypeResolution::Handle(base),
Ti::Struct { ref members, .. } => {
let member = members
.get(index as usize)
.ok_or(ResolveError::OutOfBoundsIndex { expr: base, index })?;
TypeResolution::Handle(member.ty)
}
Ti::ValuePointer {
size: Some(size),
kind,
width,
space,
} => {
if index >= size as u32 {
return Err(ResolveError::OutOfBoundsIndex { expr: base, index });
}
TypeResolution::Value(Ti::ValuePointer {
size: None,
kind,
width,
space,
})
}
Ti::Pointer {
base: ty_base,
space,
} => TypeResolution::Value(match types[ty_base].inner {
Ti::Array { base, .. } => Ti::Pointer { base, space },
Ti::Vector { size, kind, width } => {
if index >= size as u32 {
return Err(ResolveError::OutOfBoundsIndex { expr: base, index });
}
Ti::ValuePointer {
size: None,
kind,
width,
space,
}
}
Ti::Matrix {
rows,
columns,
width,
} => {
if index >= columns as u32 {
return Err(ResolveError::OutOfBoundsIndex { expr: base, index });
}
Ti::ValuePointer {
size: Some(rows),
kind: crate::ScalarKind::Float,
width,
space,
}
}
Ti::Struct { ref members, .. } => {
let member = members
.get(index as usize)
.ok_or(ResolveError::OutOfBoundsIndex { expr: base, index })?;
Ti::Pointer {
base: member.ty,
space,
}
}
Ti::BindingArray { base, .. } => Ti::Pointer { base, space },
ref other => {
log::error!("Access index sub-type {:?}", other);
return Err(ResolveError::InvalidSubAccess {
ty: ty_base,
indexed: true,
});
}
}),
Ti::BindingArray { base, .. } => TypeResolution::Handle(base),
ref other => {
log::error!("Access index type {:?}", other);
return Err(ResolveError::InvalidAccess {
expr: base,
indexed: true,
});
}
}
}
crate::Expression::Splat { size, value } => match *past(value)?.inner_with(types) {
Ti::Scalar { kind, width } => {
TypeResolution::Value(Ti::Vector { size, kind, width })
}
ref other => {
log::error!("Scalar type {:?}", other);
return Err(ResolveError::InvalidScalar(value));
}
},
crate::Expression::Swizzle {
size,
vector,
pattern: _,
} => match *past(vector)?.inner_with(types) {
Ti::Vector {
size: _,
kind,
width,
} => TypeResolution::Value(Ti::Vector { size, kind, width }),
ref other => {
log::error!("Vector type {:?}", other);
return Err(ResolveError::InvalidVector(vector));
}
},
crate::Expression::Literal(lit) => TypeResolution::Value(lit.ty_inner()),
crate::Expression::Constant(h) => TypeResolution::Handle(self.constants[h].ty),
crate::Expression::ZeroValue(ty) => TypeResolution::Handle(ty),
crate::Expression::Compose { ty, .. } => TypeResolution::Handle(ty),
crate::Expression::FunctionArgument(index) => {
let arg = self
.arguments
.get(index as usize)
.ok_or(ResolveError::FunctionArgumentNotFound(index))?;
TypeResolution::Handle(arg.ty)
}
crate::Expression::GlobalVariable(h) => {
let var = &self.global_vars[h];
if var.space == crate::AddressSpace::Handle {
TypeResolution::Handle(var.ty)
} else {
TypeResolution::Value(Ti::Pointer {
base: var.ty,
space: var.space,
})
}
}
crate::Expression::LocalVariable(h) => {
let var = &self.local_vars[h];
TypeResolution::Value(Ti::Pointer {
base: var.ty,
space: crate::AddressSpace::Function,
})
}
crate::Expression::Load { pointer } => match *past(pointer)?.inner_with(types) {
Ti::Pointer { base, space: _ } => {
if let Ti::Atomic { kind, width } = types[base].inner {
TypeResolution::Value(Ti::Scalar { kind, width })
} else {
TypeResolution::Handle(base)
}
}
Ti::ValuePointer {
size,
kind,
width,
space: _,
} => TypeResolution::Value(match size {
Some(size) => Ti::Vector { size, kind, width },
None => Ti::Scalar { kind, width },
}),
ref other => {
log::error!("Pointer type {:?}", other);
return Err(ResolveError::InvalidPointer(pointer));
}
},
crate::Expression::ImageSample {
image,
gather: Some(_),
..
} => match *past(image)?.inner_with(types) {
Ti::Image { class, .. } => TypeResolution::Value(Ti::Vector {
kind: match class {
crate::ImageClass::Sampled { kind, multi: _ } => kind,
_ => crate::ScalarKind::Float,
},
width: 4,
size: crate::VectorSize::Quad,
}),
ref other => {
log::error!("Image type {:?}", other);
return Err(ResolveError::InvalidImage(image));
}
},
crate::Expression::ImageSample { image, .. }
| crate::Expression::ImageLoad { image, .. } => match *past(image)?.inner_with(types) {
Ti::Image { class, .. } => TypeResolution::Value(match class {
crate::ImageClass::Depth { multi: _ } => Ti::Scalar {
kind: crate::ScalarKind::Float,
width: 4,
},
crate::ImageClass::Sampled { kind, multi: _ } => Ti::Vector {
kind,
width: 4,
size: crate::VectorSize::Quad,
},
crate::ImageClass::Storage { format, .. } => Ti::Vector {
kind: format.into(),
width: 4,
size: crate::VectorSize::Quad,
},
}),
ref other => {
log::error!("Image type {:?}", other);
return Err(ResolveError::InvalidImage(image));
}
},
crate::Expression::ImageQuery { image, query } => TypeResolution::Value(match query {
crate::ImageQuery::Size { level: _ } => match *past(image)?.inner_with(types) {
Ti::Image { dim, .. } => match dim {
crate::ImageDimension::D1 => Ti::Scalar {
kind: crate::ScalarKind::Uint,
width: 4,
},
crate::ImageDimension::D2 | crate::ImageDimension::Cube => Ti::Vector {
size: crate::VectorSize::Bi,
kind: crate::ScalarKind::Uint,
width: 4,
},
crate::ImageDimension::D3 => Ti::Vector {
size: crate::VectorSize::Tri,
kind: crate::ScalarKind::Uint,
width: 4,
},
},
ref other => {
log::error!("Image type {:?}", other);
return Err(ResolveError::InvalidImage(image));
}
},
crate::ImageQuery::NumLevels
| crate::ImageQuery::NumLayers
| crate::ImageQuery::NumSamples => Ti::Scalar {
kind: crate::ScalarKind::Uint,
width: 4,
},
}),
crate::Expression::Unary { expr, .. } => past(expr)?.clone(),
crate::Expression::Binary { op, left, right } => match op {
crate::BinaryOperator::Add
| crate::BinaryOperator::Subtract
| crate::BinaryOperator::Divide
| crate::BinaryOperator::Modulo => past(left)?.clone(),
crate::BinaryOperator::Multiply => {
let (res_left, res_right) = (past(left)?, past(right)?);
match (res_left.inner_with(types), res_right.inner_with(types)) {
(
&Ti::Matrix {
columns: _,
rows,
width,
},
&Ti::Matrix { columns, .. },
) => TypeResolution::Value(Ti::Matrix {
columns,
rows,
width,
}),
(
&Ti::Matrix {
columns: _,
rows,
width,
},
&Ti::Vector { .. },
) => TypeResolution::Value(Ti::Vector {
size: rows,
kind: crate::ScalarKind::Float,
width,
}),
(
&Ti::Vector { .. },
&Ti::Matrix {
columns,
rows: _,
width,
},
) => TypeResolution::Value(Ti::Vector {
size: columns,
kind: crate::ScalarKind::Float,
width,
}),
(&Ti::Scalar { .. }, _) => res_right.clone(),
(_, &Ti::Scalar { .. }) => res_left.clone(),
(&Ti::Vector { .. }, &Ti::Vector { .. }) => res_left.clone(),
(tl, tr) => {
return Err(ResolveError::IncompatibleOperands(format!(
"{tl:?} * {tr:?}"
)))
}
}
}
crate::BinaryOperator::Equal
| crate::BinaryOperator::NotEqual
| crate::BinaryOperator::Less
| crate::BinaryOperator::LessEqual
| crate::BinaryOperator::Greater
| crate::BinaryOperator::GreaterEqual
| crate::BinaryOperator::LogicalAnd
| crate::BinaryOperator::LogicalOr => {
let kind = crate::ScalarKind::Bool;
let width = crate::BOOL_WIDTH;
let inner = match *past(left)?.inner_with(types) {
Ti::Scalar { .. } => Ti::Scalar { kind, width },
Ti::Vector { size, .. } => Ti::Vector { size, kind, width },
ref other => {
return Err(ResolveError::IncompatibleOperands(format!(
"{op:?}({other:?}, _)"
)))
}
};
TypeResolution::Value(inner)
}
crate::BinaryOperator::And
| crate::BinaryOperator::ExclusiveOr
| crate::BinaryOperator::InclusiveOr
| crate::BinaryOperator::ShiftLeft
| crate::BinaryOperator::ShiftRight => past(left)?.clone(),
},
crate::Expression::AtomicResult { ty, .. } => TypeResolution::Handle(ty),
crate::Expression::WorkGroupUniformLoadResult { ty } => TypeResolution::Handle(ty),
crate::Expression::Select { accept, .. } => past(accept)?.clone(),
crate::Expression::Derivative { expr, .. } => past(expr)?.clone(),
crate::Expression::Relational { fun, argument } => match fun {
crate::RelationalFunction::All | crate::RelationalFunction::Any => {
TypeResolution::Value(Ti::Scalar {
kind: crate::ScalarKind::Bool,
width: crate::BOOL_WIDTH,
})
}
crate::RelationalFunction::IsNan
| crate::RelationalFunction::IsInf
| crate::RelationalFunction::IsFinite
| crate::RelationalFunction::IsNormal => match *past(argument)?.inner_with(types) {
Ti::Scalar { .. } => TypeResolution::Value(Ti::Scalar {
kind: crate::ScalarKind::Bool,
width: crate::BOOL_WIDTH,
}),
Ti::Vector { size, .. } => TypeResolution::Value(Ti::Vector {
kind: crate::ScalarKind::Bool,
width: crate::BOOL_WIDTH,
size,
}),
ref other => {
return Err(ResolveError::IncompatibleOperands(format!(
"{fun:?}({other:?})"
)))
}
},
},
crate::Expression::Math {
fun,
arg,
arg1,
arg2: _,
arg3: _,
} => {
use crate::MathFunction as Mf;
let res_arg = past(arg)?;
match fun {
// comparison
Mf::Abs |
Mf::Min |
Mf::Max |
Mf::Clamp |
Mf::Saturate |
// trigonometry
Mf::Cos |
Mf::Cosh |
Mf::Sin |
Mf::Sinh |
Mf::Tan |
Mf::Tanh |
Mf::Acos |
Mf::Asin |
Mf::Atan |
Mf::Atan2 |
Mf::Asinh |
Mf::Acosh |
Mf::Atanh |
Mf::Radians |
Mf::Degrees |
// decomposition
Mf::Ceil |
Mf::Floor |
Mf::Round |
Mf::Fract |
Mf::Trunc |
Mf::Modf |
Mf::Frexp |
Mf::Ldexp |
// exponent
Mf::Exp |
Mf::Exp2 |
Mf::Log |
Mf::Log2 |
Mf::Pow => res_arg.clone(),
// geometry
Mf::Dot => match *res_arg.inner_with(types) {
Ti::Vector {
kind,
size: _,
width,
} => TypeResolution::Value(Ti::Scalar { kind, width }),
ref other =>
return Err(ResolveError::IncompatibleOperands(
format!("{fun:?}({other:?}, _)")
)),
},
Mf::Outer => {
let arg1 = arg1.ok_or_else(|| ResolveError::IncompatibleOperands(
format!("{fun:?}(_, None)")
))?;
match (res_arg.inner_with(types), past(arg1)?.inner_with(types)) {
(&Ti::Vector {kind: _, size: columns,width}, &Ti::Vector{ size: rows, .. }) => TypeResolution::Value(Ti::Matrix { columns, rows, width }),
(left, right) =>
return Err(ResolveError::IncompatibleOperands(
format!("{fun:?}({left:?}, {right:?})")
)),
}
},
Mf::Cross => res_arg.clone(),
Mf::Distance |
Mf::Length => match *res_arg.inner_with(types) {
Ti::Scalar {width,kind} |
Ti::Vector {width,kind,size:_} => TypeResolution::Value(Ti::Scalar { kind, width }),
ref other => return Err(ResolveError::IncompatibleOperands(
format!("{fun:?}({other:?})")
)),
},
Mf::Normalize |
Mf::FaceForward |
Mf::Reflect |
Mf::Refract => res_arg.clone(),
// computational
Mf::Sign |
Mf::Fma |
Mf::Mix |
Mf::Step |
Mf::SmoothStep |
Mf::Sqrt |
Mf::InverseSqrt => res_arg.clone(),
Mf::Transpose => match *res_arg.inner_with(types) {
Ti::Matrix {
columns,
rows,
width,
} => TypeResolution::Value(Ti::Matrix {
columns: rows,
rows: columns,
width,
}),
ref other => return Err(ResolveError::IncompatibleOperands(
format!("{fun:?}({other:?})")
)),
},
Mf::Inverse => match *res_arg.inner_with(types) {
Ti::Matrix {
columns,
rows,
width,
} if columns == rows => TypeResolution::Value(Ti::Matrix {
columns,
rows,
width,
}),
ref other => return Err(ResolveError::IncompatibleOperands(
format!("{fun:?}({other:?})")
)),
},
Mf::Determinant => match *res_arg.inner_with(types) {
Ti::Matrix {
width,
..
} => TypeResolution::Value(Ti::Scalar { kind: crate::ScalarKind::Float, width }),
ref other => return Err(ResolveError::IncompatibleOperands(
format!("{fun:?}({other:?})")
)),
},
// bits
Mf::CountTrailingZeros |
Mf::CountLeadingZeros |
Mf::CountOneBits |
Mf::ReverseBits |
Mf::ExtractBits |
Mf::InsertBits |
Mf::FindLsb |
Mf::FindMsb => match *res_arg.inner_with(types) {
Ti::Scalar { kind: kind @ (crate::ScalarKind::Sint | crate::ScalarKind::Uint), width } =>
TypeResolution::Value(Ti::Scalar { kind, width }),
Ti::Vector { size, kind: kind @ (crate::ScalarKind::Sint | crate::ScalarKind::Uint), width } =>
TypeResolution::Value(Ti::Vector { size, kind, width }),
ref other => return Err(ResolveError::IncompatibleOperands(
format!("{fun:?}({other:?})")
)),
},
// data packing
Mf::Pack4x8snorm |
Mf::Pack4x8unorm |
Mf::Pack2x16snorm |
Mf::Pack2x16unorm |
Mf::Pack2x16float => TypeResolution::Value(Ti::Scalar { kind: crate::ScalarKind::Uint, width: 4 }),
// data unpacking
Mf::Unpack4x8snorm |
Mf::Unpack4x8unorm => TypeResolution::Value(Ti::Vector { size: crate::VectorSize::Quad, kind: crate::ScalarKind::Float, width: 4 }),
Mf::Unpack2x16snorm |
Mf::Unpack2x16unorm |
Mf::Unpack2x16float => TypeResolution::Value(Ti::Vector { size: crate::VectorSize::Bi, kind: crate::ScalarKind::Float, width: 4 }),
}
}
crate::Expression::As {
expr,
kind,
convert,
} => match *past(expr)?.inner_with(types) {
Ti::Scalar { kind: _, width } => TypeResolution::Value(Ti::Scalar {
kind,
width: convert.unwrap_or(width),
}),
Ti::Vector {
kind: _,
size,
width,
} => TypeResolution::Value(Ti::Vector {
kind,
size,
width: convert.unwrap_or(width),
}),
Ti::Matrix {
columns,
rows,
width,
} => TypeResolution::Value(Ti::Matrix {
columns,
rows,
width: convert.unwrap_or(width),
}),
ref other => {
return Err(ResolveError::IncompatibleOperands(format!(
"{other:?} as {kind:?}"
)))
}
},
crate::Expression::CallResult(function) => {
let result = self.functions[function]
.result
.as_ref()
.ok_or(ResolveError::FunctionReturnsVoid)?;
TypeResolution::Handle(result.ty)
}
crate::Expression::ArrayLength(_) => TypeResolution::Value(Ti::Scalar {
kind: crate::ScalarKind::Uint,
width: 4,
}),
crate::Expression::RayQueryProceedResult => TypeResolution::Value(Ti::Scalar {
kind: crate::ScalarKind::Bool,
width: crate::BOOL_WIDTH,
}),
crate::Expression::RayQueryGetIntersection { .. } => {
let result = self
.special_types
.ray_intersection
.ok_or(ResolveError::MissingSpecialType)?;
TypeResolution::Handle(result)
}
})
}
}
#[test]
fn test_error_size() {
use std::mem::size_of;
assert_eq!(size_of::<ResolveError>(), 32);
}