Merge branch 'main' of github.com:ethteck/decomp.me into json-diff

This commit is contained in:
Zach Banks
2021-09-18 12:45:21 -04:00
29 changed files with 813 additions and 494 deletions
+14
View File
@@ -0,0 +1,14 @@
# This file wraps common django decorators in method_decorator for use with the APIView class
from django.utils.decorators import method_decorator
from rest_framework.response import Response
from typing import Optional, Callable
from datetime import datetime
def condition(etag_func: Optional[Callable[..., Optional[str]]] = None, last_modified_func: Optional[Callable[..., Optional[datetime]]] = None) -> Callable[..., Callable[..., Response]]:
"""
Handle Last-Modified and ETag headers.
"""
from django.views.decorators.http import condition
return method_decorator(condition(etag_func, last_modified_func))
+6 -3
View File
@@ -14,9 +14,12 @@ if TYPE_CHECKING:
class AnonymousUser(auth.models.AnonymousUser):
profile: Profile
class Request(DRFRequest):
user: Union[User, AnonymousUser]
profile: Profile
if TYPE_CHECKING:
class Request(DRFRequest):
user: Union[User, AnonymousUser]
profile: Profile
else:
Request = DRFRequest
def disable_csrf(get_response):
def middleware(request: HttpRequest):
+62
View File
@@ -3,7 +3,9 @@ from django.urls import reverse
from django.contrib.auth.models import User
from rest_framework import status
from rest_framework.test import APITestCase
import responses
from time import sleep
from .models import Compilation, Scratch, Profile
from .github import GitHubUser
@@ -224,3 +226,63 @@ class UserTests(APITestCase):
response = self.client.get(f"/api/scratch/{slug}")
self.assertEqual(response.json()["scratch"]["owner"]["username"], self.GITHUB_USER["login"])
self.assertEqual(response.json()["scratch"]["owner"]["is_you"], True)
class ScratchDetailTests(APITestCase):
def make_nop_scratch(self) -> Scratch:
response = self.client.post(reverse("scratch"), {
'arch': 'mips',
'context': '',
'target_asm': "jr $ra\nnop\n",
})
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
scratch = Scratch.objects.first()
assert scratch is not None # assert keyword instead of self.assertIsNotNone for mypy
return scratch
def test_404_head(self):
"""
Ensure that HEAD requests 404 correctly.
"""
response = self.client.head(reverse("scratch-detail", args=["doesnt_exist"]))
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
def test_last_modified(self):
"""
Ensure that the Last-Modified header is set.
"""
scratch = self.make_nop_scratch()
response = self.client.head(reverse("scratch-detail", args=[scratch.slug]))
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assert_(response.headers.get("Last-Modified") is not None)
def test_if_modified_since(self):
"""
Ensure that the If-Modified-Since header is handled.
"""
scratch = self.make_nop_scratch()
response = self.client.head(reverse("scratch-detail", args=[scratch.slug]))
self.assertEqual(response.status_code, status.HTTP_200_OK)
last_modified = response.headers.get("Last-Modified")
# should be unmodified
response = self.client.get(reverse("scratch-detail", args=[scratch.slug]), HTTP_IF_MODIFIED_SINCE=last_modified)
self.assertEqual(response.status_code, status.HTTP_304_NOT_MODIFIED)
# Last-Modified is only granular to the second
sleep(1)
# touch the scratch
old_last_updated = scratch.last_updated
scratch.slug = "newslug"
scratch.save()
self.assertNotEqual(scratch.last_updated, old_last_updated)
# should now be modified
response = self.client.get(reverse("scratch-detail", args=[scratch.slug]), HTTP_IF_MODIFIED_SINCE=last_modified)
self.assertEqual(response.status_code, status.HTTP_200_OK)
+2 -2
View File
@@ -4,8 +4,8 @@ from . import views
urlpatterns = [
path('compilers', views.compilers, name='compilers'),
path('scratch', views.scratch, name='scratch'), # TODO make this into its own view
path('scratch/<slug:slug>', views.scratch, name='scratch-detail'),
path('scratch', views.create_scratch, name='scratch'),
path('scratch/<slug:slug>', views.ScratchDetail.as_view(), name='scratch-detail'),
path('scratch/<slug:slug>/compile', views.compile, name='compile_scratch'),
path('scratch/<slug:slug>/fork', views.fork, name='fork_scratch'),
path('user', views.CurrentUser.as_view(), name="current-user"),
+118 -108
View File
@@ -9,16 +9,16 @@ from rest_framework import serializers, status
from rest_framework.views import APIView
from rest_framework.response import Response
from rest_framework.decorators import api_view
from typing import Any, Dict, Optional
import hashlib
import logging
from datetime import datetime
from typing import Any, Dict, Optional
from .models import Profile, Asm, Scratch
from .models import Profile, Asm, Scratch, gen_scratch_id
from .github import GitHubUser
from .middleware import Request
from coreapp.models import gen_scratch_id
from .decorators.django import condition
def get_db_asm(request_asm) -> Asm:
h = hashlib.sha256(request_asm.encode()).hexdigest()
@@ -34,125 +34,135 @@ def compilers(request):
"compiler_ids": CompilerWrapper.available_compilers(),
})
class ScratchDetail(APIView):
# type-ignored due to python/mypy#7778
def scratch_last_modified(request: Request, slug: str) -> Optional[datetime]: # type: ignore
scratch: Optional[Scratch] = Scratch.objects.filter(slug=slug).first()
if scratch:
return scratch.last_updated
else:
return None
@api_view(["GET", "POST", "PATCH"])
def scratch(request, slug=None):
"""
Get, create, or update a scratch
"""
scratch_condition = condition(last_modified_func=scratch_last_modified)
if request.method == "GET":
if not slug:
return Response("Missing slug", status=status.HTTP_400_BAD_REQUEST)
@scratch_condition
def head(self, request: Request, slug: str):
get_object_or_404(Scratch, slug=slug) # for 404
return Response()
db_scratch = get_object_or_404(Scratch, slug=slug)
@scratch_condition
def get(self, request: Request, slug: str):
scratch = get_object_or_404(Scratch, slug=slug)
if not db_scratch.owner:
if not scratch.owner and request.query_params.get("no_take_ownership") is None:
# Give ownership to this profile
profile = request.profile
logging.debug(f"Granting ownership of scratch {db_scratch} to {profile}")
logging.debug(f"Granting ownership of scratch {scratch} to {profile}")
db_scratch.owner = profile
db_scratch.save()
scratch.owner = profile
scratch.save()
return Response({
"scratch": ScratchWithMetadataSerializer(db_scratch, context={ "request": request }).data,
})
elif request.method == "POST":
if slug:
return Response({"error": "Not allowed to POST with slug"}, status=status.HTTP_400_BAD_REQUEST)
ser = ScratchCreateSerializer(data=request.data)
ser.is_valid(raise_exception=True)
data = ser.validated_data
arch = data.get("arch")
compiler = data.get("compiler", "")
if compiler:
arch = CompilerWrapper.arch_from_compiler(compiler)
if not arch:
raise serializers.ValidationError("Unknown compiler")
elif not arch:
raise serializers.ValidationError("arch not provided")
target_asm = data["target_asm"]
context = data["context"]
asm = get_db_asm(target_asm)
assembly, err = CompilerWrapper.assemble_asm(arch, asm)
if not assembly:
assert isinstance(err, str)
errors = []
for line in err.splitlines():
if "asm.s:" in line:
errors.append(line[line.find("asm.s:") + len("asm.s:") :].strip())
else:
errors.append(line)
return Response({
"error": "as_error",
"error_description": "Error when assembling target asm",
"as_errors": errors,
}, status=status.HTTP_400_BAD_REQUEST)
source_code = data.get("source_code")
if not source_code:
source_code = "void func() {}\n"
if arch == "mips":
source_code = M2CWrapper.decompile(asm.data, context) or source_code
cc_opts = data.get("compiler_flags", "")
if compiler and cc_opts:
cc_opts = CompilerWrapper.filter_cc_opts(compiler, cc_opts)
scratch_data = {
"slug": gen_scratch_id(),
"arch": arch,
"compiler": compiler,
"cc_opts": cc_opts,
"context": context,
"source_code": source_code,
"target_assembly": assembly.pk,
response = self.head(request, slug)
response.data = {
"scratch": ScratchWithMetadataSerializer(scratch, context={ "request": request }).data,
}
return response
serializer = ScratchSerializer(data=scratch_data)
serializer.is_valid(raise_exception=True)
serializer.save()
db_scratch = Scratch.objects.get(slug=scratch_data["slug"])
return Response({
"scratch": ScratchWithMetadataSerializer(db_scratch, context={ "request": request }).data,
}, status=status.HTTP_201_CREATED)
elif request.method == "PATCH":
if not slug:
return Response({"error": "Missing slug"}, status=status.HTTP_400_BAD_REQUEST)
def patch(self, request: Request, slug: str):
required_params = ["compiler", "cc_opts", "source_code", "context"]
for param in required_params:
if param not in request.data:
return Response({"error": f"Missing parameter: {param}"}, status=status.HTTP_400_BAD_REQUEST)
db_scratch = get_object_or_404(Scratch, slug=slug)
scratch = get_object_or_404(Scratch, slug=slug)
if db_scratch.owner and db_scratch.owner != request.profile:
return Response(status=status.HTTP_403_FORBIDDEN)
if scratch.owner and scratch.owner != request.profile:
response = self.get(request, slug)
response.status_code = status.HTTP_403_FORBIDDEN
return response
# TODO validate
db_scratch.compiler = request.data["compiler"]
db_scratch.cc_opts = request.data["cc_opts"]
db_scratch.source_code = request.data["source_code"]
db_scratch.context = request.data["context"]
db_scratch.save()
return Response(status=status.HTTP_202_ACCEPTED)
scratch.compiler = request.data["compiler"]
scratch.cc_opts = request.data["cc_opts"]
scratch.source_code = request.data["source_code"]
scratch.context = request.data["context"]
scratch.save()
return self.get(request, slug)
@api_view(["POST"])
def create_scratch(request):
"""
Create a scratch
"""
ser = ScratchCreateSerializer(data=request.data)
ser.is_valid(raise_exception=True)
data = ser.validated_data
arch = data.get("arch")
compiler = data.get("compiler", "")
if compiler:
arch = CompilerWrapper.arch_from_compiler(compiler)
if not arch:
raise serializers.ValidationError("Unknown compiler")
elif not arch:
raise serializers.ValidationError("arch not provided")
target_asm = data["target_asm"]
context = data["context"]
asm = get_db_asm(target_asm)
assembly, err = CompilerWrapper.assemble_asm(arch, asm)
if not assembly:
assert isinstance(err, str)
errors = []
for line in err.splitlines():
if "asm.s:" in line:
errors.append(line[line.find("asm.s:") + len("asm.s:") :].strip())
else:
errors.append(line)
return Response({
"error": "as_error",
"error_description": "Error when assembling target asm",
"as_errors": errors,
}, status=status.HTTP_400_BAD_REQUEST)
source_code = data.get("source_code")
if not source_code:
source_code = "void func() {}\n"
if arch == "mips":
source_code = M2CWrapper.decompile(asm.data, context) or source_code
cc_opts = data.get("compiler_flags", "")
if compiler and cc_opts:
cc_opts = CompilerWrapper.filter_cc_opts(compiler, cc_opts)
scratch_data = {
"slug": gen_scratch_id(),
"arch": arch,
"compiler": compiler,
"cc_opts": cc_opts,
"context": context,
"source_code": source_code,
"target_assembly": assembly.pk,
}
serializer = ScratchSerializer(data=scratch_data)
serializer.is_valid(raise_exception=True)
serializer.save()
db_scratch = Scratch.objects.get(slug=scratch_data["slug"])
return Response({
"scratch": ScratchWithMetadataSerializer(db_scratch, context={ "request": request }).data,
}, status=status.HTTP_201_CREATED)
@api_view(["POST"])
def compile(request, slug):
@@ -181,12 +191,12 @@ def compile(request, slug):
if compilation:
diff_output = AsmDifferWrapper.diff(scratch.target_assembly, compilation)
response_obj = {
"diff_output": diff_output,
"errors": errors,
}
return Response(response_obj)
return Response({
"compilation": {
"diff_output": diff_output,
"errors": errors,
},
})
@api_view(["POST"])
def fork(request, slug):