You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

754 lines
16 KiB

package main
import (
"fmt"
"go/ast"
"go/parser"
"go/token"
"log"
"os"
"regexp"
"strconv"
"strings"
)
const (
identation = " "
)
var (
allowedImp map[string]map[string]bool // Map of the allowed imports
allowedFun map[string]bool // Map of the allowed built-in functions
// Is necessary an array to keep all the call instances.
callX []nodePos // Keeps the name of the called functions and the position in the file.
// A map is enough for function declarations because they are unique.
funcDeclPkg map[string]*funcBody // Keeps the name of the function associated to its body and its position in the file.
allArrayTypes = true
arraysInstances []nodePos
forStmts []nodePos
basicLits []nodePos
illegals []illegal
notAllowedArrayT []string
predeclaredTypes = []string{"bool", "byte", "complex64", "complex128",
"error", "float32", "float64", "int", "int8",
"int16", "int32", "int64", "rune", "string",
"uint", "uint8", "uint16", "uint32", "uint64",
"uintptr",
}
relativeImports []string
importPkg map[string]*pkgFunc
pkgName []string
allImports map[string]bool
openImports []string
funcOccurrences map[string]int
)
// pkgFunc for all the functions of a given package
type pkgFunc struct {
functions []string
path string
}
type funcImp struct {
pkg, fun string
pos token.Pos
}
// All visitors
type callVisitor struct {
Calls []string
Fset *token.FileSet
}
type fileVisitor struct {
funcDecl []string
funcCalls []string
selectExpr []string
arrayType []nodePos
Fset *token.FileSet
}
type pkgVisitor struct {
Fset *token.FileSet
}
type impVisitor struct {
Fset *token.FileSet
relativeImports []string
}
// Get the position of the node in the file
type locate interface {
getPos(ast.Node) string
}
func (i *impVisitor) getPos(n ast.Node) string {
return i.Fset.Position(n.Pos()).String()
}
func (fv *fileVisitor) getPos(n ast.Node) string {
return fv.Fset.Position(n.Pos()).String()
}
func (p *pkgVisitor) getPos(n ast.Node) string {
return p.Fset.Position(n.Pos()).String()
}
func (c *callVisitor) getPos(n ast.Node) string {
return c.Fset.Position(n.Pos()).String()
}
type illegal struct {
T string
Name string
Pos string
}
func (i *illegal) String() string {
return i.T + " " + i.Name + " " + i.Pos
}
func getPkgFunc(path string, fsetPkg *token.FileSet) {
i := &impVisitor{Fset: fsetPkg}
p := &pkgVisitor{Fset: fsetPkg}
pkgs, err := parser.ParseDir(fsetPkg, path, nil, parser.AllErrors)
if err != nil {
fmt.Println(err.Error())
os.Exit(1)
}
for pkgname := range pkgs {
pkg, _ := ast.NewPackage(fsetPkg, pkgs[pkgname].Files, nil, nil)
pkgName = append(pkgName, pkgname)
ast.Walk(i, pkg)
ast.Walk(p, pkg)
for _, v := range i.relativeImports {
if isIn(v, openImports) {
break
}
openImports = append(openImports, v)
getPkgFunc(path+"/"+v, fsetPkg)
}
}
}
// reformat from the data base
func splitArgs(args string) []string {
result := strings.Split(args, " ")
return result
}
func rightFile(args string) string {
expectedFiles := splitArgs(args)
for _, s := range expectedFiles {
if strings.Contains(s, ".go") {
return s
}
}
return ""
}
func allowCastingAndImp(allowedImports []string) {
casted := false
for i, v := range allowedImports {
casted = allow(v, casted)
if v == "--no-array" {
allArrayTypes = false
notAllowedArrayT = append(notAllowedArrayT, allowedImports[i+1:]...)
break
}
}
}
type flags struct {
l struct { // flag for char or string literal
noLit bool // true -> unallows
pattern string // this pattern
}
}
// TODO: treat all the flags in this function
// For now, only --no-lit="{PATTERN}"
func parseFlags(args []string) *flags {
f := &flags{}
for _, v := range args {
var flag []string
if strings.Contains(v, "=") {
flag = strings.Split(v, "=")
}
if flag == nil {
continue
}
if flag[0] == "--no-lit" {
f.l.noLit = true
f.l.pattern = flag[1]
}
}
return f
}
func removeAmount(s string) string {
strRm := strings.TrimFunc(s, func(c rune) bool {
return c >= '0' && c <= '9' || c == '#'
})
return strRm
}
// compares if the function is used a certain amount of times allowed
func allowedAmount(occurrences map[string]int, allowedImports []string) {
function := ""
funcSelector := ""
for _, v := range allowedImports {
// pkg in case it's a build in function and slice in case it's a selector function
pkg, slice := trimRelativeImport(v)
if slice != nil {
function = strings.Join(slice, ".")
funcSelector = removeAmount(function)
} else {
function = pkg
funcSelector = removeAmount(pkg)
}
if strings.ContainsAny(function, "#") {
strNbr := strings.TrimPrefix(function, funcSelector+"#")
nbr, err := strconv.Atoi(strNbr)
if err != nil {
log.Panic(err)
}
if occurrences[funcSelector] > nbr {
illegals = append(illegals, illegal{
T: "illegal-amount",
Name: funcSelector + " allowed count " + strNbr + " your count " + strconv.Itoa(occurrences[funcSelector]),
})
}
}
}
}
func main() {
if len(os.Args) < 2 {
fmt.Println("No file or directory")
return
}
var allowedImports []string
if len(os.Args) > 2 {
allowedImports = splitArgs(os.Args[2])
}
allowCastingAndImp(allowedImports)
flag := parseFlags(allowedImports)
filename := strings.TrimSpace(rightFile(os.Args[1]))
split := strings.Split(filename, "/")
path := strings.Join(split[:len(split)-1], "/")
if path == "" {
path = "."
}
fsetFile := token.NewFileSet()
fsetPkg := token.NewFileSet()
fmt.Println("Parsing")
file, err := parser.ParseFile(fsetFile, filename, nil, parser.AllErrors)
if err != nil {
fmt.Println(err.Error())
os.Exit(1)
}
// Get all the name of all functions declared in the file
w := &fileVisitor{Fset: fsetFile}
ast.Walk(w, file)
getPkgFunc(path, fsetPkg)
for _, v := range w.funcDecl {
isFuncAllowed(v, fsetPkg)
}
// TODO: Parsing the arguments for the --max-occurrences flag
allowedAmount(funcOccurrences, allowedImports)
if flag != nil {
flag.unallowLits()
}
analyzeArrayT()
analyzeForStmt(allowedImports)
fmt.Println(identation + "OK")
fmt.Println("Cheating")
if len(illegals) > 0 {
for _, i := range illegals {
fmt.Println(identation + i.String())
}
os.Exit(1)
} else {
fmt.Println(identation + "OK")
}
}
func (f flags) unallowLits() {
if f.l.noLit {
for _, v := range basicLits {
if !f.isLitAllowed(v.name) {
illegals = append(illegals, illegal{
T: "illegal-literal",
Name: v.name,
Pos: v.position,
})
}
}
}
}
func (f flags) isLitAllowed(s string) bool {
matched, err := regexp.Match(f.l.pattern, []byte(s))
if err != nil {
return true
}
return !matched
}
func analyzeForStmt(args []string) {
if isIn("--no-for", args) {
for _, v := range forStmts {
illegals = append(illegals, illegal{
T: "illegal-loop",
Name: v.name,
Pos: v.position,
})
}
}
}
func analyzeArrayT() {
if !allArrayTypes {
l := len(notAllowedArrayT)
for _, v := range arraysInstances {
if l == 0 ||
isIn(v.name, notAllowedArrayT) {
illegals = append(illegals, illegal{
T: "illegal-array-type",
Name: v.name,
Pos: v.position,
})
}
}
}
}
func allow(s string, casted bool) bool {
if strings.ContainsRune(s, '.') {
allowImport(s)
} else {
if allowedFun == nil {
allowedFun = make(map[string]bool)
}
if strings.ContainsAny(s, "#") {
s = removeAmount(s)
}
allowedFun[s] = true
}
if s == "--cast" && !casted {
for _, v := range predeclaredTypes {
allow(v, false)
}
return true
}
return casted
}
// Returns true if the string matches the format of a relative import
func isRelativeImport(s string) bool {
relativeImport, _ := regexp.MatchString(`\.\.\\??`, s)
return relativeImport
}
// Returns true if the string represents an import package
// i.e and expresion like <lib>.<function>
func isImport(s string) bool {
matched, _ := regexp.MatchString(`.\..`, s)
return matched
}
func trimRelativeImport(str string) (string, []string) {
var pkg string
var slice []string
if isImport(str) && !isRelativeImport(str) {
splited := strings.Split(str, "/")
slice = strings.Split(splited[len(splited)-1], ".")
splited[len(splited)-1] = slice[0]
pkg = strings.Join(splited, "/")
if allowedImp[slice[0]] == nil {
allowedImp[slice[0]] = make(map[string]bool)
}
fn := slice[len(slice)-1]
allowedImp[slice[0]][fn] = true
} else {
pkg = str
}
return pkg, slice
}
func allowImport(s string) {
if allowedImp == nil {
allowedImp = make(map[string]map[string]bool)
}
if allImports == nil {
allImports = make(map[string]bool)
}
pkg, slice := trimRelativeImport(s)
allImports[pkg] = true
fn := "*"
if len(slice) > 1 {
fn = removeAmount(slice[1])
}
if allowedImp[pkg] == nil {
allowedImp[pkg] = make(map[string]bool)
}
allowedImp[pkg][fn] = true
}
func addToIllegals(funcname string) {
for _, v := range callX {
if v.name == funcname {
pos := v.position
setIllegal("illegal-call", v.name, pos)
}
}
if funcDeclPkg[funcname] != nil {
pos := funcDeclPkg[funcname].position
setIllegal("illegal-call", funcname, pos)
}
}
// First ignoring the imported functions
func isFuncAllowed(funcname string, fset *token.FileSet) bool {
if allowedFun[funcname] {
return true
}
if !isFuncDeclIn(funcname, funcDeclPkg) {
addToIllegals(funcname)
return false
}
bodyf := funcDeclPkg[funcname].body
if bodyf == nil {
addToIllegals(funcname)
fmt.Println("Body is nil")
return false
}
c := &callVisitor{Fset: fset}
ast.Walk(c, bodyf)
res := true
for _, v := range c.Calls {
if v == funcname {
continue
}
allowed := isFuncAllowed(v, fset)
if !allowed {
addToIllegals(funcname)
for _, v := range c.Calls {
if v == funcname {
continue
}
allowed := isFuncAllowed(v, fset)
if !allowed {
addToIllegals(funcname)
}
res = res && allowed
}
}
res = res && allowed
}
return res
}
func isFuncDeclIn(funcname string, fundecl map[string]*funcBody) bool {
return fundecl[funcname] != nil
}
func isFuncCallIn(funcname string, funP []nodePos) bool {
for _, v := range funP {
if v.name == funcname {
return true
}
}
return false
}
func isIn(s string, slc []string) bool {
for _, v := range slc {
if s == v {
return true
}
}
return false
}
// Keeps the positions of each function call and declaration
type funcBody struct {
position string
body *ast.BlockStmt
}
type nodePos struct {
position string
name string
}
func (fv *fileVisitor) Visit(n ast.Node) ast.Visitor {
fv.checkImport(n)
if ide, ok := n.(*ast.CallExpr); ok {
if opt, ok := ide.Fun.(*ast.Ident); ok {
fv.funcCalls = append(fv.funcCalls, opt.Name)
newFun := nodePos{position: fv.getPos(n), name: opt.Name}
callX = append(callX, newFun)
}
}
if expr, ok := n.(*ast.SelectorExpr); ok {
if x, ok := expr.X.(*ast.Ident); ok {
fv.selectExpr = append(fv.selectExpr, x.Name+"."+expr.Sel.Name)
// saves the function in to the map, from the package
if importPkg[x.Name] != nil {
importPkg[x.Name].functions = append(importPkg[x.Name].functions, x.Name+"."+expr.Sel.Name)
}
if funcOccurrences == nil {
funcOccurrences = make(map[string]int)
}
funcOccurrences[x.Name+"."+expr.Sel.Name]++
}
}
if ex, ok := n.(*ast.ArrayType); ok {
if op, ok := ex.Elt.(*ast.Ident); ok {
fv.arrayType = append(fv.arrayType, nodePos{
name: op.Name,
position: fv.getPos(n),
})
}
}
if exp, ok := n.(*ast.FuncDecl); ok {
fv.funcDecl = append(fv.funcDecl, exp.Name.Name)
for _, v := range exp.Type.Params.List {
if _, ok := v.Type.(*ast.FuncType); ok {
if allowedFun == nil {
allowedFun = make(map[string]bool)
}
for _, name := range v.Names {
allowedFun[name.Name] = true
}
}
}
}
if ex, ok := n.(*ast.AssignStmt); ok {
if exp, ok := ex.Rhs[0].(*ast.FuncLit); ok {
if ide, ok := ex.Lhs[0].(*ast.Ident); ok {
if funcDeclPkg == nil {
funcDeclPkg = make(map[string]*funcBody)
}
funcDeclPkg[ide.Name] = &funcBody{body: exp.Body, position: fv.getPos(n)}
}
}
}
return fv
}
func positionIsIn(illegals []illegal, pos string) bool {
for _, v := range illegals {
if v.Pos == pos {
return true
}
}
return false
}
func (p *pkgVisitor) Visit(n ast.Node) ast.Visitor {
if exp, ok := n.(*ast.FuncDecl); ok {
if funcDeclPkg == nil {
funcDeclPkg = make(map[string]*funcBody)
}
funcDeclPkg[exp.Name.Name] = &funcBody{body: exp.Body, position: p.getPos(n)}
for _, pkg := range pkgName {
if importPkg[pkg] != nil && isIn(pkg+"."+exp.Name.Name, importPkg[pkg].functions) {
funcDeclPkg[pkg+"."+exp.Name.Name] = &funcBody{body: exp.Body, position: p.getPos(n)}
}
}
}
if ex, ok := n.(*ast.AssignStmt); ok {
if exp, ok := ex.Rhs[0].(*ast.FuncLit); ok {
if ide, ok := ex.Lhs[0].(*ast.Ident); ok {
if funcDeclPkg == nil {
funcDeclPkg = make(map[string]*funcBody)
}
funcDeclPkg[ide.Name] = &funcBody{body: exp.Body, position: p.getPos(n)}
}
}
}
return p
}
func setIllegal(illegalType, funcName, pos string) {
if !positionIsIn(illegals, pos) {
illegals = append(illegals, illegal{
T: illegalType,
Name: funcName,
Pos: pos,
})
}
}
// Signals that exists at least one callExpr in the node
func (c *callVisitor) Visit(n ast.Node) ast.Visitor {
if id, ok := n.(*ast.BasicLit); ok {
if id.Kind != token.CHAR && id.Kind != token.STRING {
return nil
}
basicLits = append(basicLits, nodePos{position: c.getPos(n), name: id.Value})
}
if exp, ok := n.(*ast.CallExpr); ok {
if fun, ok := exp.Fun.(*ast.Ident); ok {
c.Calls = append(c.Calls, fun.Name)
newFun := nodePos{position: c.getPos(n), name: fun.Name}
callX = append(callX, newFun)
if funcOccurrences == nil {
funcOccurrences = make(map[string]int)
}
funcOccurrences[fun.Name]++
return c
}
}
// SelectorExpr is when we access a value (dot opperator)
// We need to check those for specific functions
if expr, ok := n.(*ast.SelectorExpr); ok {
x, ok := expr.X.(*ast.Ident)
if !ok {
// in this case we are deep in an access
// example, fmt is banned, but pouet isn't.
// we must allow pouet.fmt.x but not fmt.x
// this is the pouet.fmt.x case.
return c
}
pkg := allowedImp[x.Name]
f := x.Name + "." + expr.Sel.Name
if funcDeclPkg[f] != nil {
c.Calls = append(c.Calls, f)
return c
}
if pkg == nil {
if allImports[x.Name] {
pos := c.getPos(n)
setIllegal("illegal-access", f, pos)
}
return c
}
if !pkg["*"] && !pkg[expr.Sel.Name] {
// all the package is not whiteList and is not explicitly allowed
pos := c.getPos(n)
setIllegal("illegal-access", f, pos)
}
}
if ex, ok := n.(*ast.ArrayType); ok {
if op, ok := ex.Elt.(*ast.Ident); ok {
arraysInstances = append(arraysInstances, nodePos{
name: op.Name,
position: c.getPos(n),
})
}
}
if _, ok := n.(*ast.ForStmt); ok {
forStmts = append(forStmts, nodePos{
name: "for",
position: c.getPos(n),
})
}
return c
}
func (fv *fileVisitor) checkImport(n ast.Node) ast.Visitor {
if spec, ok := n.(*ast.ImportSpec); ok {
pkg := spec.Path.Value[1 : len(spec.Path.Value)-1]
if allowedImp[pkg] == nil {
pos := fv.getPos(n)
setIllegal("illegal-import", pkg, pos)
return fv
}
// if the import is named, we need to move it to the new name
name := ""
if spec.Name != nil {
name = spec.Name.Name
} else if strings.ContainsRune(pkg, '/') {
parts := strings.Split(pkg, "/")
name = parts[len(parts)-1]
}
if allowedImp[pkg] != nil {
if name != "" {
allowedImp[name] = allowedImp[pkg]
allowedImp[pkg] = nil
}
}
if isRelativeImport(pkg) {
if importPkg == nil {
importPkg = make(map[string]*pkgFunc)
}
if name != "" {
importPkg[name] = &pkgFunc{
path: pkg,
}
}
relativeImports = append(relativeImports, name)
}
}
return fv
}
func (i *impVisitor) Visit(n ast.Node) ast.Visitor {
if spec, ok := n.(*ast.ImportSpec); ok {
pkg := spec.Path.Value[1 : len(spec.Path.Value)-1]
if allImports == nil {
allImports = make(map[string]bool)
}
pkgSplit := strings.Split(pkg, "/")
allImports[pkg] = true
allImports[pkgSplit[len(pkgSplit)-1]] = true
if isRelativeImport(pkg) {
i.relativeImports = append(i.relativeImports, pkg)
}
}
return i
}