class Einsum

Unoptimized, pure-Ruby implementation of a subset of Numpy einsum.

See: docs.scipy.org/doc/numpy/reference/generated/numpy.einsum.html

typed: strong frozen_string_literal: true

Constants

FormatError
Label
VERSION

Public Class Methods

einsum(format, *operands) click to toggle source

Evaluates the (extended) Einstein summation convention on the operands.

Operands must be Array like. Array elements must respond to * and +.

Examples:

Implicit mode:

Einsum.einsum('ij,jk', [[1, 2], [3, 4]], [[1, 2], [3, 4]]) # => dot product: [[7, 10], [15, 22]]
Einsum.einsum('ij,kj', [[1, 2], [3, 4]], [[1, 2], [3, 4]]) # => inner product: [[ 5, 11], [11, 25]]

Explicit mode:

Einsum.einsum('ij,jk->ik', [[1, 2], [3, 4]], [[1, 2], [3, 4]]) # => dot product: [[7, 10], [15, 22]]
Einsum.einsum('ij,kj->ik', [[1, 2], [3, 4]], [[1, 2], [3, 4]]) # => inner product: [[ 5, 11], [11, 25]]
Einsum.einsum('ij,jk->', [[1, 2], [3, 4]], [[1, 2], [3, 4]]) # => 54
Einsum.einsum('ij,kj->', [[1, 2], [3, 4]], [[1, 2], [3, 4]]) # => 52
# File lib/einsum.rb, line 44
def einsum(format, *operands)
  labels = {}

  # check syntax of format string
  unless format.match?(/\A([a-z]+(,[a-z]+)*)(->[a-z]*)?\z/)
    raise FormatError, "invalid format: #{format}"
  end

  # chop up format string
  inputs, explicit, output = format.partition('->')
  inputs = inputs.split(',')
  if operands.length != inputs.length
    raise FormatError, "provides #{operands.length} operands for #{inputs.length} input labels strings"
  end

  # check labels and operands
  inputs.zip(operands).each.with_index do |(input, operand), pos|
    input.split('').each.with_index do |label, axis|
      unless (dim = dim(operand, axis))
        raise FormatError, "no axis in operand #{pos} corresponds to label #{label}"
      end
      if labels[label] && labels[label].dimension != dim
        raise FormatError, "inconsistent dimension for label #{label}: #{labels[label].dimension} and #{dim}"
      end

      labels[label] ||= Label.new(dim)
      labels[label].increment
    end
  end

  # if implicit mode, generate output labels string from all
  # labels mentioned only once in the input labels strings
  if explicit.empty? && (groups = labels.group_by { |_, l| l.count }[1])
    output = groups.map(&:first).sort.join
  end

  # compute shape of the result
  shape = []
  output.split('').each do |label|
    unless labels[label]
      raise FormatError, "output label #{label} not present in input labels"
    end

    shape << labels[label].dimension
  end

  # generate template for result
  result = 0
  unless shape.empty?
    result = empty(shape, result)
  end

  # generate code for the specified operations. first, loop over
  # each output axis in the order specified by the output labels.
  # then, loop over the remaining input axes and compute the
  # result for each cell in the output matrix.

  code = []
  internal = inputs.join.split('').sort.uniq - output.split('')
  external = output.split('')

  external.each do |label|
    code.push("#{labels[label].dimension}.times do |#{label}|")
  end

  internal.each do |label|
    code.push("#{labels[label].dimension}.times do |#{label}|")
  end

  external_labels = external.map { |l| "[#{l}]" }.join
  code.push("result#{external_labels} +=")

  inputs.each.with_index do |input, i|
    input_labels = input.split('').map { |l| "[#{l}]" }.join
    suffix = i < inputs.length - 1 ? ' *' : ''
    code.push("operands[#{i}]#{input_labels}#{suffix}")
  end

  internal.each do
    code.push('end')
  end

  external.each do
    code.push('end')
  end

  # evaluate the generated code in the current context. this would
  # be considered dangerous, except we are in control of generated
  # code except loop variable names, which are derived from input
  # and output labels, which are constrained to be individual,
  # lowercase characters, which are bound in their respective
  # loops.

  # rubocop:disable Security/Eval
  binding.eval(code.join("\n"))
  # rubocop:enable Security/Eval

  result
end