Vendor things

This commit is contained in:
John Doty 2024-03-08 11:03:01 -08:00
parent 5deceec006
commit 977e3c17e5
19434 changed files with 10682014 additions and 0 deletions

View file

@ -0,0 +1,486 @@
use crate::front::wgsl::parse::number::Number;
use crate::{Arena, FastHashSet, Handle, Span};
use std::hash::Hash;
#[derive(Debug, Default)]
pub struct TranslationUnit<'a> {
pub decls: Arena<GlobalDecl<'a>>,
/// The common expressions arena for the entire translation unit.
///
/// All functions, global initializers, array lengths, etc. store their
/// expressions here. We apportion these out to individual Naga
/// [`Function`]s' expression arenas at lowering time. Keeping them all in a
/// single arena simplifies handling of things like array lengths (which are
/// effectively global and thus don't clearly belong to any function) and
/// initializers (which can appear in both function-local and module-scope
/// contexts).
///
/// [`Function`]: crate::Function
pub expressions: Arena<Expression<'a>>,
/// Non-user-defined types, like `vec4<f32>` or `array<i32, 10>`.
///
/// These are referred to by `Handle<ast::Type<'a>>` values.
/// User-defined types are referred to by name until lowering.
pub types: Arena<Type<'a>>,
}
#[derive(Debug, Clone, Copy)]
pub struct Ident<'a> {
pub name: &'a str,
pub span: Span,
}
#[derive(Debug)]
pub enum IdentExpr<'a> {
Unresolved(&'a str),
Local(Handle<Local>),
}
/// A reference to a module-scope definition or predeclared object.
///
/// Each [`GlobalDecl`] holds a set of these values, to be resolved to
/// specific definitions later. To support de-duplication, `Eq` and
/// `Hash` on a `Dependency` value consider only the name, not the
/// source location at which the reference occurs.
#[derive(Debug)]
pub struct Dependency<'a> {
/// The name referred to.
pub ident: &'a str,
/// The location at which the reference to that name occurs.
pub usage: Span,
}
impl Hash for Dependency<'_> {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.ident.hash(state);
}
}
impl PartialEq for Dependency<'_> {
fn eq(&self, other: &Self) -> bool {
self.ident == other.ident
}
}
impl Eq for Dependency<'_> {}
/// A module-scope declaration.
#[derive(Debug)]
pub struct GlobalDecl<'a> {
pub kind: GlobalDeclKind<'a>,
/// Names of all module-scope or predeclared objects this
/// declaration uses.
pub dependencies: FastHashSet<Dependency<'a>>,
}
#[derive(Debug)]
pub enum GlobalDeclKind<'a> {
Fn(Function<'a>),
Var(GlobalVariable<'a>),
Const(Const<'a>),
Struct(Struct<'a>),
Type(TypeAlias<'a>),
}
#[derive(Debug)]
pub struct FunctionArgument<'a> {
pub name: Ident<'a>,
pub ty: Handle<Type<'a>>,
pub binding: Option<crate::Binding>,
pub handle: Handle<Local>,
}
#[derive(Debug)]
pub struct FunctionResult<'a> {
pub ty: Handle<Type<'a>>,
pub binding: Option<crate::Binding>,
}
#[derive(Debug)]
pub struct EntryPoint {
pub stage: crate::ShaderStage,
pub early_depth_test: Option<crate::EarlyDepthTest>,
pub workgroup_size: [u32; 3],
}
#[cfg(doc)]
use crate::front::wgsl::lower::{RuntimeExpressionContext, StatementContext};
#[derive(Debug)]
pub struct Function<'a> {
pub entry_point: Option<EntryPoint>,
pub name: Ident<'a>,
pub arguments: Vec<FunctionArgument<'a>>,
pub result: Option<FunctionResult<'a>>,
/// Local variable and function argument arena.
///
/// Note that the `Local` here is actually a zero-sized type. The AST keeps
/// all the detailed information about locals - names, types, etc. - in
/// [`LocalDecl`] statements. For arguments, that information is kept in
/// [`arguments`]. This `Arena`'s only role is to assign a unique `Handle`
/// to each of them, and track their definitions' spans for use in
/// diagnostics.
///
/// In the AST, when an [`Ident`] expression refers to a local variable or
/// argument, its [`IdentExpr`] holds the referent's `Handle<Local>` in this
/// arena.
///
/// During lowering, [`LocalDecl`] statements add entries to a per-function
/// table that maps `Handle<Local>` values to their Naga representations,
/// accessed via [`StatementContext::local_table`] and
/// [`RuntimeExpressionContext::local_table`]. This table is then consulted when
/// lowering subsequent [`Ident`] expressions.
///
/// [`LocalDecl`]: StatementKind::LocalDecl
/// [`arguments`]: Function::arguments
/// [`Ident`]: Expression::Ident
/// [`StatementContext::local_table`]: StatementContext::local_table
/// [`RuntimeExpressionContext::local_table`]: RuntimeExpressionContext::local_table
pub locals: Arena<Local>,
pub body: Block<'a>,
}
#[derive(Debug)]
pub struct GlobalVariable<'a> {
pub name: Ident<'a>,
pub space: crate::AddressSpace,
pub binding: Option<crate::ResourceBinding>,
pub ty: Handle<Type<'a>>,
pub init: Option<Handle<Expression<'a>>>,
}
#[derive(Debug)]
pub struct StructMember<'a> {
pub name: Ident<'a>,
pub ty: Handle<Type<'a>>,
pub binding: Option<crate::Binding>,
pub align: Option<(u32, Span)>,
pub size: Option<(u32, Span)>,
}
#[derive(Debug)]
pub struct Struct<'a> {
pub name: Ident<'a>,
pub members: Vec<StructMember<'a>>,
}
#[derive(Debug)]
pub struct TypeAlias<'a> {
pub name: Ident<'a>,
pub ty: Handle<Type<'a>>,
}
#[derive(Debug)]
pub struct Const<'a> {
pub name: Ident<'a>,
pub ty: Option<Handle<Type<'a>>>,
pub init: Handle<Expression<'a>>,
}
/// The size of an [`Array`] or [`BindingArray`].
///
/// [`Array`]: Type::Array
/// [`BindingArray`]: Type::BindingArray
#[derive(Debug, Copy, Clone)]
pub enum ArraySize<'a> {
/// The length as a constant expression.
Constant(Handle<Expression<'a>>),
Dynamic,
}
#[derive(Debug)]
pub enum Type<'a> {
Scalar {
kind: crate::ScalarKind,
width: crate::Bytes,
},
Vector {
size: crate::VectorSize,
kind: crate::ScalarKind,
width: crate::Bytes,
},
Matrix {
columns: crate::VectorSize,
rows: crate::VectorSize,
width: crate::Bytes,
},
Atomic {
kind: crate::ScalarKind,
width: crate::Bytes,
},
Pointer {
base: Handle<Type<'a>>,
space: crate::AddressSpace,
},
Array {
base: Handle<Type<'a>>,
size: ArraySize<'a>,
},
Image {
dim: crate::ImageDimension,
arrayed: bool,
class: crate::ImageClass,
},
Sampler {
comparison: bool,
},
AccelerationStructure,
RayQuery,
RayDesc,
RayIntersection,
BindingArray {
base: Handle<Type<'a>>,
size: ArraySize<'a>,
},
/// A user-defined type, like a struct or a type alias.
User(Ident<'a>),
}
#[derive(Debug, Default)]
pub struct Block<'a> {
pub stmts: Vec<Statement<'a>>,
}
#[derive(Debug)]
pub struct Statement<'a> {
pub kind: StatementKind<'a>,
pub span: Span,
}
#[derive(Debug)]
pub enum StatementKind<'a> {
LocalDecl(LocalDecl<'a>),
Block(Block<'a>),
If {
condition: Handle<Expression<'a>>,
accept: Block<'a>,
reject: Block<'a>,
},
Switch {
selector: Handle<Expression<'a>>,
cases: Vec<SwitchCase<'a>>,
},
Loop {
body: Block<'a>,
continuing: Block<'a>,
break_if: Option<Handle<Expression<'a>>>,
},
Break,
Continue,
Return {
value: Option<Handle<Expression<'a>>>,
},
Kill,
Call {
function: Ident<'a>,
arguments: Vec<Handle<Expression<'a>>>,
},
Assign {
target: Handle<Expression<'a>>,
op: Option<crate::BinaryOperator>,
value: Handle<Expression<'a>>,
},
Increment(Handle<Expression<'a>>),
Decrement(Handle<Expression<'a>>),
Ignore(Handle<Expression<'a>>),
}
#[derive(Debug)]
pub enum SwitchValue {
I32(i32),
U32(u32),
Default,
}
#[derive(Debug)]
pub struct SwitchCase<'a> {
pub value: SwitchValue,
pub value_span: Span,
pub body: Block<'a>,
pub fall_through: bool,
}
/// A type at the head of a [`Construct`] expression.
///
/// WGSL has two types of [`type constructor expressions`]:
///
/// - Those that fully specify the type being constructed, like
/// `vec3<f32>(x,y,z)`, which obviously constructs a `vec3<f32>`.
///
/// - Those that leave the component type of the composite being constructed
/// implicit, to be inferred from the argument types, like `vec3(x,y,z)`,
/// which constructs a `vec3<T>` where `T` is the type of `x`, `y`, and `z`.
///
/// This enum represents the head type of both cases. The `PartialFoo` variants
/// represent the second case, where the component type is implicit.
///
/// This does not cover structs or types referred to by type aliases. See the
/// documentation for [`Construct`] and [`Call`] expressions for details.
///
/// [`Construct`]: Expression::Construct
/// [`type constructor expressions`]: https://gpuweb.github.io/gpuweb/wgsl/#type-constructor-expr
/// [`Call`]: Expression::Call
#[derive(Debug)]
pub enum ConstructorType<'a> {
/// A scalar type or conversion: `f32(1)`.
Scalar {
kind: crate::ScalarKind,
width: crate::Bytes,
},
/// A vector construction whose component type is inferred from the
/// argument: `vec3(1.0)`.
PartialVector { size: crate::VectorSize },
/// A vector construction whose component type is written out:
/// `vec3<f32>(1.0)`.
Vector {
size: crate::VectorSize,
kind: crate::ScalarKind,
width: crate::Bytes,
},
/// A matrix construction whose component type is inferred from the
/// argument: `mat2x2(1,2,3,4)`.
PartialMatrix {
columns: crate::VectorSize,
rows: crate::VectorSize,
},
/// A matrix construction whose component type is written out:
/// `mat2x2<f32>(1,2,3,4)`.
Matrix {
columns: crate::VectorSize,
rows: crate::VectorSize,
width: crate::Bytes,
},
/// An array whose component type and size are inferred from the arguments:
/// `array(3,4,5)`.
PartialArray,
/// An array whose component type and size are written out:
/// `array<u32, 4>(3,4,5)`.
Array {
base: Handle<Type<'a>>,
size: ArraySize<'a>,
},
/// Constructing a value of a known Naga IR type.
///
/// This variant is produced only during lowering, when we have Naga types
/// available, never during parsing.
Type(Handle<crate::Type>),
}
#[derive(Debug, Copy, Clone)]
pub enum Literal {
Bool(bool),
Number(Number),
}
#[cfg(doc)]
use crate::front::wgsl::lower::Lowerer;
#[derive(Debug)]
pub enum Expression<'a> {
Literal(Literal),
Ident(IdentExpr<'a>),
/// A type constructor expression.
///
/// This is only used for expressions like `KEYWORD(EXPR...)` and
/// `KEYWORD<PARAM>(EXPR...)`, where `KEYWORD` is a [type-defining keyword] like
/// `vec3`. These keywords cannot be shadowed by user definitions, so we can
/// tell that such an expression is a construction immediately.
///
/// For ordinary identifiers, we can't tell whether an expression like
/// `IDENTIFIER(EXPR, ...)` is a construction expression or a function call
/// until we know `IDENTIFIER`'s definition, so we represent those as
/// [`Call`] expressions.
///
/// [type-defining keyword]: https://gpuweb.github.io/gpuweb/wgsl/#type-defining-keywords
/// [`Call`]: Expression::Call
Construct {
ty: ConstructorType<'a>,
ty_span: Span,
components: Vec<Handle<Expression<'a>>>,
},
Unary {
op: crate::UnaryOperator,
expr: Handle<Expression<'a>>,
},
AddrOf(Handle<Expression<'a>>),
Deref(Handle<Expression<'a>>),
Binary {
op: crate::BinaryOperator,
left: Handle<Expression<'a>>,
right: Handle<Expression<'a>>,
},
/// A function call or type constructor expression.
///
/// We can't tell whether an expression like `IDENTIFIER(EXPR, ...)` is a
/// construction expression or a function call until we know `IDENTIFIER`'s
/// definition, so we represent everything of that form as one of these
/// expressions until lowering. At that point, [`Lowerer::call`] has
/// everything's definition in hand, and can decide whether to emit a Naga
/// [`Constant`], [`As`], [`Splat`], or [`Compose`] expression.
///
/// [`Lowerer::call`]: Lowerer::call
/// [`Constant`]: crate::Expression::Constant
/// [`As`]: crate::Expression::As
/// [`Splat`]: crate::Expression::Splat
/// [`Compose`]: crate::Expression::Compose
Call {
function: Ident<'a>,
arguments: Vec<Handle<Expression<'a>>>,
},
Index {
base: Handle<Expression<'a>>,
index: Handle<Expression<'a>>,
},
Member {
base: Handle<Expression<'a>>,
field: Ident<'a>,
},
Bitcast {
expr: Handle<Expression<'a>>,
to: Handle<Type<'a>>,
ty_span: Span,
},
}
#[derive(Debug)]
pub struct LocalVariable<'a> {
pub name: Ident<'a>,
pub ty: Option<Handle<Type<'a>>>,
pub init: Option<Handle<Expression<'a>>>,
pub handle: Handle<Local>,
}
#[derive(Debug)]
pub struct Let<'a> {
pub name: Ident<'a>,
pub ty: Option<Handle<Type<'a>>>,
pub init: Handle<Expression<'a>>,
pub handle: Handle<Local>,
}
#[derive(Debug)]
pub enum LocalDecl<'a> {
Var(LocalVariable<'a>),
Let(Let<'a>),
}
#[derive(Debug)]
/// A placeholder for a local variable declaration.
///
/// See [`Function::locals`] for more information.
pub struct Local;

View file

@ -0,0 +1,236 @@
use super::Error;
use crate::Span;
pub fn map_address_space(word: &str, span: Span) -> Result<crate::AddressSpace, Error<'_>> {
match word {
"private" => Ok(crate::AddressSpace::Private),
"workgroup" => Ok(crate::AddressSpace::WorkGroup),
"uniform" => Ok(crate::AddressSpace::Uniform),
"storage" => Ok(crate::AddressSpace::Storage {
access: crate::StorageAccess::default(),
}),
"push_constant" => Ok(crate::AddressSpace::PushConstant),
"function" => Ok(crate::AddressSpace::Function),
_ => Err(Error::UnknownAddressSpace(span)),
}
}
pub fn map_built_in(word: &str, span: Span) -> Result<crate::BuiltIn, Error<'_>> {
Ok(match word {
"position" => crate::BuiltIn::Position { invariant: false },
// vertex
"vertex_index" => crate::BuiltIn::VertexIndex,
"instance_index" => crate::BuiltIn::InstanceIndex,
"view_index" => crate::BuiltIn::ViewIndex,
// fragment
"front_facing" => crate::BuiltIn::FrontFacing,
"frag_depth" => crate::BuiltIn::FragDepth,
"primitive_index" => crate::BuiltIn::PrimitiveIndex,
"sample_index" => crate::BuiltIn::SampleIndex,
"sample_mask" => crate::BuiltIn::SampleMask,
// compute
"global_invocation_id" => crate::BuiltIn::GlobalInvocationId,
"local_invocation_id" => crate::BuiltIn::LocalInvocationId,
"local_invocation_index" => crate::BuiltIn::LocalInvocationIndex,
"workgroup_id" => crate::BuiltIn::WorkGroupId,
"num_workgroups" => crate::BuiltIn::NumWorkGroups,
_ => return Err(Error::UnknownBuiltin(span)),
})
}
pub fn map_interpolation(word: &str, span: Span) -> Result<crate::Interpolation, Error<'_>> {
match word {
"linear" => Ok(crate::Interpolation::Linear),
"flat" => Ok(crate::Interpolation::Flat),
"perspective" => Ok(crate::Interpolation::Perspective),
_ => Err(Error::UnknownAttribute(span)),
}
}
pub fn map_sampling(word: &str, span: Span) -> Result<crate::Sampling, Error<'_>> {
match word {
"center" => Ok(crate::Sampling::Center),
"centroid" => Ok(crate::Sampling::Centroid),
"sample" => Ok(crate::Sampling::Sample),
_ => Err(Error::UnknownAttribute(span)),
}
}
pub fn map_storage_format(word: &str, span: Span) -> Result<crate::StorageFormat, Error<'_>> {
use crate::StorageFormat as Sf;
Ok(match word {
"r8unorm" => Sf::R8Unorm,
"r8snorm" => Sf::R8Snorm,
"r8uint" => Sf::R8Uint,
"r8sint" => Sf::R8Sint,
"r16unorm" => Sf::R16Unorm,
"r16snorm" => Sf::R16Snorm,
"r16uint" => Sf::R16Uint,
"r16sint" => Sf::R16Sint,
"r16float" => Sf::R16Float,
"rg8unorm" => Sf::Rg8Unorm,
"rg8snorm" => Sf::Rg8Snorm,
"rg8uint" => Sf::Rg8Uint,
"rg8sint" => Sf::Rg8Sint,
"r32uint" => Sf::R32Uint,
"r32sint" => Sf::R32Sint,
"r32float" => Sf::R32Float,
"rg16unorm" => Sf::Rg16Unorm,
"rg16snorm" => Sf::Rg16Snorm,
"rg16uint" => Sf::Rg16Uint,
"rg16sint" => Sf::Rg16Sint,
"rg16float" => Sf::Rg16Float,
"rgba8unorm" => Sf::Rgba8Unorm,
"rgba8snorm" => Sf::Rgba8Snorm,
"rgba8uint" => Sf::Rgba8Uint,
"rgba8sint" => Sf::Rgba8Sint,
"rgb10a2unorm" => Sf::Rgb10a2Unorm,
"rg11b10float" => Sf::Rg11b10Float,
"rg32uint" => Sf::Rg32Uint,
"rg32sint" => Sf::Rg32Sint,
"rg32float" => Sf::Rg32Float,
"rgba16unorm" => Sf::Rgba16Unorm,
"rgba16snorm" => Sf::Rgba16Snorm,
"rgba16uint" => Sf::Rgba16Uint,
"rgba16sint" => Sf::Rgba16Sint,
"rgba16float" => Sf::Rgba16Float,
"rgba32uint" => Sf::Rgba32Uint,
"rgba32sint" => Sf::Rgba32Sint,
"rgba32float" => Sf::Rgba32Float,
_ => return Err(Error::UnknownStorageFormat(span)),
})
}
pub fn get_scalar_type(word: &str) -> Option<(crate::ScalarKind, crate::Bytes)> {
match word {
// "f16" => Some((crate::ScalarKind::Float, 2)),
"f32" => Some((crate::ScalarKind::Float, 4)),
"f64" => Some((crate::ScalarKind::Float, 8)),
"i32" => Some((crate::ScalarKind::Sint, 4)),
"u32" => Some((crate::ScalarKind::Uint, 4)),
"bool" => Some((crate::ScalarKind::Bool, crate::BOOL_WIDTH)),
_ => None,
}
}
pub fn map_derivative(word: &str) -> Option<(crate::DerivativeAxis, crate::DerivativeControl)> {
use crate::{DerivativeAxis as Axis, DerivativeControl as Ctrl};
match word {
"dpdxCoarse" => Some((Axis::X, Ctrl::Coarse)),
"dpdyCoarse" => Some((Axis::Y, Ctrl::Coarse)),
"fwidthCoarse" => Some((Axis::Width, Ctrl::Coarse)),
"dpdxFine" => Some((Axis::X, Ctrl::Fine)),
"dpdyFine" => Some((Axis::Y, Ctrl::Fine)),
"fwidthFine" => Some((Axis::Width, Ctrl::Fine)),
"dpdx" => Some((Axis::X, Ctrl::None)),
"dpdy" => Some((Axis::Y, Ctrl::None)),
"fwidth" => Some((Axis::Width, Ctrl::None)),
_ => None,
}
}
pub fn map_relational_fun(word: &str) -> Option<crate::RelationalFunction> {
match word {
"any" => Some(crate::RelationalFunction::Any),
"all" => Some(crate::RelationalFunction::All),
_ => None,
}
}
pub fn map_standard_fun(word: &str) -> Option<crate::MathFunction> {
use crate::MathFunction as Mf;
Some(match word {
// comparison
"abs" => Mf::Abs,
"min" => Mf::Min,
"max" => Mf::Max,
"clamp" => Mf::Clamp,
"saturate" => Mf::Saturate,
// trigonometry
"cos" => Mf::Cos,
"cosh" => Mf::Cosh,
"sin" => Mf::Sin,
"sinh" => Mf::Sinh,
"tan" => Mf::Tan,
"tanh" => Mf::Tanh,
"acos" => Mf::Acos,
"acosh" => Mf::Acosh,
"asin" => Mf::Asin,
"asinh" => Mf::Asinh,
"atan" => Mf::Atan,
"atanh" => Mf::Atanh,
"atan2" => Mf::Atan2,
"radians" => Mf::Radians,
"degrees" => Mf::Degrees,
// decomposition
"ceil" => Mf::Ceil,
"floor" => Mf::Floor,
"round" => Mf::Round,
"fract" => Mf::Fract,
"trunc" => Mf::Trunc,
"modf" => Mf::Modf,
"frexp" => Mf::Frexp,
"ldexp" => Mf::Ldexp,
// exponent
"exp" => Mf::Exp,
"exp2" => Mf::Exp2,
"log" => Mf::Log,
"log2" => Mf::Log2,
"pow" => Mf::Pow,
// geometry
"dot" => Mf::Dot,
"outerProduct" => Mf::Outer,
"cross" => Mf::Cross,
"distance" => Mf::Distance,
"length" => Mf::Length,
"normalize" => Mf::Normalize,
"faceForward" => Mf::FaceForward,
"reflect" => Mf::Reflect,
"refract" => Mf::Refract,
// computational
"sign" => Mf::Sign,
"fma" => Mf::Fma,
"mix" => Mf::Mix,
"step" => Mf::Step,
"smoothstep" => Mf::SmoothStep,
"sqrt" => Mf::Sqrt,
"inverseSqrt" => Mf::InverseSqrt,
"transpose" => Mf::Transpose,
"determinant" => Mf::Determinant,
// bits
"countTrailingZeros" => Mf::CountTrailingZeros,
"countLeadingZeros" => Mf::CountLeadingZeros,
"countOneBits" => Mf::CountOneBits,
"reverseBits" => Mf::ReverseBits,
"extractBits" => Mf::ExtractBits,
"insertBits" => Mf::InsertBits,
"firstTrailingBit" => Mf::FindLsb,
"firstLeadingBit" => Mf::FindMsb,
// data packing
"pack4x8snorm" => Mf::Pack4x8snorm,
"pack4x8unorm" => Mf::Pack4x8unorm,
"pack2x16snorm" => Mf::Pack2x16snorm,
"pack2x16unorm" => Mf::Pack2x16unorm,
"pack2x16float" => Mf::Pack2x16float,
// data unpacking
"unpack4x8snorm" => Mf::Unpack4x8snorm,
"unpack4x8unorm" => Mf::Unpack4x8unorm,
"unpack2x16snorm" => Mf::Unpack2x16snorm,
"unpack2x16unorm" => Mf::Unpack2x16unorm,
"unpack2x16float" => Mf::Unpack2x16float,
_ => return None,
})
}
pub fn map_conservative_depth(
word: &str,
span: Span,
) -> Result<crate::ConservativeDepth, Error<'_>> {
use crate::ConservativeDepth as Cd;
match word {
"greater_equal" => Ok(Cd::GreaterEqual),
"less_equal" => Ok(Cd::LessEqual),
"unchanged" => Ok(Cd::Unchanged),
_ => Err(Error::UnknownConservativeDepth(span)),
}
}

View file

@ -0,0 +1,723 @@
use super::{number::consume_number, Error, ExpectedToken};
use crate::front::wgsl::error::NumberError;
use crate::front::wgsl::parse::{conv, Number};
use crate::Span;
type TokenSpan<'a> = (Token<'a>, Span);
#[derive(Copy, Clone, Debug, PartialEq)]
pub enum Token<'a> {
Separator(char),
Paren(char),
Attribute,
Number(Result<Number, NumberError>),
Word(&'a str),
Operation(char),
LogicalOperation(char),
ShiftOperation(char),
AssignmentOperation(char),
IncrementOperation,
DecrementOperation,
Arrow,
Unknown(char),
Trivia,
End,
}
fn consume_any(input: &str, what: impl Fn(char) -> bool) -> (&str, &str) {
let pos = input.find(|c| !what(c)).unwrap_or(input.len());
input.split_at(pos)
}
/// Return the token at the start of `input`.
///
/// If `generic` is `false`, then the bit shift operators `>>` or `<<`
/// are valid lookahead tokens for the current parser state (see [§3.1
/// Parsing] in the WGSL specification). In other words:
///
/// - If `generic` is `true`, then we are expecting an angle bracket
/// around a generic type parameter, like the `<` and `>` in
/// `vec3<f32>`, so interpret `<` and `>` as `Token::Paren` tokens,
/// even if they're part of `<<` or `>>` sequences.
///
/// - Otherwise, interpret `<<` and `>>` as shift operators:
/// `Token::LogicalOperation` tokens.
///
/// [§3.1 Parsing]: https://gpuweb.github.io/gpuweb/wgsl/#parsing
fn consume_token(input: &str, generic: bool) -> (Token<'_>, &str) {
let mut chars = input.chars();
let cur = match chars.next() {
Some(c) => c,
None => return (Token::End, ""),
};
match cur {
':' | ';' | ',' => (Token::Separator(cur), chars.as_str()),
'.' => {
let og_chars = chars.as_str();
match chars.next() {
Some('0'..='9') => consume_number(input),
_ => (Token::Separator(cur), og_chars),
}
}
'@' => (Token::Attribute, chars.as_str()),
'(' | ')' | '{' | '}' | '[' | ']' => (Token::Paren(cur), chars.as_str()),
'<' | '>' => {
let og_chars = chars.as_str();
match chars.next() {
Some('=') if !generic => (Token::LogicalOperation(cur), chars.as_str()),
Some(c) if c == cur && !generic => {
let og_chars = chars.as_str();
match chars.next() {
Some('=') => (Token::AssignmentOperation(cur), chars.as_str()),
_ => (Token::ShiftOperation(cur), og_chars),
}
}
_ => (Token::Paren(cur), og_chars),
}
}
'0'..='9' => consume_number(input),
'/' => {
let og_chars = chars.as_str();
match chars.next() {
Some('/') => {
let _ = chars.position(is_comment_end);
(Token::Trivia, chars.as_str())
}
Some('*') => {
let mut depth = 1;
let mut prev = None;
for c in &mut chars {
match (prev, c) {
(Some('*'), '/') => {
prev = None;
depth -= 1;
if depth == 0 {
return (Token::Trivia, chars.as_str());
}
}
(Some('/'), '*') => {
prev = None;
depth += 1;
}
_ => {
prev = Some(c);
}
}
}
(Token::End, "")
}
Some('=') => (Token::AssignmentOperation(cur), chars.as_str()),
_ => (Token::Operation(cur), og_chars),
}
}
'-' => {
let og_chars = chars.as_str();
match chars.next() {
Some('>') => (Token::Arrow, chars.as_str()),
Some('0'..='9' | '.') => consume_number(input),
Some('-') => (Token::DecrementOperation, chars.as_str()),
Some('=') => (Token::AssignmentOperation(cur), chars.as_str()),
_ => (Token::Operation(cur), og_chars),
}
}
'+' => {
let og_chars = chars.as_str();
match chars.next() {
Some('+') => (Token::IncrementOperation, chars.as_str()),
Some('=') => (Token::AssignmentOperation(cur), chars.as_str()),
_ => (Token::Operation(cur), og_chars),
}
}
'*' | '%' | '^' => {
let og_chars = chars.as_str();
match chars.next() {
Some('=') => (Token::AssignmentOperation(cur), chars.as_str()),
_ => (Token::Operation(cur), og_chars),
}
}
'~' => (Token::Operation(cur), chars.as_str()),
'=' | '!' => {
let og_chars = chars.as_str();
match chars.next() {
Some('=') => (Token::LogicalOperation(cur), chars.as_str()),
_ => (Token::Operation(cur), og_chars),
}
}
'&' | '|' => {
let og_chars = chars.as_str();
match chars.next() {
Some(c) if c == cur => (Token::LogicalOperation(cur), chars.as_str()),
Some('=') => (Token::AssignmentOperation(cur), chars.as_str()),
_ => (Token::Operation(cur), og_chars),
}
}
_ if is_blankspace(cur) => {
let (_, rest) = consume_any(input, is_blankspace);
(Token::Trivia, rest)
}
_ if is_word_start(cur) => {
let (word, rest) = consume_any(input, is_word_part);
(Token::Word(word), rest)
}
_ => (Token::Unknown(cur), chars.as_str()),
}
}
/// Returns whether or not a char is a comment end
/// (Unicode Pattern_White_Space excluding U+0020, U+0009, U+200E and U+200F)
const fn is_comment_end(c: char) -> bool {
match c {
'\u{000a}'..='\u{000d}' | '\u{0085}' | '\u{2028}' | '\u{2029}' => true,
_ => false,
}
}
/// Returns whether or not a char is a blankspace (Unicode Pattern_White_Space)
const fn is_blankspace(c: char) -> bool {
match c {
'\u{0020}'
| '\u{0009}'..='\u{000d}'
| '\u{0085}'
| '\u{200e}'
| '\u{200f}'
| '\u{2028}'
| '\u{2029}' => true,
_ => false,
}
}
/// Returns whether or not a char is a word start (Unicode XID_Start + '_')
fn is_word_start(c: char) -> bool {
c == '_' || unicode_xid::UnicodeXID::is_xid_start(c)
}
/// Returns whether or not a char is a word part (Unicode XID_Continue)
fn is_word_part(c: char) -> bool {
unicode_xid::UnicodeXID::is_xid_continue(c)
}
#[derive(Clone)]
pub(in crate::front::wgsl) struct Lexer<'a> {
input: &'a str,
pub(in crate::front::wgsl) source: &'a str,
// The byte offset of the end of the last non-trivia token.
last_end_offset: usize,
}
impl<'a> Lexer<'a> {
pub(in crate::front::wgsl) const fn new(input: &'a str) -> Self {
Lexer {
input,
source: input,
last_end_offset: 0,
}
}
/// Calls the function with a lexer and returns the result of the function as well as the span for everything the function parsed
///
/// # Examples
/// ```ignore
/// let lexer = Lexer::new("5");
/// let (value, span) = lexer.capture_span(Lexer::next_uint_literal);
/// assert_eq!(value, 5);
/// ```
#[inline]
pub fn capture_span<T, E>(
&mut self,
inner: impl FnOnce(&mut Self) -> Result<T, E>,
) -> Result<(T, Span), E> {
let start = self.current_byte_offset();
let res = inner(self)?;
let end = self.current_byte_offset();
Ok((res, Span::from(start..end)))
}
pub(in crate::front::wgsl) fn start_byte_offset(&mut self) -> usize {
loop {
// Eat all trivia because `next` doesn't eat trailing trivia.
let (token, rest) = consume_token(self.input, false);
if let Token::Trivia = token {
self.input = rest;
} else {
return self.current_byte_offset();
}
}
}
fn peek_token_and_rest(&mut self) -> (TokenSpan<'a>, &'a str) {
let mut cloned = self.clone();
let token = cloned.next();
let rest = cloned.input;
(token, rest)
}
const fn current_byte_offset(&self) -> usize {
self.source.len() - self.input.len()
}
pub(in crate::front::wgsl) fn span_from(&self, offset: usize) -> Span {
Span::from(offset..self.last_end_offset)
}
/// Return the next non-whitespace token from `self`.
///
/// Assume we are a parse state where bit shift operators may
/// occur, but not angle brackets.
#[must_use]
pub(in crate::front::wgsl) fn next(&mut self) -> TokenSpan<'a> {
self.next_impl(false)
}
/// Return the next non-whitespace token from `self`.
///
/// Assume we are in a parse state where angle brackets may occur,
/// but not bit shift operators.
#[must_use]
pub(in crate::front::wgsl) fn next_generic(&mut self) -> TokenSpan<'a> {
self.next_impl(true)
}
/// Return the next non-whitespace token from `self`, with a span.
///
/// See [`consume_token`] for the meaning of `generic`.
fn next_impl(&mut self, generic: bool) -> TokenSpan<'a> {
let mut start_byte_offset = self.current_byte_offset();
loop {
let (token, rest) = consume_token(self.input, generic);
self.input = rest;
match token {
Token::Trivia => start_byte_offset = self.current_byte_offset(),
_ => {
self.last_end_offset = self.current_byte_offset();
return (token, self.span_from(start_byte_offset));
}
}
}
}
#[must_use]
pub(in crate::front::wgsl) fn peek(&mut self) -> TokenSpan<'a> {
let (token, _) = self.peek_token_and_rest();
token
}
pub(in crate::front::wgsl) fn expect_span(
&mut self,
expected: Token<'a>,
) -> Result<Span, Error<'a>> {
let next = self.next();
if next.0 == expected {
Ok(next.1)
} else {
Err(Error::Unexpected(next.1, ExpectedToken::Token(expected)))
}
}
pub(in crate::front::wgsl) fn expect(&mut self, expected: Token<'a>) -> Result<(), Error<'a>> {
self.expect_span(expected)?;
Ok(())
}
pub(in crate::front::wgsl) fn expect_generic_paren(
&mut self,
expected: char,
) -> Result<(), Error<'a>> {
let next = self.next_generic();
if next.0 == Token::Paren(expected) {
Ok(())
} else {
Err(Error::Unexpected(
next.1,
ExpectedToken::Token(Token::Paren(expected)),
))
}
}
/// If the next token matches it is skipped and true is returned
pub(in crate::front::wgsl) fn skip(&mut self, what: Token<'_>) -> bool {
let (peeked_token, rest) = self.peek_token_and_rest();
if peeked_token.0 == what {
self.input = rest;
true
} else {
false
}
}
pub(in crate::front::wgsl) fn next_ident_with_span(
&mut self,
) -> Result<(&'a str, Span), Error<'a>> {
match self.next() {
(Token::Word(word), span) if word == "_" => {
Err(Error::InvalidIdentifierUnderscore(span))
}
(Token::Word(word), span) if word.starts_with("__") => {
Err(Error::ReservedIdentifierPrefix(span))
}
(Token::Word(word), span) => Ok((word, span)),
other => Err(Error::Unexpected(other.1, ExpectedToken::Identifier)),
}
}
pub(in crate::front::wgsl) fn next_ident(
&mut self,
) -> Result<super::ast::Ident<'a>, Error<'a>> {
let ident = self
.next_ident_with_span()
.map(|(name, span)| super::ast::Ident { name, span })?;
if crate::keywords::wgsl::RESERVED.contains(&ident.name) {
return Err(Error::ReservedKeyword(ident.span));
}
Ok(ident)
}
/// Parses a generic scalar type, for example `<f32>`.
pub(in crate::front::wgsl) fn next_scalar_generic(
&mut self,
) -> Result<(crate::ScalarKind, crate::Bytes), Error<'a>> {
self.expect_generic_paren('<')?;
let pair = match self.next() {
(Token::Word(word), span) => {
conv::get_scalar_type(word).ok_or(Error::UnknownScalarType(span))
}
(_, span) => Err(Error::UnknownScalarType(span)),
}?;
self.expect_generic_paren('>')?;
Ok(pair)
}
/// Parses a generic scalar type, for example `<f32>`.
///
/// Returns the span covering the inner type, excluding the brackets.
pub(in crate::front::wgsl) fn next_scalar_generic_with_span(
&mut self,
) -> Result<(crate::ScalarKind, crate::Bytes, Span), Error<'a>> {
self.expect_generic_paren('<')?;
let pair = match self.next() {
(Token::Word(word), span) => conv::get_scalar_type(word)
.map(|(a, b)| (a, b, span))
.ok_or(Error::UnknownScalarType(span)),
(_, span) => Err(Error::UnknownScalarType(span)),
}?;
self.expect_generic_paren('>')?;
Ok(pair)
}
pub(in crate::front::wgsl) fn next_storage_access(
&mut self,
) -> Result<crate::StorageAccess, Error<'a>> {
let (ident, span) = self.next_ident_with_span()?;
match ident {
"read" => Ok(crate::StorageAccess::LOAD),
"write" => Ok(crate::StorageAccess::STORE),
"read_write" => Ok(crate::StorageAccess::LOAD | crate::StorageAccess::STORE),
_ => Err(Error::UnknownAccess(span)),
}
}
pub(in crate::front::wgsl) fn next_format_generic(
&mut self,
) -> Result<(crate::StorageFormat, crate::StorageAccess), Error<'a>> {
self.expect(Token::Paren('<'))?;
let (ident, ident_span) = self.next_ident_with_span()?;
let format = conv::map_storage_format(ident, ident_span)?;
self.expect(Token::Separator(','))?;
let access = self.next_storage_access()?;
self.expect(Token::Paren('>'))?;
Ok((format, access))
}
pub(in crate::front::wgsl) fn open_arguments(&mut self) -> Result<(), Error<'a>> {
self.expect(Token::Paren('('))
}
pub(in crate::front::wgsl) fn close_arguments(&mut self) -> Result<(), Error<'a>> {
let _ = self.skip(Token::Separator(','));
self.expect(Token::Paren(')'))
}
pub(in crate::front::wgsl) fn next_argument(&mut self) -> Result<bool, Error<'a>> {
let paren = Token::Paren(')');
if self.skip(Token::Separator(',')) {
Ok(!self.skip(paren))
} else {
self.expect(paren).map(|()| false)
}
}
}
#[cfg(test)]
fn sub_test(source: &str, expected_tokens: &[Token]) {
let mut lex = Lexer::new(source);
for &token in expected_tokens {
assert_eq!(lex.next().0, token);
}
assert_eq!(lex.next().0, Token::End);
}
#[test]
fn test_numbers() {
// WGSL spec examples //
// decimal integer
sub_test(
"0x123 0X123u 1u 123 0 0i 0x3f",
&[
Token::Number(Ok(Number::I32(291))),
Token::Number(Ok(Number::U32(291))),
Token::Number(Ok(Number::U32(1))),
Token::Number(Ok(Number::I32(123))),
Token::Number(Ok(Number::I32(0))),
Token::Number(Ok(Number::I32(0))),
Token::Number(Ok(Number::I32(63))),
],
);
// decimal floating point
sub_test(
"0.e+4f 01. .01 12.34 .0f 0h 1e-3 0xa.fp+2 0x1P+4f 0X.3 0x3p+2h 0X1.fp-4 0x3.2p+2h",
&[
Token::Number(Ok(Number::F32(0.))),
Token::Number(Ok(Number::F32(1.))),
Token::Number(Ok(Number::F32(0.01))),
Token::Number(Ok(Number::F32(12.34))),
Token::Number(Ok(Number::F32(0.))),
Token::Number(Err(NumberError::UnimplementedF16)),
Token::Number(Ok(Number::F32(0.001))),
Token::Number(Ok(Number::F32(43.75))),
Token::Number(Ok(Number::F32(16.))),
Token::Number(Ok(Number::F32(0.1875))),
Token::Number(Err(NumberError::UnimplementedF16)),
Token::Number(Ok(Number::F32(0.12109375))),
Token::Number(Err(NumberError::UnimplementedF16)),
],
);
// MIN / MAX //
// min / max decimal signed integer
sub_test(
"-2147483648i 2147483647i -2147483649i 2147483648i",
&[
Token::Number(Ok(Number::I32(i32::MIN))),
Token::Number(Ok(Number::I32(i32::MAX))),
Token::Number(Err(NumberError::NotRepresentable)),
Token::Number(Err(NumberError::NotRepresentable)),
],
);
// min / max decimal unsigned integer
sub_test(
"0u 4294967295u -1u 4294967296u",
&[
Token::Number(Ok(Number::U32(u32::MIN))),
Token::Number(Ok(Number::U32(u32::MAX))),
Token::Number(Err(NumberError::NotRepresentable)),
Token::Number(Err(NumberError::NotRepresentable)),
],
);
// min / max hexadecimal signed integer
sub_test(
"-0x80000000i 0x7FFFFFFFi -0x80000001i 0x80000000i",
&[
Token::Number(Ok(Number::I32(i32::MIN))),
Token::Number(Ok(Number::I32(i32::MAX))),
Token::Number(Err(NumberError::NotRepresentable)),
Token::Number(Err(NumberError::NotRepresentable)),
],
);
// min / max hexadecimal unsigned integer
sub_test(
"0x0u 0xFFFFFFFFu -0x1u 0x100000000u",
&[
Token::Number(Ok(Number::U32(u32::MIN))),
Token::Number(Ok(Number::U32(u32::MAX))),
Token::Number(Err(NumberError::NotRepresentable)),
Token::Number(Err(NumberError::NotRepresentable)),
],
);
/// ≈ 2^-126 * 2^23 (= 2^149)
const SMALLEST_POSITIVE_SUBNORMAL_F32: f32 = 1e-45;
/// ≈ 2^-126 * (1 2^23)
const LARGEST_SUBNORMAL_F32: f32 = 1.1754942e-38;
/// ≈ 2^-126
const SMALLEST_POSITIVE_NORMAL_F32: f32 = f32::MIN_POSITIVE;
/// ≈ 1 2^24
const LARGEST_F32_LESS_THAN_ONE: f32 = 0.99999994;
/// ≈ 1 + 2^23
const SMALLEST_F32_LARGER_THAN_ONE: f32 = 1.0000001;
/// ≈ -(2^127 * (2 2^23))
const SMALLEST_NORMAL_F32: f32 = f32::MIN;
/// ≈ 2^127 * (2 2^23)
const LARGEST_NORMAL_F32: f32 = f32::MAX;
// decimal floating point
sub_test(
"1e-45f 1.1754942e-38f 1.17549435e-38f 0.99999994f 1.0000001f -3.40282347e+38f 3.40282347e+38f",
&[
Token::Number(Ok(Number::F32(
SMALLEST_POSITIVE_SUBNORMAL_F32,
))),
Token::Number(Ok(Number::F32(
LARGEST_SUBNORMAL_F32,
))),
Token::Number(Ok(Number::F32(
SMALLEST_POSITIVE_NORMAL_F32,
))),
Token::Number(Ok(Number::F32(
LARGEST_F32_LESS_THAN_ONE,
))),
Token::Number(Ok(Number::F32(
SMALLEST_F32_LARGER_THAN_ONE,
))),
Token::Number(Ok(Number::F32(
SMALLEST_NORMAL_F32,
))),
Token::Number(Ok(Number::F32(
LARGEST_NORMAL_F32,
))),
],
);
sub_test(
"-3.40282367e+38f 3.40282367e+38f",
&[
Token::Number(Err(NumberError::NotRepresentable)), // ≈ -2^128
Token::Number(Err(NumberError::NotRepresentable)), // ≈ 2^128
],
);
// hexadecimal floating point
sub_test(
"0x1p-149f 0x7FFFFFp-149f 0x1p-126f 0xFFFFFFp-24f 0x800001p-23f -0xFFFFFFp+104f 0xFFFFFFp+104f",
&[
Token::Number(Ok(Number::F32(
SMALLEST_POSITIVE_SUBNORMAL_F32,
))),
Token::Number(Ok(Number::F32(
LARGEST_SUBNORMAL_F32,
))),
Token::Number(Ok(Number::F32(
SMALLEST_POSITIVE_NORMAL_F32,
))),
Token::Number(Ok(Number::F32(
LARGEST_F32_LESS_THAN_ONE,
))),
Token::Number(Ok(Number::F32(
SMALLEST_F32_LARGER_THAN_ONE,
))),
Token::Number(Ok(Number::F32(
SMALLEST_NORMAL_F32,
))),
Token::Number(Ok(Number::F32(
LARGEST_NORMAL_F32,
))),
],
);
sub_test(
"-0x1p128f 0x1p128f 0x1.000001p0f",
&[
Token::Number(Err(NumberError::NotRepresentable)), // = -2^128
Token::Number(Err(NumberError::NotRepresentable)), // = 2^128
Token::Number(Err(NumberError::NotRepresentable)),
],
);
}
#[test]
fn test_tokens() {
sub_test("id123_OK", &[Token::Word("id123_OK")]);
sub_test(
"92No",
&[Token::Number(Ok(Number::I32(92))), Token::Word("No")],
);
sub_test(
"2u3o",
&[
Token::Number(Ok(Number::U32(2))),
Token::Number(Ok(Number::I32(3))),
Token::Word("o"),
],
);
sub_test(
"2.4f44po",
&[
Token::Number(Ok(Number::F32(2.4))),
Token::Number(Ok(Number::I32(44))),
Token::Word("po"),
],
);
sub_test(
"Δέλτα réflexion Кызыл 𐰓𐰏𐰇 朝焼け سلام 검정 שָׁלוֹם गुलाबी փիրուզ",
&[
Token::Word("Δέλτα"),
Token::Word("réflexion"),
Token::Word("Кызыл"),
Token::Word("𐰓𐰏𐰇"),
Token::Word("朝焼け"),
Token::Word("سلام"),
Token::Word("검정"),
Token::Word("שָׁלוֹם"),
Token::Word("गुलाबी"),
Token::Word("փիրուզ"),
],
);
sub_test("æNoø", &[Token::Word("æNoø")]);
sub_test("No¾", &[Token::Word("No"), Token::Unknown('¾')]);
sub_test("No好", &[Token::Word("No好")]);
sub_test("_No", &[Token::Word("_No")]);
sub_test(
"*/*/***/*//=/*****//",
&[
Token::Operation('*'),
Token::AssignmentOperation('/'),
Token::Operation('/'),
],
);
}
#[test]
fn test_variable_decl() {
sub_test(
"@group(0 ) var< uniform> texture: texture_multisampled_2d <f32 >;",
&[
Token::Attribute,
Token::Word("group"),
Token::Paren('('),
Token::Number(Ok(Number::I32(0))),
Token::Paren(')'),
Token::Word("var"),
Token::Paren('<'),
Token::Word("uniform"),
Token::Paren('>'),
Token::Word("texture"),
Token::Separator(':'),
Token::Word("texture_multisampled_2d"),
Token::Paren('<'),
Token::Word("f32"),
Token::Paren('>'),
Token::Separator(';'),
],
);
sub_test(
"var<storage,read_write> buffer: array<u32>;",
&[
Token::Word("var"),
Token::Paren('<'),
Token::Word("storage"),
Token::Separator(','),
Token::Word("read_write"),
Token::Paren('>'),
Token::Word("buffer"),
Token::Separator(':'),
Token::Word("array"),
Token::Paren('<'),
Token::Word("u32"),
Token::Paren('>'),
Token::Separator(';'),
],
);
}

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,443 @@
use std::borrow::Cow;
use crate::front::wgsl::error::NumberError;
use crate::front::wgsl::parse::lexer::Token;
/// When using this type assume no Abstract Int/Float for now
#[derive(Copy, Clone, Debug, PartialEq)]
pub enum Number {
/// Abstract Int (-2^63 ≤ i < 2^63)
AbstractInt(i64),
/// Abstract Float (IEEE-754 binary64)
AbstractFloat(f64),
/// Concrete i32
I32(i32),
/// Concrete u32
U32(u32),
/// Concrete f32
F32(f32),
}
impl Number {
/// Convert abstract numbers to a plausible concrete counterpart.
///
/// Return concrete numbers unchanged. If the conversion would be
/// lossy, return an error.
fn abstract_to_concrete(self) -> Result<Number, NumberError> {
match self {
Number::AbstractInt(num) => i32::try_from(num)
.map(Number::I32)
.map_err(|_| NumberError::NotRepresentable),
Number::AbstractFloat(num) => {
let num = num as f32;
if num.is_finite() {
Ok(Number::F32(num))
} else {
Err(NumberError::NotRepresentable)
}
}
num => Ok(num),
}
}
}
// TODO: when implementing Creation-Time Expressions, remove the ability to match the minus sign
pub(in crate::front::wgsl) fn consume_number(input: &str) -> (Token<'_>, &str) {
let (result, rest) = parse(input);
(
Token::Number(result.and_then(Number::abstract_to_concrete)),
rest,
)
}
enum Kind {
Int(IntKind),
Float(FloatKind),
}
enum IntKind {
I32,
U32,
}
enum FloatKind {
F32,
F16,
}
// The following regexes (from the WGSL spec) will be matched:
// int_literal:
// | / 0 [iu]? /
// | / [1-9][0-9]* [iu]? /
// | / 0[xX][0-9a-fA-F]+ [iu]? /
// decimal_float_literal:
// | / 0 [fh] /
// | / [1-9][0-9]* [fh] /
// | / [0-9]* \.[0-9]+ ([eE][+-]?[0-9]+)? [fh]? /
// | / [0-9]+ \.[0-9]* ([eE][+-]?[0-9]+)? [fh]? /
// | / [0-9]+ [eE][+-]?[0-9]+ [fh]? /
// hex_float_literal:
// | / 0[xX][0-9a-fA-F]* \.[0-9a-fA-F]+ ([pP][+-]?[0-9]+ [fh]?)? /
// | / 0[xX][0-9a-fA-F]+ \.[0-9a-fA-F]* ([pP][+-]?[0-9]+ [fh]?)? /
// | / 0[xX][0-9a-fA-F]+ [pP][+-]?[0-9]+ [fh]? /
// You could visualize the regex below via https://debuggex.com to get a rough idea what `parse` is doing
// -?(?:0[xX](?:([0-9a-fA-F]+\.[0-9a-fA-F]*|[0-9a-fA-F]*\.[0-9a-fA-F]+)(?:([pP][+-]?[0-9]+)([fh]?))?|([0-9a-fA-F]+)([pP][+-]?[0-9]+)([fh]?)|([0-9a-fA-F]+)([iu]?))|((?:[0-9]+[eE][+-]?[0-9]+|(?:[0-9]+\.[0-9]*|[0-9]*\.[0-9]+)(?:[eE][+-]?[0-9]+)?))([fh]?)|((?:[0-9]|[1-9][0-9]+))([iufh]?))
fn parse(input: &str) -> (Result<Number, NumberError>, &str) {
/// returns `true` and consumes `X` bytes from the given byte buffer
/// if the given `X` nr of patterns are found at the start of the buffer
macro_rules! consume {
($bytes:ident, $($pattern:pat),*) => {
match $bytes {
&[$($pattern),*, ref rest @ ..] => { $bytes = rest; true },
_ => false,
}
};
}
/// consumes one byte from the given byte buffer
/// if one of the given patterns are found at the start of the buffer
/// returning the corresponding expr for the matched pattern
macro_rules! consume_map {
($bytes:ident, [$($pattern:pat_param => $to:expr),*]) => {
match $bytes {
$( &[$pattern, ref rest @ ..] => { $bytes = rest; Some($to) }, )*
_ => None,
}
};
}
/// consumes all consecutive bytes matched by the `0-9` pattern from the given byte buffer
/// returning the number of consumed bytes
macro_rules! consume_dec_digits {
($bytes:ident) => {{
let start_len = $bytes.len();
while let &[b'0'..=b'9', ref rest @ ..] = $bytes {
$bytes = rest;
}
start_len - $bytes.len()
}};
}
/// consumes all consecutive bytes matched by the `0-9 | a-f | A-F` pattern from the given byte buffer
/// returning the number of consumed bytes
macro_rules! consume_hex_digits {
($bytes:ident) => {{
let start_len = $bytes.len();
while let &[b'0'..=b'9' | b'a'..=b'f' | b'A'..=b'F', ref rest @ ..] = $bytes {
$bytes = rest;
}
start_len - $bytes.len()
}};
}
/// maps the given `&[u8]` (tail of the initial `input: &str`) to a `&str`
macro_rules! rest_to_str {
($bytes:ident) => {
&input[input.len() - $bytes.len()..]
};
}
struct ExtractSubStr<'a>(&'a str);
impl<'a> ExtractSubStr<'a> {
/// given an `input` and a `start` (tail of the `input`)
/// creates a new [`ExtractSubStr`](`Self`)
fn start(input: &'a str, start: &'a [u8]) -> Self {
let start = input.len() - start.len();
Self(&input[start..])
}
/// given an `end` (tail of the initial `input`)
/// returns a substring of `input`
fn end(&self, end: &'a [u8]) -> &'a str {
let end = self.0.len() - end.len();
&self.0[..end]
}
}
let mut bytes = input.as_bytes();
let general_extract = ExtractSubStr::start(input, bytes);
let is_negative = consume!(bytes, b'-');
if consume!(bytes, b'0', b'x' | b'X') {
let digits_extract = ExtractSubStr::start(input, bytes);
let consumed = consume_hex_digits!(bytes);
if consume!(bytes, b'.') {
let consumed_after_period = consume_hex_digits!(bytes);
if consumed + consumed_after_period == 0 {
return (Err(NumberError::Invalid), rest_to_str!(bytes));
}
let significand = general_extract.end(bytes);
if consume!(bytes, b'p' | b'P') {
consume!(bytes, b'+' | b'-');
let consumed = consume_dec_digits!(bytes);
if consumed == 0 {
return (Err(NumberError::Invalid), rest_to_str!(bytes));
}
let number = general_extract.end(bytes);
let kind = consume_map!(bytes, [b'f' => FloatKind::F32, b'h' => FloatKind::F16]);
(parse_hex_float(number, kind), rest_to_str!(bytes))
} else {
(
parse_hex_float_missing_exponent(significand, None),
rest_to_str!(bytes),
)
}
} else {
if consumed == 0 {
return (Err(NumberError::Invalid), rest_to_str!(bytes));
}
let significand = general_extract.end(bytes);
let digits = digits_extract.end(bytes);
let exp_extract = ExtractSubStr::start(input, bytes);
if consume!(bytes, b'p' | b'P') {
consume!(bytes, b'+' | b'-');
let consumed = consume_dec_digits!(bytes);
if consumed == 0 {
return (Err(NumberError::Invalid), rest_to_str!(bytes));
}
let exponent = exp_extract.end(bytes);
let kind = consume_map!(bytes, [b'f' => FloatKind::F32, b'h' => FloatKind::F16]);
(
parse_hex_float_missing_period(significand, exponent, kind),
rest_to_str!(bytes),
)
} else {
let kind = consume_map!(bytes, [b'i' => IntKind::I32, b'u' => IntKind::U32]);
(
parse_hex_int(is_negative, digits, kind),
rest_to_str!(bytes),
)
}
}
} else {
let is_first_zero = bytes.first() == Some(&b'0');
let consumed = consume_dec_digits!(bytes);
if consume!(bytes, b'.') {
let consumed_after_period = consume_dec_digits!(bytes);
if consumed + consumed_after_period == 0 {
return (Err(NumberError::Invalid), rest_to_str!(bytes));
}
if consume!(bytes, b'e' | b'E') {
consume!(bytes, b'+' | b'-');
let consumed = consume_dec_digits!(bytes);
if consumed == 0 {
return (Err(NumberError::Invalid), rest_to_str!(bytes));
}
}
let number = general_extract.end(bytes);
let kind = consume_map!(bytes, [b'f' => FloatKind::F32, b'h' => FloatKind::F16]);
(parse_dec_float(number, kind), rest_to_str!(bytes))
} else {
if consumed == 0 {
return (Err(NumberError::Invalid), rest_to_str!(bytes));
}
if consume!(bytes, b'e' | b'E') {
consume!(bytes, b'+' | b'-');
let consumed = consume_dec_digits!(bytes);
if consumed == 0 {
return (Err(NumberError::Invalid), rest_to_str!(bytes));
}
let number = general_extract.end(bytes);
let kind = consume_map!(bytes, [b'f' => FloatKind::F32, b'h' => FloatKind::F16]);
(parse_dec_float(number, kind), rest_to_str!(bytes))
} else {
// make sure the multi-digit numbers don't start with zero
if consumed > 1 && is_first_zero {
return (Err(NumberError::Invalid), rest_to_str!(bytes));
}
let digits_with_sign = general_extract.end(bytes);
let kind = consume_map!(bytes, [
b'i' => Kind::Int(IntKind::I32),
b'u' => Kind::Int(IntKind::U32),
b'f' => Kind::Float(FloatKind::F32),
b'h' => Kind::Float(FloatKind::F16)
]);
(
parse_dec(is_negative, digits_with_sign, kind),
rest_to_str!(bytes),
)
}
}
}
}
fn parse_hex_float_missing_exponent(
// format: -?0[xX] ( [0-9a-fA-F]+\.[0-9a-fA-F]* | [0-9a-fA-F]*\.[0-9a-fA-F]+ )
significand: &str,
kind: Option<FloatKind>,
) -> Result<Number, NumberError> {
let hexf_input = format!("{}{}", significand, "p0");
parse_hex_float(&hexf_input, kind)
}
fn parse_hex_float_missing_period(
// format: -?0[xX] [0-9a-fA-F]+
significand: &str,
// format: [pP][+-]?[0-9]+
exponent: &str,
kind: Option<FloatKind>,
) -> Result<Number, NumberError> {
let hexf_input = format!("{significand}.{exponent}");
parse_hex_float(&hexf_input, kind)
}
fn parse_hex_int(
is_negative: bool,
// format: [0-9a-fA-F]+
digits: &str,
kind: Option<IntKind>,
) -> Result<Number, NumberError> {
let digits_with_sign = if is_negative {
Cow::Owned(format!("-{digits}"))
} else {
Cow::Borrowed(digits)
};
parse_int(&digits_with_sign, kind, 16, is_negative)
}
fn parse_dec(
is_negative: bool,
// format: -? ( [0-9] | [1-9][0-9]+ )
digits_with_sign: &str,
kind: Option<Kind>,
) -> Result<Number, NumberError> {
match kind {
None => parse_int(digits_with_sign, None, 10, is_negative),
Some(Kind::Int(kind)) => parse_int(digits_with_sign, Some(kind), 10, is_negative),
Some(Kind::Float(kind)) => parse_dec_float(digits_with_sign, Some(kind)),
}
}
// Float parsing notes
// The following chapters of IEEE 754-2019 are relevant:
//
// 7.4 Overflow (largest finite number is exceeded by what would have been
// the rounded floating-point result were the exponent range unbounded)
//
// 7.5 Underflow (tiny non-zero result is detected;
// for decimal formats tininess is detected before rounding when a non-zero result
// computed as though both the exponent range and the precision were unbounded
// would lie strictly between 2^126)
//
// 7.6 Inexact (rounded result differs from what would have been computed
// were both exponent range and precision unbounded)
// The WGSL spec requires us to error:
// on overflow for decimal floating point literals
// on overflow and inexact for hexadecimal floating point literals
// (underflow is not mentioned)
// hexf_parse errors on overflow, underflow, inexact
// rust std lib float from str handles overflow, underflow, inexact transparently (rounds and will not error)
// Therefore we only check for overflow manually for decimal floating point literals
// input format: -?0[xX] ( [0-9a-fA-F]+\.[0-9a-fA-F]* | [0-9a-fA-F]*\.[0-9a-fA-F]+ ) [pP][+-]?[0-9]+
fn parse_hex_float(input: &str, kind: Option<FloatKind>) -> Result<Number, NumberError> {
match kind {
None => match hexf_parse::parse_hexf64(input, false) {
Ok(num) => Ok(Number::AbstractFloat(num)),
// can only be ParseHexfErrorKind::Inexact but we can't check since it's private
_ => Err(NumberError::NotRepresentable),
},
Some(FloatKind::F32) => match hexf_parse::parse_hexf32(input, false) {
Ok(num) => Ok(Number::F32(num)),
// can only be ParseHexfErrorKind::Inexact but we can't check since it's private
_ => Err(NumberError::NotRepresentable),
},
Some(FloatKind::F16) => Err(NumberError::UnimplementedF16),
}
}
// input format: -? ( [0-9]+\.[0-9]* | [0-9]*\.[0-9]+ ) ([eE][+-]?[0-9]+)?
// | -? [0-9]+ [eE][+-]?[0-9]+
fn parse_dec_float(input: &str, kind: Option<FloatKind>) -> Result<Number, NumberError> {
match kind {
None => {
let num = input.parse::<f64>().unwrap(); // will never fail
num.is_finite()
.then_some(Number::AbstractFloat(num))
.ok_or(NumberError::NotRepresentable)
}
Some(FloatKind::F32) => {
let num = input.parse::<f32>().unwrap(); // will never fail
num.is_finite()
.then_some(Number::F32(num))
.ok_or(NumberError::NotRepresentable)
}
Some(FloatKind::F16) => Err(NumberError::UnimplementedF16),
}
}
fn parse_int(
input: &str,
kind: Option<IntKind>,
radix: u32,
is_negative: bool,
) -> Result<Number, NumberError> {
fn map_err(e: core::num::ParseIntError) -> NumberError {
match *e.kind() {
core::num::IntErrorKind::PosOverflow | core::num::IntErrorKind::NegOverflow => {
NumberError::NotRepresentable
}
_ => unreachable!(),
}
}
match kind {
None => match i64::from_str_radix(input, radix) {
Ok(num) => Ok(Number::AbstractInt(num)),
Err(e) => Err(map_err(e)),
},
Some(IntKind::I32) => match i32::from_str_radix(input, radix) {
Ok(num) => Ok(Number::I32(num)),
Err(e) => Err(map_err(e)),
},
Some(IntKind::U32) if is_negative => Err(NumberError::NotRepresentable),
Some(IntKind::U32) => match u32::from_str_radix(input, radix) {
Ok(num) => Ok(Number::U32(num)),
Err(e) => Err(map_err(e)),
},
}
}